diff options
Diffstat (limited to 'mlir/lib/Dialect')
17 files changed, 451 insertions, 127 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index f405d0c..c798adb 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -757,13 +757,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> { offset = numElements - 4l; } Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), - scaleSrcElemType); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); Value newScaleSrc = vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, - ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26..749e2ba 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, // Use "unused attribute" marker to silence clang-tidy warning stemming from // the inability to see through "llvm::TypeSwitch". template <> -bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, - Region *src, Region *dest, - const IRMapping &mapping) { +[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src, + Region *dest, + const IRMapping &mapping) { // If it's a valid dimension, we need to check that it remains so. if (isValidDim(op.getResult(), src)) return remainsLegalAfterInline( @@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map, /// Simplify the map while exploiting information on the values in `operands`. // Use "unused attribute" marker to silence warning stemming from the inability // to see through the template expansion. -static void LLVM_ATTRIBUTE_UNUSED -simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { +[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map, + ArrayRef<Value> operands) { assert(map.getNumInputs() == operands.size() && "invalid operands for map"); SmallVector<AffineExpr> newResults; newResults.reserve(map.getNumResults()); @@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, return success(*map != initialMap); } +/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form +/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`, +/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove) +/// into `replacementsMap`. If no entries were added to `replacementsMap`, +/// nothing was found. +static void shortenAddChainsContainingAll( + AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove, + AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) { + auto binOp = dyn_cast<AffineBinaryOpExpr>(e); + if (!binOp) + return; + AffineExpr lhs = binOp.getLHS(); + AffineExpr rhs = binOp.getRHS(); + if (binOp.getKind() != AffineExprKind::Add) { + shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap); + shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap); + return; + } + SmallVector<AffineExpr> toPreserve; + llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove); + AffineExpr thisTerm = rhs; + AffineExpr nextTerm = lhs; + + while (thisTerm) { + if (!ourTracker.erase(thisTerm)) { + toPreserve.push_back(thisTerm); + shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal, + replacementsMap); + } + auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm); + if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) { + thisTerm = nextTerm; + nextTerm = AffineExpr(); + } else { + thisTerm = nextBinOp.getRHS(); + nextTerm = nextBinOp.getLHS(); + } + } + if (!ourTracker.empty()) + return; + // We reverse the terms to be preserved here in order to preserve + // associativity between them. + AffineExpr newExpr = newVal; + for (AffineExpr preserved : llvm::reverse(toPreserve)) + newExpr = newExpr + preserved; + replacementsMap.insert({e, newExpr}); +} + +/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N + +/// ...` (not necessarily in order) where the set of the `x_i` is the set of +/// outputs of an `affine.delinearize_index` whos inverse is that expression, +/// replace that expression with the input of that delinearize_index op. +/// +/// `unitDimInput` is the input that was detected as the potential start to this +/// replacement chain - if it isn't the rightmost result of the delinearization, +/// this method fails. (This is intended to ensure we don't have redundant scans +/// over the same expression). +/// +/// While this currently only handles delinearizations with a constant basis, +/// that isn't a fundamental limitation. +/// +/// This is a utility function for `replaceDimOrSym` below. +static LogicalResult replaceAffineDelinearizeIndexInverseExpression( + AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, + SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) { + if (!delinOp.getDynamicBasis().empty()) + return failure(); + if (resultToReplace != delinOp.getMultiIndex().back()) + return failure(); + + MLIRContext *ctx = delinOp.getContext(); + SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr()); + for (auto [pos, dim] : llvm::enumerate(dims)) { + auto asResult = dyn_cast_if_present<OpResult>(dim); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx); + } + for (auto [pos, sym] : llvm::enumerate(syms)) { + auto asResult = dyn_cast_if_present<OpResult>(sym); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx); + } + if (llvm::is_contained(resToExpr, AffineExpr())) + return failure(); + + bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>); + int64_t stride = 1; + llvm::SmallDenseSet<AffineExpr, 4> expectedExprs; + // This isn't zip_equal since sometimes the delinearize basis is missing a + // size for the first result. + for (auto [binding, size] : llvm::zip( + llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) { + expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx)); + stride *= size; + } + if (resToExpr.size() != delinOp.getStaticBasis().size()) + expectedExprs.insert(resToExpr[0] * stride); + + DenseMap<AffineExpr, AffineExpr> replacements; + AffineExpr delinInExpr = isDimReplacement + ? getAffineDimExpr(dims.size(), ctx) + : getAffineSymbolExpr(syms.size(), ctx); + + for (AffineExpr e : map->getResults()) + shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements); + if (replacements.empty()) + return failure(); + + AffineMap origMap = *map; + if (isDimReplacement) + dims.push_back(delinOp.getLinearIndex()); + else + syms.push_back(delinOp.getLinearIndex()); + *map = origMap.replace(replacements, dims.size(), syms.size()); + + // Blank out dead dimensions and symbols + for (AffineExpr e : resToExpr) { + if (auto d = dyn_cast<AffineDimExpr>(e)) { + unsigned pos = d.getPosition(); + if (!map->isFunctionOfDim(pos)) + dims[pos] = nullptr; + } + if (auto s = dyn_cast<AffineSymbolExpr>(e)) { + unsigned pos = s.getPosition(); + if (!map->isFunctionOfSymbol(pos)) + syms[pos] = nullptr; + } + } + return success(); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map, syms); } + if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) { + return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims, + syms); + } + auto affineApply = v.getDefiningOp<AffineApplyOp>(); if (!affineApply) return failure(); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index cd216ef..4743941 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation( /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] bool mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) { assert(!loops.empty() && "no loops provided"); @@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, return copyNestRoot; } -static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED -emitRemarkForBlock(Block &block) { +[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) { return block.getParentOp()->emitRemark(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 01a16ce..ac35eea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, /// These are unused for now. /// TODO: Move over to these once more types have been migrated to TypeDef. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d8f983f..6192d79 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3024,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } if (dynamicPointParseResult.has_value()) { - Type ChunkSizesType; + Type chunkSizesType; if (failed(*dynamicPointParseResult) || parser.parseComma() || - parser.parseType(ChunkSizesType) || - parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, + parser.parseType(chunkSizesType) || + parser.resolveOperand(dynamicChunkSizes, chunkSizesType, result.operands)) { return failure(); } @@ -3399,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects( } static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, - Type targetType, Type tile_sizes, + Type targetType, Type tileSizes, Type) { - printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); + printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes}); } static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 642ced9..90cbbd8 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -40,6 +40,16 @@ static bool isScalarLikeType(Type type) { return type.isIntOrIndexOrFloat() || isa<ComplexType>(type); } +/// Helper function to attach the `VarName` attribute to an operation +/// if a variable name is provided. +static void attachVarNameAttr(Operation *op, OpBuilder &builder, + StringRef varName) { + if (!varName.empty()) { + auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName); + op->setAttr(acc::getVarNameAttrName(), varNameAttr); + } +} + struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, MemRefType> { @@ -83,7 +93,9 @@ struct MemRefPointerLikeModel // then we can generate an alloca operation. if (memrefTy.hasStaticShape()) { needsFree = false; // alloca doesn't need deallocation - return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy); + attachVarNameAttr(allocaOp, builder, varName); + return allocaOp.getResult(); } // For dynamic memrefs, extract sizes from the original variable if @@ -103,8 +115,10 @@ struct MemRefPointerLikeModel // Static dimensions are handled automatically by AllocOp } needsFree = true; // alloc needs deallocation - return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) - .getResult(); + auto allocOp = + memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes); + attachVarNameAttr(allocOp, builder, varName); + return allocOp.getResult(); } // TODO: Unranked not yet supported. diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 135c033..645cbff 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op, } template <typename It> -bool isUnique(It begin, It end) { +static bool isUnique(It begin, It end) { if (begin == end) { return true; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index a1711a6..069191c 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) { /// Helper function for `assertUsageConsistency` to better handle SMLoc /// mismatches. -LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc -minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { +[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, + llvm::SMLoc sm2) { const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index f539502..684c088 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// #ifndef NDEBUG -LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, - Location loc, Value memref) { +[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc, + Value memref) { memref = memref::CastOp::create( builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad671..1cba1bb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a..a0661e4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ public: ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333..f072e3e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template <typename T> -FailureOr<SmallVector<T>> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr<OpComplianceInfo<T>> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap<T>(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile<T>(op, it->second, condition); + return findMatchedEntry<T>(op, it->second); } template <typename T> @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition<T>(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile<T>(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector<T> opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); - const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + const auto maybeProfDef = getOperatorDefinition<Profile>(op); + const auto maybeExtDef = getOperatorDefinition<Extension>(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector<TypeInfo> bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template <typename T> -SmallVector<T> TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition) { +OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector<TypeInfo> expected : sets) { + SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector<TypeInfo> expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector<VersionedTypeInfo> typeInfoSet{set}; + OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 9a24c2b..a2cff6a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -21,10 +21,10 @@ using namespace mlir; // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9beb22d..1599ae9 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template <typename ArithOp> +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create<ArithOp>(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets, + ArrayRef<int64_t> blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector<OpFoldResult> blockedOffsets; + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector<int64_t> MemDescType::getStrideShape() { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + + SmallVector<int64_t> innerBlkShape = getBlockShape(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector<int, 4> perm = + llvm::to_vector<4>(llvm::seq<int>(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector<int64_t> innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector<int64_t> matrixShapeOrig(matrixShape.size()); + SmallVector<int64_t> BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector<int64_t> outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector<int64_t> blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets) { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + SmallVector<int64_t> blockShape = getBlockShape(); + SmallVector<int64_t> strides = getStrideShape(); + SmallVector<OpFoldResult> blockedOffsets; + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..abd12e2 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref<InFlightDiagnostic()> emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + } + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else { + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockShape.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast<VectorType>(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef<int64_t> valueShape = resTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef<int64_t> dataShape = dataTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); -} - -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef<OpFoldResult> offsets) { - llvm::SmallVector<Value> dynamicOffsets; - llvm::SmallVector<int64_t> staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef<int64_t> srcShape = srcTy.getShape(); - ArrayRef<int64_t> resShape = resTy.getShape(); - - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStrides() != resTy.getStrides()) - return emitOpError("result must inherit the source strides."); - - return success(); + auto dataTy = dyn_cast<VectorType>(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0f..aafa1b7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef<int64_t> shape = valueTy.getShape(); auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc..31a967d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { return failure(); ArrayRef<int64_t> wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); |