diff options
Diffstat (limited to 'mlir/lib/Dialect')
146 files changed, 7228 insertions, 2184 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 9a0a230..11a40d6 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -511,6 +511,18 @@ LogicalResult DPPOp::verify() { } //===----------------------------------------------------------------------===// +// PermlaneSwapOp +//===----------------------------------------------------------------------===// +LogicalResult PermlaneSwapOp::verify() { + unsigned rowLength = getRowLength(); + + if (rowLength != 16 && rowLength != 32) + return emitOpError("row_length attribute must either be 16 or 32."); + + return success(); +} + +//===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// @@ -518,8 +530,8 @@ LogicalResult GatherToLDSOp::verify() { MemRefType srcType = cast<MemRefType>(getSrc().getType()); MemRefType dstType = cast<MemRefType>(getDst().getType()); - if (!dstType.areTrailingDimsContiguous(dstType.getRank())) - return emitOpError("destination types must be contiguous"); + if (!dstType.areTrailingDimsContiguous(1)) + return emitOpError("destination type inner most dim must be contiguous"); auto elemType = srcType.getElementType(); // Check $src and $dst element types are the same. diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 729e3da..d35853b 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms ResolveStridedMetadata.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms DEPENDS MLIRAMDGPUTransformsIncGen diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp index a3fdc7e..d547510 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp @@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final } }; +static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, + Value view, mlir::OperandRange indices, + SmallVectorImpl<Value> &resolvedIndices, + Value &memrefBase, StringRef role) { + Operation *defOp = view.getDefiningOp(); + if (!defOp) { + return failure(); + } + return llvm::TypeSwitch<Operation *, LogicalResult>(defOp) + .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) { + mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, loc, subviewOp.getMixedOffsets(), + subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices, + resolvedIndices); + memrefBase = subviewOp.getSource(); + return success(); + }) + .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesExpandShape( + loc, rewriter, expandShapeOp, indices, resolvedIndices, + false))) { + return failure(); + } + memrefBase = expandShapeOp.getViewSource(); + return success(); + }) + .Case<memref::CollapseShapeOp>( + [&](memref::CollapseShapeOp collapseShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesCollapseShape( + loc, rewriter, collapseShapeOp, indices, + resolvedIndices))) { + return failure(); + } + memrefBase = collapseShapeOp.getViewSource(); + return success(); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure( + op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or " + "CollapseShapeOp") + .str()); + }); +} + struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value memrefSource; - SmallVector<Value> sourceIndices; - auto foldResult = - llvm::TypeSwitch<Operation *, LogicalResult>( - op.getSrc().getDefiningOp()) - .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) { - // If the source is a SubViewOp, we can directly rewrite the - // GatherToLDSOp. - mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, loc, subviewOp.getMixedOffsets(), - subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), - op.getSrcIndices(), sourceIndices); - memrefSource = subviewOp.getSource(); - return success(); - }) - .Case<memref::ExpandShapeOp>( - [&](memref::ExpandShapeOp expandShapeOp) { - if (failed(mlir::memref::resolveSourceIndicesExpandShape( - loc, rewriter, expandShapeOp, op.getSrcIndices(), - sourceIndices, false))) { - return failure(); - } - memrefSource = expandShapeOp.getViewSource(); - return success(); - }) - .Case<memref::CollapseShapeOp>( - [&](memref::CollapseShapeOp collapseShapeOp) { - if (failed(mlir::memref::resolveSourceIndicesCollapseShape( - loc, rewriter, collapseShapeOp, op.getSrcIndices(), - sourceIndices))) { - return failure(); - } - memrefSource = collapseShapeOp.getViewSource(); - return success(); - }) - .Default([&](Operation *op) { - // If the source is not a SubViewOp, ExpandShapeOp, or - // CollapseShapeOp, we cannot fold the GatherToLDSOp. - return rewriter.notifyMatchFailure( - op, - "source producer is not one of SubViewOp, ExpandShapeOp, or " - "CollapseShapeOp"); - }); + SmallVector<Value> sourceIndices, destIndices; + Value memrefSource, memrefDest; + + auto foldSrcResult = + foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(), + sourceIndices, memrefSource, "source"); + + if (failed(foldSrcResult)) { + memrefSource = op.getSrc(); + sourceIndices = op.getSrcIndices(); + } + + auto foldDstResult = + foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(), + destIndices, memrefDest, "destination"); - if (failed(foldResult)) { - return failure(); + if (failed(foldDstResult)) { + memrefDest = op.getDst(); + destIndices = op.getDstIndices(); } rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices, - op.getDst(), op.getDstIndices(), + memrefDest, destIndices, op.getTransferType()); return success(); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 6f3110c..68990ef 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) { if (parser.parseGreater()) return nullptr; - return TileType::get(shape, elementType); + return TileType::getChecked( + [&] { return parser.emitError(parser.getNameLoc()); }, shape, + elementType); } void amx::TileType::print(AsmPrinter &os) const { diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index 86edc2b..b405ec2 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { int64_t lb = forOp.getConstantLowerBound(); dividend[pos] = 1; dividend.back() -= lb; - addLocalFloorDiv(dividend, step); + unsigned qPos = addLocalFloorDiv(dividend, step); // Second constraint: (iv - lb) - step * q = 0. SmallVector<int64_t, 8> eq(getNumCols(), 0); eq[pos] = 1; eq.back() -= lb; // For the local var just added above. - eq[getNumCols() - 2] = -step; + eq[qPos] = -step; addEquality(eq); } } diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp index 2f85e0b..166d39e 100644 --- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <numeric> #include <optional> @@ -548,19 +549,19 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) { // Check whether there is any negative direction vector in the // dependence components found above, which means that dependence is // violated by the default hyper-rect tiling method. - LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " - "for dependence at depth: " - << Twine(d) << " between:\n";); - LLVM_DEBUG(srcAccess.opInst->dump()); - LLVM_DEBUG(dstAccess.opInst->dump()); + LDBG() << "Checking whether tiling legality violated " + << "for dependence at depth: " << Twine(d) << " between:" + << OpWithFlags(srcAccess.opInst, OpPrintingFlags().skipRegions()) + << "\nand:\n" + << OpWithFlags(dstAccess.opInst, + OpPrintingFlags().skipRegions()); for (const DependenceComponent &depComp : depComps) { if (depComp.lb.has_value() && depComp.ub.has_value() && *depComp.lb < *depComp.ub && *depComp.ub < 0) { - LLVM_DEBUG(llvm::dbgs() - << "Dependence component lb = " << Twine(*depComp.lb) - << " ub = " << Twine(*depComp.ub) - << " is negative at depth: " << Twine(d) - << " and thus violates the legality rule.\n"); + LDBG() << "Dependence component lb = " << Twine(*depComp.lb) + << " ub = " << Twine(*depComp.ub) + << " is negative at depth: " << Twine(d) + << " and thus violates the legality rule."; return false; } } diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index a89c1ae..99ea20b 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, } bool MemRefDependenceGraph::init() { - LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); + LDBG() << "--- Initializing MDG ---"; // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. DenseMap<Value, SetVector<unsigned>> memrefAccesses; @@ -288,8 +289,7 @@ bool MemRefDependenceGraph::init() { // Return false if non-handled/unknown region-holding ops are found. We // won't know what such ops do or what its regions mean; for e.g., it may // not be an imperative op. - LLVM_DEBUG(llvm::dbgs() - << "MDG init failed; unknown region-holding op found!\n"); + LDBG() << "MDG init failed; unknown region-holding op found!"; return false; } // We aren't creating nodes for memory-effect free ops either with no @@ -297,7 +297,7 @@ bool MemRefDependenceGraph::init() { // interface. } - LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n"); + LDBG() << "Created " << nodes.size() << " nodes"; // Add dependence edges between nodes which produce SSA values and their // users. Load ops can be considered as the ones producing SSA values. @@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, gatherDefiningNodes(dstId, definingNodes); if (llvm::any_of(definingNodes, [&](unsigned id) { return hasDependencePath(srcId, id); })) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: a defining op with a user in the dst " - "loop has dependence from the src loop\n"); + LDBG() << "Can't fuse: a defining op with a user in the dst " + << "loop has dependence from the src loop"; return nullptr; } @@ -957,20 +956,20 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { FlatAffineValueConstraints srcConstraints; // TODO: Store the source's domain to avoid computation at each depth. if (failed(getSourceAsConstraints(srcConstraints))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); + LDBG() << "Unable to compute source's domain"; return std::nullopt; } // As the set difference utility currently cannot handle symbols in its // operands, validity of the slice cannot be determined. if (srcConstraints.getNumSymbolVars() > 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); + LDBG() << "Cannot handle symbols in source domain"; return std::nullopt; } // TODO: Handle local vars in the source domains while using the 'projectOut' // utility below. Currently, aligning is not done assuming that there will be // no local vars in the source domain. if (srcConstraints.getNumLocalVars() != 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); + LDBG() << "Cannot handle locals in source domain"; return std::nullopt; } @@ -978,7 +977,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { // fusion succeeds. FlatAffineValueConstraints sliceConstraints; if (failed(getAsConstraints(&sliceConstraints))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); + LDBG() << "Unable to compute slice's domain"; return std::nullopt; } @@ -987,11 +986,11 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { sliceConstraints.projectOut(ivs.size(), sliceConstraints.getNumVars() - ivs.size()); - LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); - LLVM_DEBUG(srcConstraints.dump()); - LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " - "(expressed in terms of its source's IVs):\n"); - LLVM_DEBUG(sliceConstraints.dump()); + LDBG() << "Domain of the source of the slice:\n" + << "Source constraints:" << srcConstraints + << "\nDomain of the slice if this fusion succeeds " + << "(expressed in terms of its source's IVs):\n" + << "Slice constraints:" << sliceConstraints; // TODO: Store 'srcSet' to avoid recalculating for each depth. PresburgerSet srcSet(srcConstraints); @@ -999,7 +998,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { PresburgerSet diffSet = sliceSet.subtract(srcSet); if (!diffSet.isIntegerEmpty()) { - LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); + LDBG() << "Incorrect slice"; return false; } return true; @@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, unsigned rank = access.getRank(); - LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op - << "\ndepth: " << loopDepth << "\n";); + LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth; // 0-d memrefs. if (rank == 0) { @@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, if (auto constVal = getConstantIntValue(symbol)) cst.addBound(BoundType::EQ, symbol, constVal.value()); } else { - LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value"); + LDBG() << "unknown affine dimensional value"; return failure(); } } @@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, // Add access function equalities to connect loop IVs to data dimensions. if (failed(cst.composeMap(&accessValueMap))) { op->emitError("getMemRefRegion: compose affine map failed"); - LLVM_DEBUG(accessValueMap.getAffineMap().dump()); + LDBG() << "Access map: " << accessValueMap.getAffineMap(); return failure(); } @@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, } cst.removeTrivialRedundancy(); - LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); - LLVM_DEBUG(cst.dump()); + LDBG() << "Memory region: " << cst; return success(); } @@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() { auto memRefType = cast<MemRefType>(memref.getType()); if (!memRefType.getLayout().isIdentity()) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + LDBG() << "Non-identity layout map not yet supported"; return false; } // Compute the extents of the buffer. std::optional<int64_t> numElements = getConstantBoundingSizeAndShape(); if (!numElements) { - LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + LDBG() << "Dynamic shapes not yet supported"; return std::nullopt; } auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType); @@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp, /*addMemRefDimBounds=*/false))) return success(); - LLVM_DEBUG(llvm::dbgs() << "Memory region"); - LLVM_DEBUG(region.getConstraints()->dump()); + LDBG() << "Memory region: " << region.getConstraints(); bool outOfBounds = false; unsigned rank = loadOrStoreOp.getMemRefType().getRank(); @@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Check if 'loopDepth' exceeds nesting depth of src/dst ops. if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) || (isBackwardSlice && loopDepth > getNestingDepth(b))) { - LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); + LDBG() << "Invalid loop depth"; return SliceComputationResult::GenericFailure; } @@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, &dependenceConstraints, /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses); if (result.value == DependenceResult::Failure) { - LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); + LDBG() << "Dependence check failed"; return SliceComputationResult::GenericFailure; } if (result.value == DependenceResult::NoDependence) @@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { // Initialize 'sliceUnionCst' with the bounds computed in previous step. if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n"); + LDBG() << "Unable to compute slice bound constraints"; return SliceComputationResult::GenericFailure; } assert(sliceUnionCst.getNumDimAndSymbolVars() > 0); @@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. FlatAffineValueConstraints tmpSliceCst; if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n"); + LDBG() << "Unable to compute slice bound constraints"; return SliceComputationResult::GenericFailure; } @@ -1630,8 +1624,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, if (sliceUnionCst.getNumLocalVars() > 0 || tmpSliceCst.getNumLocalVars() > 0 || failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute union bounding box of slice bounds\n"); + LDBG() << "Unable to compute union bounding box of slice bounds"; return SliceComputationResult::GenericFailure; } } @@ -1639,7 +1632,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Empty union. if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { - LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n"); + LDBG() << "empty slice union - unexpected"; return SliceComputationResult::GenericFailure; } @@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, unsigned innermostCommonLoopDepth = getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); + LDBG() << "Exceeds max loop depth"; return SliceComputationResult::GenericFailure; } @@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // that the slice is valid, otherwise return appropriate failure status. std::optional<bool> isSliceValid = sliceUnion->isSliceValid(); if (!isSliceValid) { - LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); + LDBG() << "Cannot determine if the slice is valid"; return SliceComputationResult::GenericFailure; } if (!*isSliceValid) @@ -2050,7 +2043,8 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block, if (failed( region->compute(opInst, /*loopDepth=*/getNestingDepth(&*block.begin())))) { - LLVM_DEBUG(opInst->emitError("error obtaining memory region")); + LDBG() << "Error obtaining memory region"; + opInst->emitError("error obtaining memory region"); return failure(); } @@ -2058,9 +2052,11 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block, if (inserted) { it->second = std::move(region); } else if (failed(it->second->unionBoundingBox(*region))) { - LLVM_DEBUG(opInst->emitWarning( + LDBG() << "getMemoryFootprintBytes: unable to perform a union on a " + "memory region"; + opInst->emitWarning( "getMemoryFootprintBytes: unable to perform a union on a memory " - "region")); + "region"); return failure(); } return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 22608a1..7e5ce26 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) { return false; } +/// A utility function to check if a value is defined at the top level of +/// `region` or is an argument of `region` or is defined above the region. +static bool isTopLevelValueOrAbove(Value value, Region *region) { + Region *parentRegion = value.getParentRegion(); + do { + if (parentRegion == region) + return true; + Operation *regionOp = region->getParentOp(); + if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) + break; + region = region->getParentOp()->getParentRegion(); + } while (region); + return false; +} + /// A value can be used as a symbol for `region` iff it meets one of the /// following conditions: /// *) It is a constant. @@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { return false; // A top-level value is a valid symbol. - if (region && ::isTopLevelValue(value, region)) + if (region && isTopLevelValueOrAbove(value, region)) return true; auto *defOp = value.getDefiningOp(); - if (!defOp) { - // A block argument that is not a top-level value is a valid symbol if it - // dominates region's parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) - if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentOpRegion); + if (!defOp) return false; - } // Constant operation is ok. Attribute operandCst; @@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp)) return isDimOpValidSymbol(dimOp, region); - // Check for values dominating `region`'s parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) - if (auto *parentRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentRegion); - return false; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 6c9adff..ff0157e 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <iomanip> #include <optional> @@ -95,8 +96,8 @@ static bool canRemoveSrcNodeAfterFusion( // Otherwise, the src loop can't be removed. if (fusedLoopInsPoint != depNodeOp && !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " - "dominate dependence\n"); + LDBG() << "Src loop can't be removed: dst loop doesn't " + << "dominate dependence"; return false; } @@ -109,14 +110,13 @@ static bool canRemoveSrcNodeAfterFusion( if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { std::optional<bool> isMaximal = fusionSlice.isMaximal(); if (!isMaximal) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " - "if fusion is maximal\n"); + LDBG() << "Src loop can't be removed: can't determine " + << "if fusion is maximal"; return false; } if (!*isMaximal) { - LLVM_DEBUG(llvm::dbgs() - << "Src loop can't be removed: fusion is not maximal\n"); + LDBG() << "Src loop can't be removed: fusion is not maximal"; return false; } } @@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) { // Check if this is defined to be an alias of another memref. if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) - if (isEscapingMemref(viewOp.getViewSource(), block)) + if (memref == viewOp.getViewDest() && + isEscapingMemref(viewOp.getViewSource(), block)) return true; // Any op besides allocating ops wouldn't guarantee alias freedom @@ -279,19 +280,19 @@ static std::optional<double> getAdditionalComputeFraction( AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost) { - LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";); + LDBG() << "Determining additional compute fraction..."; // Compute cost of sliced and unsliced src loop nest. // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) { - LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n"); + LDBG() << "Failed to get source loop nest stats."; return std::nullopt; } // Compute cost of dst loop nest. LoopNestStats dstLoopNestStats; if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) { - LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n"); + LDBG() << "Failed to get destination loop nest stats."; return std::nullopt; } @@ -304,14 +305,14 @@ static std::optional<double> getAdditionalComputeFraction( const ComputationSliceState &slice = depthSliceUnions[depth - 1]; // Skip slice union if it wasn't computed for this depth. if (slice.isEmpty()) { - LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n"); + LDBG() << "Slice wasn't computed."; return std::nullopt; } if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp, dstLoopNestStats, slice, &fusedLoopNestComputeCost)) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); + LDBG() << "Unable to compute fusion compute cost"; return std::nullopt; } @@ -348,9 +349,8 @@ static Value createPrivateMemRef(AffineForOp forOp, MemRefAccess bM(cast<AffineWriteOpInterface>(b)); return aM == bM; })) { - LLVM_DEBUG(llvm::dbgs() - << "Private memref creation unsupported for multiple producer " - "stores with different access functions.\n"); + LDBG() << "Private memref creation unsupported for multiple producer " + << "stores with different access functions."; return nullptr; } @@ -455,8 +455,7 @@ static Value createPrivateMemRef(AffineForOp forOp, assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); (void)res; - LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType - << '\n'); + LDBG() << "Created private memref of type: " << newMemRefType; return newMemRef; } @@ -505,15 +504,12 @@ static bool isFusionProfitable(AffineForOp srcForOp, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold) { - LLVM_DEBUG({ - llvm::dbgs() - << "Checking whether fusion is profitable between source nest:\n"; - llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n"; - llvm::dbgs() << dstForOp << "\n"; - }); + LDBG() << "Checking whether fusion is profitable between source nest:"; + LDBG() << ' ' << srcForOp << " and destination nest:"; + LDBG() << dstForOp; if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); + LDBG() << "Can't fuse: maxLegalFusionDepth is 0"; return false; } @@ -537,8 +533,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, // TODO: Suppport multiple producer stores in profitability // analysis. if (producerStores.size() > 1) { - LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not " - "supported for multiple producer store case.\n"); + LDBG() << "Limited profitability analysis. Not " + << "supported for multiple producer store case."; int64_t sliceCost; int64_t fusedLoopNestComputeCost; // We will still fuse if fusion obeys the specified compute @@ -547,12 +543,11 @@ static bool isFusionProfitable(AffineForOp srcForOp, srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > computeToleranceThreshold) { - LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds " - "compute tolerance. Not fusing.\n"); + LDBG() << "Additional computation exceeds " + << "compute tolerance. Not fusing."; return false; } - LLVM_DEBUG(llvm::dbgs() - << "Considering fusion profitable at max legal depth.\n"); + LDBG() << "Considering fusion profitable at max legal depth."; return true; } @@ -574,8 +569,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOp->getLoc()); if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for source operation\n"); + LDBG() << "Unable to compute MemRefRegion for source operation"; return false; } @@ -609,8 +603,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!mayAdditionalComputeFraction) { - LLVM_DEBUG(llvm::dbgs() - << "Can't determine additional compute fraction.\n"); + LDBG() << "Can't determine additional compute fraction."; continue; } double additionalComputeFraction = *mayAdditionalComputeFraction; @@ -620,9 +613,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, // depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOp->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to compute slice write region at loopDepth: " << i - << "\n"); + LDBG() << "Failed to compute slice write region at loopDepth: " << i; continue; } @@ -630,9 +621,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, sliceWriteRegion.getRegionSize(); if (!maybeSliceWriteRegionSizeBytes.has_value() || *maybeSliceWriteRegionSizeBytes == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to get slice write region size at loopDepth: " << i - << "\n"); + LDBG() << "Failed to get slice write region size at loopDepth: " << i; continue; } int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; @@ -649,9 +638,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, << " storage reduction factor: " << storageReduction << "x\n" << " fused nest cost: " << fusedLoopNestComputeCost << "\n" << " src write region size: " << srcWriteRegionSizeBytes << "\n" - << " slice write region size: " << sliceWriteRegionSizeBytes - << "\n"; - llvm::dbgs() << msg.str(); + << " slice write region size: " << sliceWriteRegionSizeBytes; + LDBG() << msg.str(); }); // TODO: This is a placeholder cost model. @@ -670,28 +658,24 @@ static bool isFusionProfitable(AffineForOp srcForOp, // A simple cost model: fuse if it reduces the memory footprint. if (!bestDstLoopDepth) { - LLVM_DEBUG( - llvm::dbgs() - << "All fusion choices involve more than the threshold amount of " - "redundant computation; NOT fusing.\n"); + LDBG() << "All fusion choices involve more than the threshold amount of " + << "redundant computation; NOT fusing."; return false; } if (!bestDstLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); + LDBG() << "no fusion depth could be evaluated."; return false; } // Set dstLoopDepth based on best values from search. *dstLoopDepth = *bestDstLoopDepth; - LLVM_DEBUG( - llvm::dbgs() << " LoopFusion fusion stats:" - << "\n best loop depth: " << bestDstLoopDepth - << "\n src loop nest compute cost: " << srcLoopNestCost - << "\n dst loop nest compute cost: " << dstLoopNestCost - << "\n fused loop nest compute cost: " - << minFusedLoopNestComputeCost << "\n"); + LDBG() << " LoopFusion fusion stats:"; + LDBG() << " best loop depth: " << bestDstLoopDepth; + LDBG() << " src loop nest compute cost: " << srcLoopNestCost; + LDBG() << " dst loop nest compute cost: " << dstLoopNestCost; + LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost; auto dstMemSize = getMemoryFootprintBytes(dstForOp); auto srcMemSize = getMemoryFootprintBytes(srcForOp); @@ -699,8 +683,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, std::optional<double> storageReduction; if (!dstMemSize || !srcMemSize) { - LLVM_DEBUG(llvm::dbgs() - << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing."; return false; } @@ -710,13 +693,13 @@ static bool isFusionProfitable(AffineForOp srcForOp, assert(sliceMemEstimate && "expected value"); auto fusedMem = dstMemSizeVal + *sliceMemEstimate; - LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" - << " dst mem: " << dstMemSizeVal << "\n" - << " fused mem: " << fusedMem << "\n" - << " slice mem: " << sliceMemEstimate << "\n"); + LDBG() << " src mem: " << srcMemSizeVal; + LDBG() << " dst mem: " << dstMemSizeVal; + LDBG() << " fused mem: " << fusedMem; + LDBG() << " slice mem: " << sliceMemEstimate; if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { - LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + LDBG() << "Fusion is not profitable; NOT fusing."; return false; } storageReduction = @@ -734,8 +717,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, << std::setprecision(2) << additionalComputeFraction << "% redundant computation and a "; msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); - msg << "% storage reduction.\n"; - llvm::dbgs() << msg.str(); + msg << "% storage reduction."; + LDBG() << msg.str(); }); return true; @@ -895,7 +878,7 @@ public: /// No fusion is performed when producers with a user count greater than /// `maxSrcUserCount` for any of the memrefs involved. void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + LDBG() << "Evaluating dst loop " << dstId; // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) return; @@ -909,7 +892,7 @@ public: if (dstNode->op->getNumResults() > 0) return; - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + LDBG() << "Evaluating dst loop " << dstId; // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop @@ -936,18 +919,14 @@ public: auto *srcNode = mdg->getNode(srcId); auto srcAffineForOp = cast<AffineForOp>(srcNode->op); - LLVM_DEBUG(llvm::dbgs() - << "Trying to fuse producer loop nest " << srcId - << " with consumer loop nest " << dstId << "\n"); - LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: " - << computeToleranceThreshold << '\n'); - LLVM_DEBUG(llvm::dbgs() - << "Producer loop nest:\n" - << *srcNode->op << "\n and consumer loop nest:\n" - << *dstNode->op << '\n'); + LDBG() << "Trying to fuse producer loop nest " << srcId + << " with consumer loop nest " << dstId; + LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold; + LDBG() << "Producer loop nest:"; + LDBG() << *srcNode->op << " and consumer loop nest:"; + LDBG() << *dstNode->op; - LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId - << " for dst loop " << dstId << "\n"); + LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId; // Skip if 'srcNode' is a loop nest returning values. // TODO: support loop nests that return values. @@ -1018,19 +997,16 @@ public: &depthSliceUnions[i - 1], strategy); if (result.value == FusionResult::Success) { maxLegalFusionDepth = i; - LLVM_DEBUG(llvm::dbgs() - << "Found valid slice for depth: " << i << '\n'); + LDBG() << "Found valid slice for depth: " << i; } } if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: fusion is not legal at any depth\n"); + LDBG() << "Can't fuse: fusion is not legal at any depth"; continue; } - LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " - << maxLegalFusionDepth << '\n'); + LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth; double computeToleranceThresholdToUse = computeToleranceThreshold; @@ -1040,7 +1016,7 @@ public: // producer-consumer memref access for example). Check this and allow // fusion accordingly. if (hasCyclicDependence(srcAffineForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + LDBG() << "Source nest has a cyclic dependence."; // Maximal fusion does not check for compute tolerance threshold; so // perform the maximal fusion only when the redundanation computation // is zero. @@ -1053,18 +1029,15 @@ public: srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > 0) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't perform maximal fusion with a cyclic dependence " - "and non-zero additional compute.\n"); + LDBG() << "Can't perform maximal fusion with a cyclic dependence " + << "and non-zero additional compute."; return; } } else { // Set redundant computation tolerance to zero regardless of what // the user specified. Without this, fusion would be invalid. - LLVM_DEBUG(llvm::dbgs() - << "Setting compute tolerance to zero since " - "source has a cylic dependence.\n"); + LDBG() << "Setting compute tolerance to zero since " + << "source has a cylic dependence."; computeToleranceThresholdToUse = 0; } } @@ -1107,8 +1080,7 @@ public: if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, removeSrcNode)) { // Create a private version of this memref. - LLVM_DEBUG(llvm::dbgs() - << "Creating private memref for " << memref << '\n'); + LDBG() << "Creating private memref for " << memref; // Create a private version of this memref. privateMemrefs.insert(memref); } @@ -1118,10 +1090,9 @@ public: fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); dstNodeChanged = true; - LLVM_DEBUG(llvm::dbgs() - << "Fused src loop " << srcId << " into dst loop " << dstId - << " at depth " << bestDstLoopDepth << ":\n" - << dstAffineForOp << "\n"); + LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":"; + LDBG() << dstAffineForOp; // Move 'dstAffineForOp' before 'insertPointInst' if needed. if (fusedLoopInsPoint != dstAffineForOp) @@ -1179,8 +1150,7 @@ public: dstLoopCollector.memrefFrees); if (removeSrcNode) { - LLVM_DEBUG(llvm::dbgs() - << "Removing src loop " << srcId << " after fusion\n"); + LDBG() << "Removing src loop " << srcId << " after fusion"; // srcNode is no longer valid after it is removed from mdg. srcAffineForOp.erase(); mdg->removeNode(srcId); @@ -1195,7 +1165,7 @@ public: /// user count greater than `maxSrcUserCount` for any of the memrefs involved /// are encountered. void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); + LDBG() << "--- Producer/Consumer Fusion ---"; init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); @@ -1207,7 +1177,7 @@ public: // Visits each node in the graph, and for each node, attempts to fuse it with // its sibling nodes (nodes which share a parent, but no dependence edges). void fuseSiblingNodes() { - LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); + LDBG() << "--- Sibling Fusion ---"; init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); @@ -1289,8 +1259,7 @@ public: maxLegalFusionDepth = i; } - LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " - << maxLegalFusionDepth << '\n'); + LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth; // Skip if fusion is not feasible at any loop depths. if (maxLegalFusionDepth == 0) @@ -1304,7 +1273,7 @@ public: // producer-consumer memref access for example). Check this and allow // fusion accordingly. if (hasCyclicDependence(sibAffineForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + LDBG() << "Source nest has a cyclic dependence."; // Maximal fusion does not check for compute tolerance threshold; so // perform the maximal fusion only when the redundanation computation is // zero. @@ -1316,17 +1285,15 @@ public: sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > 0) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't perform maximal fusion with a cyclic dependence " - "and non-zero additional compute.\n"); + LDBG() << "Can't perform maximal fusion with a cyclic dependence " + << "and non-zero additional compute."; return; } } else { // Set redundant computation tolerance to zero regardless of what the // user specified. Without this, fusion would be invalid. - LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since " - "source has a cyclic dependence.\n"); + LDBG() << "Setting compute tolerance to zero since " + << "source has a cyclic dependence."; computeToleranceThresholdToUse = 0.0; } } @@ -1356,8 +1323,7 @@ public: // slice is used in the destination. auto isMaximal = bestSlice.isMaximal(); if (!isMaximal.value_or(false)) { - LLVM_DEBUG(llvm::dbgs() - << "Slice isn't maximal; not performing sibling fusion.\n"); + LDBG() << "Slice isn't maximal; not performing sibling fusion."; continue; } @@ -1374,10 +1340,9 @@ public: if (insertPointInst != dstForInst) dstForInst->moveBefore(insertPointInst); - LLVM_DEBUG(llvm::dbgs() - << "Fused sibling nest " << sibId << " into destination nest " - << dstNode->id << " at depth " << bestDstLoopDepth << ":\n" - << dstAffineForOp << "\n"); + LDBG() << "Fused sibling nest " << sibId << " into destination nest " + << dstNode->id << " at depth " << bestDstLoopDepth << ":"; + LDBG() << dstAffineForOp; // Update data dependence graph state post fusion. updateStateAfterSiblingFusion(sibNode, dstNode); @@ -1555,7 +1520,7 @@ public: void LoopFusion::runOnBlock(Block *block) { MemRefDependenceGraph g(*block); if (!g.init()) { - LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); + LDBG() << "MDG init failed"; return; } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp index 41cd739..c6abb0d 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -251,20 +252,20 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, FusionStrategy fusionStrategy) { // Return 'failure' if 'dstLoopDepth == 0'. if (dstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); + LDBG() << "Cannot fuse loop nests at depth 0"; return FusionResult::FailPrecondition; } // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. auto *block = srcForOp->getBlock(); if (block != dstForOp->getBlock()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); + LDBG() << "Cannot fuse loop nests in different blocks"; return FusionResult::FailPrecondition; } // Return 'failure' if no valid insertion point for fused loop nest in 'block' // exists which would preserve dependences. if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); + LDBG() << "Fusion would violate dependences in block"; return FusionResult::FailBlockDependence; } @@ -277,14 +278,14 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. SmallVector<Operation *, 4> opsA; if (!gatherLoadsAndStores(forOpA, opsA)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); + LDBG() << "Fusing loops with affine.if unsupported"; return FusionResult::FailPrecondition; } // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. SmallVector<Operation *, 4> opsB; if (!gatherLoadsAndStores(forOpB, opsB)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); + LDBG() << "Fusing loops with affine.if unsupported"; return FusionResult::FailPrecondition; } @@ -296,7 +297,7 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, // TODO: 'getMaxLoopDepth' does not support forward slice fusion. assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); + LDBG() << "Fusion would violate loop dependences"; return FusionResult::FailFusionDependence; } } @@ -339,12 +340,12 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, strategyOpsA, opsB, dstLoopDepth, numCommonLoops, isSrcForOpBeforeDstForOp, srcSlice); if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { - LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + LDBG() << "computeSliceUnion failed"; return FusionResult::FailPrecondition; } if (sliceComputationResult.value == SliceComputationResult::IncorrectSliceFailure) { - LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); + LDBG() << "Incorrect slice computation"; return FusionResult::FailIncorrectSlice; } @@ -477,7 +478,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, auto *parentForOp = forOp->getParentOp(); if (forOp != forOpRoot) { if (!isa<AffineForOp>(parentForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); + LDBG() << "Expected parent AffineForOp"; return WalkResult::interrupt(); } // Add mapping to 'forOp' from its parent AffineForOp. @@ -498,7 +499,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount) { // Currently only constant trip count loop nests are supported. - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); + LDBG() << "Non-constant trip count unsupported"; return WalkResult::interrupt(); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 2de057d..cd216ef 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -21,9 +21,11 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -365,12 +367,11 @@ checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) { if (input.size() <= 1) return success(); if (failed(getIndexSet(ops, &cst))) { - LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n"); + LDBG() << "Index set computation failed!"; return failure(); } if (!cst.isHyperRectangular(0, input.size())) { - LLVM_DEBUG(llvm::dbgs() - << "Non-hyperrectangular nests not supported for tiling!\n"); + LDBG() << "Non-hyperrectangular nests not supported for tiling!"; return failure(); } return success(); @@ -385,14 +386,13 @@ static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input, if (llvm::any_of(input, [](AffineForOp op) { return op.getNumResults() > 0; })) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot tile nest where a loop has yield values\n"); + LDBG() << "Cannot tile nest where a loop has yield values"; return failure(); } // Check if the supplied `for` ops are all successively nested. if (!isPerfectlyNested(input)) { - LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested"); + LDBG() << "input loops not perfectly nested"; return failure(); } @@ -1098,7 +1098,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, // If the trip count is lower than the unroll jam factor, no unroll jam. if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) { - LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n"); + LDBG() << "[failed] trip count < unroll-jam factor"; return failure(); } @@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation( unsigned maxLoopDepth = loops.size(); if (maxLoopDepth == 1) return true; + + // We cannot guarantee the validity of the interchange if the loops have + // iter_args, since the dependence analysis does not take them into account. + // Conservatively return false in such cases. + if (llvm::any_of(loops, [](AffineForOp loop) { + return loop.getNumIterOperands() > 0; + })) + return false; + // Gather dependence components for dependences between all ops in loop nest // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; @@ -1766,9 +1775,7 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, // We can't hoist past the definition of the memref being copied. Value memref = region.memref; if (!memref.getParentRegion()->isAncestor(enclosingOp->getParentRegion())) { - LLVM_DEBUG( - llvm::dbgs() - << "memref definition will end up not dominating hoist location\n"); + LDBG() << "memref definition will end up not dominating hoist location"; break; } @@ -1977,7 +1984,7 @@ static LogicalResult generateCopy( auto memRefType = cast<MemRefType>(memref.getType()); if (!memRefType.getLayout().isIdentity()) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + LDBG() << "Non-identity layout map not yet supported"; return failure(); } @@ -1989,7 +1996,7 @@ static LogicalResult generateCopy( unsigned rank = memRefType.getRank(); if (rank == 0) { - LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n"); + LDBG() << "Non-zero ranked memrefs supported"; return failure(); } @@ -2001,19 +2008,18 @@ static LogicalResult generateCopy( std::optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs); if (!numElements) { - LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); + LDBG() << "Non-constant region size not supported"; return failure(); } if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) { // This can be supported in the future if needed. - LLVM_DEBUG(llvm::dbgs() - << "Max lower bound for memref region start not supported\n"); + LDBG() << "Max lower bound for memref region start not supported"; return failure(); } if (*numElements == 0) { - LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n"); + LDBG() << "Nothing to copy"; return success(); } @@ -2021,9 +2027,8 @@ static LogicalResult generateCopy( for (unsigned i = 0; i < rank; ++i) { region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]); if (lbMaps[i].getNumResults() == 0 || ubMaps[i].getNumResults() == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Missing lower or upper bound for region along dimension: " - << i << '\n'); + LDBG() << "Missing lower or upper bound for region along dimension: " + << i; return failure(); } } @@ -2122,7 +2127,7 @@ static LogicalResult generateCopy( // TODO: use all stride levels once DmaStartOp is extended for // multi-level strides. if (dmaStrideInfos.size() > 1) { - LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n"); + LDBG() << "Only up to one level of stride supported"; return failure(); } @@ -2309,10 +2314,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, // surrounding the this block range. unsigned copyDepth = getNestingDepth(&*begin); - LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n"); - LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n"); + LDBG() << "Generating copies at depth " << copyDepth; + LDBG() << "from begin: " + << OpWithFlags(&*begin, OpPrintingFlags().skipRegions()); + LDBG() << "to inclusive end: " + << OpWithFlags(&*std::prev(end), OpPrintingFlags().skipRegions()); // List of memory regions to copy for. We need a map vector to have a // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here @@ -2349,8 +2355,8 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, return; if (!memref.getParentRegion()->isAncestor(block->getParent())) { - LLVM_DEBUG(llvm::dbgs() << "memref definition is inside of the depth at " - "which copy-in/copy-out would happen\n"); + LDBG() << "memref definition is inside of the depth at " + << "which copy-in/copy-out would happen"; return; } @@ -2358,12 +2364,10 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, auto region = std::make_unique<MemRefRegion>(opInst->getLoc()); if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr, /*addMemRefDimBounds=*/false))) { - LLVM_DEBUG(llvm::dbgs() - << "Error obtaining memory region: semi-affine maps?\n"); - LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); + LDBG() << "Error obtaining memory region: semi-affine maps?"; + LDBG() << "over-approximating to the entire memref"; if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { - LLVM_DEBUG( - opInst->emitError("non-constant memref sizes not yet supported")); + LDBG() << "non-constant memref sizes not yet supported"; error = true; return; } @@ -2392,13 +2396,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, // Perform a union with the existing region. if (failed(it->second->unionBoundingBox(*region))) { - LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed; " - "over-approximating to the entire memref\n"); + LDBG() << "Memory region bounding box failed; " + << "over-approximating to the entire memref"; // If the union fails, we will overapproximate. if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { - LLVM_DEBUG(opInst->emitError( - "non-constant memref sizes not yet supported")); + LDBG() << "non-constant memref sizes not yet supported"; error = true; return true; } @@ -2428,8 +2430,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, }); if (error) { - LLVM_DEBUG(begin->emitError( - "copy generation failed for one or more memref's in this block\n")); + LDBG() << "copy generation failed for one or more memref's in this block"; return failure(); } @@ -2466,8 +2467,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, processRegions(writeRegions); if (!ret) { - LLVM_DEBUG(begin->emitError( - "copy generation failed for one or more memref's in this block\n")); + LDBG() << "copy generation failed for one or more memref's in this block"; return failure(); } @@ -2608,7 +2608,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops, /*boundFloorDivisor=*/nullptr, /*ub=*/nullptr, &fullTileLbPos, &fullTileUbPos)) { - LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n"); + LDBG() << "Can't get constant diff pair for a loop"; return nullptr; } @@ -2667,8 +2667,7 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest, for (auto loop : inputNest) { // TODO: straightforward to generalize to a non-unit stride. if (loop.getStepAsInt() != 1) { - LLVM_DEBUG(llvm::dbgs() - << "[tile separation] non-unit stride not implemented\n"); + LDBG() << "[tile separation] non-unit stride not implemented"; return failure(); } SmallVector<Operation *, 1> loopOp{loop.getOperation()}; @@ -2682,8 +2681,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest, /*boundFloorDivisor=*/nullptr, /*ub=*/nullptr, &lbPos, &ubPos) || lbPos == ubPos) { - LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / " - "equalities not yet handled\n"); + LDBG() << "[tile separation] Can't get constant diff / " + << "equalities not yet handled"; return failure(); } @@ -2741,8 +2740,8 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest, AffineIfOp ifOp = createSeparationCondition(inputNest, b); if (!ifOp) { fullTileLoops.front().erase(); - LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating " - "separation condition\n"); + LDBG() << "All tiles are full tiles, or failure creating " + << "separation condition"; return failure(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 488c3c3..7d4d818 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, case AtomicRMWKind::addi: case AtomicRMWKind::maxu: case AtomicRMWKind::ori: + case AtomicRMWKind::xori: return builder.getZeroAttr(resultType); case AtomicRMWKind::andi: return builder.getIntegerAttr( @@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { // Integer operations. .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) - .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) + .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; }) .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) @@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return arith::OrIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::andi: return arith::AndIOp::create(builder, loc, lhs, rhs); + case AtomicRMWKind::xori: + return arith::XOrIOp::create(builder, loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 93682a9..4780dbb 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRArithTransforms UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms DEPENDS MLIRArithTransformsIncGen diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 1aa8064..35365f2 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -158,13 +158,11 @@ protected: PatternRewriter &rewriter) { // Check iterator types for matrix multiplication. SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray(); - if (!((itTypes.size() == 3 && - (itTypes[0] == vector::IteratorType::parallel && - itTypes[1] == vector::IteratorType::parallel && - itTypes[2] == vector::IteratorType::reduction)) || - (itTypes.size() == 2 && - (itTypes[0] == vector::IteratorType::parallel && - itTypes[1] == vector::IteratorType::reduction)))) + if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) && + (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::reduction)) return rewriter.notifyMatchFailure( op, "iterator types do not correspond to matrix multiplication"); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp index 35b0bd1..6cb2a56 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -183,9 +183,9 @@ protected: Value acc; // Conventional names for matrix dimensions. - int64_t M = 0; - int64_t N = 0; - int64_t K = 0; + int64_t m = 0; + int64_t n = 0; + int64_t k = 0; // Create the matrix mulitply and accumulate operation according to // `mmlaOp`. @@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Single-dimension vector type for the entire RHS tile. - auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType, + auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType, /*scalableDims=*/{true}); // Vector type having the same number of elements as a row in the // accumulator/output tile and the same element type. - auto accRowTy = VectorType::get(/*shape=*/N, resultEltType, + auto accRowTy = VectorType::get(/*shape=*/n, resultEltType, /*scalableDims=*/{true}); // Vector type having twice the number of elements as a row in the // accumulator/output tile the same element type. - auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType, + auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType, /*scalableDims=*/{true}); // Vector type having half the number of elements as a row in the // accumulator/output tile and an integer element type with twice the bit // width. - auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(), /*scalableDims=*/{true}); // Vector type having the same the number of elements as a row in the // accumulator/output tile and an integer element type with twice the bit // width. - auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(), /*scalableDims=*/{true}); Location loc = op.getLoc(); // Extract LHS sub-tiles with logical shape <2xK>. SmallVector<Value> lhsTile; - for (int64_t i = 0; i < M; i += 2) { + for (int64_t i = 0; i < m; i += 2) { // Extract two consecutive rows of the LHS tile. auto r0 = vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i}); auto r1 = vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1}); // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile. - SmallVector<int64_t> shuffleIdx(2 * K); + SmallVector<int64_t> shuffleIdx(2 * k); std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0); auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx); // Turn it into a scalable vector. @@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Extract the RHS sub-tiles with logical shape <Kx[2]>. SmallVector<Value> rhsTile; - for (int64_t j = 0; j < N; j += 2) + for (int64_t j = 0; j < n; j += 2) rhsTile.push_back(vector::ScalableExtractOp::create( - rewriter, loc, flatRhsType, rhs, j * K)); + rewriter, loc, flatRhsType, rhs, j * k)); // Extract and pack the ACC sub-tiles. SmallVector<Value> accTile; - for (int64_t i = 0; i < M; i += 2) { + for (int64_t i = 0; i < m; i += 2) { // Extract two consecutive rows of the accumulator tile. auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), ArrayRef<int64_t>{i}); @@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64); } // Extract ACC sub-tiles. - for (int64_t j = 0; j < N; j += 2) + for (int64_t j = 0; j < n; j += 2) accTile.push_back(vector::ScalableExtractOp::create( rewriter, loc, flatAccType, accTileVec, j * 2)); } // Emit sub-tile matrix multiplications. SmallVector<Value> outTile; - for (int64_t i = 0; i < M / 2; ++i) - for (int64_t j = 0; j < N / 2; ++j) { - Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i], + for (int64_t i = 0; i < m / 2; ++i) + for (int64_t j = 0; j < n / 2; ++j) { + Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i], rhsTile[j]); outTile.push_back(mmla); } // Unpack the OUT sub-tiles and insert into the result. Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType()); - for (int64_t i = 0; i < M / 2; ++i) { + for (int64_t i = 0; i < m / 2; ++i) { // Collect a number of sub-tiles in a row. Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty); - for (int64_t j = 0; j < N / 2; ++j) + for (int64_t j = 0; j < n / 2; ++j) row = vector::ScalableInsertOp::create( - rewriter, loc, outTile[i * N / 2 + j], row, j * 4); + rewriter, loc, outTile[i * n / 2 + j], row, j * 4); // Unpack the row to obtain two rows of the output. If we have the out // sub-tiles transposed we obtain two consecutive output rows by @@ -432,9 +432,9 @@ public: VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - M = lhsType.getDimSize(0); - N = rhsType.getDimSize(0); - K = rhsType.getDimSize(1); + m = lhsType.getDimSize(0); + n = rhsType.getDimSize(0); + k = rhsType.getDimSize(1); // Check the operands have the expected shape: // * for LHS: fixed vector MxK @@ -442,8 +442,8 @@ public: // * K == 8 // * M and N even and at least 2 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || - rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 || - M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 || + m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 || !rhsType.getScalableDims()[0]) return rewriter.notifyMatchFailure(op, "non-matching operand shape"); @@ -504,9 +504,9 @@ public: VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - M = lhsType.getDimSize(0); - N = rhsType.getDimSize(0); - K = rhsType.getDimSize(1); + m = lhsType.getDimSize(0); + n = rhsType.getDimSize(0); + k = rhsType.getDimSize(1); // Check the operands have the expected shape: // * for LHS: fixed vector MxK @@ -514,8 +514,8 @@ public: // * K == 4 // * M and N even and at least 2 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || - rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 || - M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 || + m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 || !rhsType.getScalableDims()[0]) return rewriter.notifyMatchFailure(op, "non-matching operand shape"); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp index ddc64ea..91e37dd 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -248,7 +248,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { Region *definingRegion = value.getParentRegion(); // Last users of the `value` inside all blocks where the value dies. - llvm::SmallSet<Operation *, 4> lastUsers; + llvm::SmallPtrSet<Operation *, 4> lastUsers; // Find blocks in the `definingRegion` that have users of the `value` (if // there are multiple users in the block, which one will be selected is diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 7eb729f..56ff212 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> { // which otherwise could prevent removal of unnecessary allocs. Value canonicalSource = source; while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>( - canonicalSource.getDefiningOp())) + canonicalSource.getDefiningOp())) { + if (canonicalSource != iface.getViewDest()) { + break; + } canonicalSource = iface.getViewSource(); + } std::optional<Operation *> maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); @@ -806,14 +810,12 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> { if (!srcTensorType) return failure(); auto currentOutputMemRefType = - dyn_cast<MemRefType>(toBuffer.getResult().getType()); + dyn_cast<BaseMemRefType>(toBuffer.getResult().getType()); if (!currentOutputMemRefType) return failure(); - auto memrefType = MemRefType::get(srcTensorType.getShape(), - srcTensorType.getElementType(), - currentOutputMemRefType.getLayout(), - currentOutputMemRefType.getMemorySpace()); + auto memrefType = currentOutputMemRefType.cloneWith( + srcTensorType.getShape(), srcTensorType.getElementType()); Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType, tensorCastOperand.getOperand(), toBuffer.getReadOnly()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 8916526..a465c95 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -37,8 +37,12 @@ using namespace mlir::bufferization; /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) + while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 8f983ab..0b2e080 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) { // Add additional dependencies created by view changes to the alias list. if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { registerDependencies(viewInterface.getViewSource(), - viewInterface->getResult(0)); + viewInterface.getViewDest()); return WalkResult::advance(); } @@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) { /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) + while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 91f6f25..68ef519 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -20,6 +20,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/DebugLog.h" #include <optional> namespace mlir { @@ -328,20 +329,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op, "blocks"); // Bufferize the op. - LLVM_DEBUG(llvm::dbgs() - << "//===-------------------------------------------===//\n" - << "IR after bufferizing: " << nextOp->getName() << "\n"); + LDBG(3) << "//===-------------------------------------------===//\n" + << "IR after bufferizing: " << nextOp->getName(); rewriter.setInsertionPoint(nextOp); if (failed( bufferizableOp.bufferize(rewriter, options, bufferizationState))) { - LLVM_DEBUG(llvm::dbgs() - << "failed to bufferize\n" - << "//===-------------------------------------------===//\n"); + LDBG(2) << "failed to bufferize\n" + << "//===-------------------------------------------===//"; return nextOp->emitError("failed to bufferize op"); } - LLVM_DEBUG(llvm::dbgs() - << *op - << "\n//===-------------------------------------------===//\n"); + LDBG(3) << *op << "\n//===-------------------------------------------===//"; } // Return early if the top-level op is entirely gone. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index a8e8353..fb7f2bb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -56,6 +56,7 @@ #include "mlir/Interfaces/SubsetOpInterface.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/DebugLog.h" MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) @@ -616,13 +617,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (getParallelRegion(def.getParentRegion(), options) != getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(), options)) { - LLVM_DEBUG( - llvm::dbgs() - << "\n- bufferizes out-of-place due to parallel region:\n"); - LLVM_DEBUG(llvm::dbgs() - << " unConflictingWrite = operand " - << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner() << "\n"); + LDBG() << "\n- bufferizes out-of-place due to parallel region:\n" + << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner(); return true; } } @@ -631,9 +629,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); - LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); - LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber() - << " of " << *readingOp << "\n"); + LDBG() << "\n- check conflict:\n" + << " uRead = operand " << uRead->getOperandNumber() << " of " + << *readingOp; // Find the definition of uRead by following the SSA use-def chain. // E.g.: @@ -648,23 +646,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read value has no definitions\n"); + LDBG() << " no conflict: read value has no definitions"; continue; } // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. for (OpOperand *uConflictingWrite : usesWrite) { - LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " - << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner() << "\n"); + LDBG() << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner(); // Check if op dominance can be used to rule out read-after-write // conflicts. bool useDominance = canUseOpDominance(uRead, uConflictingWrite, definitions, state); - LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n"); + LDBG() << "\n- useDominance = " << useDominance; // Throughout this loop, check for multiple requirements that have to be // met for uConflictingWrite to be an actual conflict. @@ -680,8 +677,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // inside a loop), there may be no meaningful `happensBefore` // relationship. if (happensBefore(readingOp, conflictingWritingOp, domInfo)) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read happens before write\n"); + LDBG() << " no conflict: read happens before write"; continue; } @@ -693,8 +689,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // Note: If the op is executed multiple times (e.g., because it is // inside a loop), it may be conflicting with itself. if (uConflictingWrite == uRead) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read and write are same use\n"); + LDBG() << " no conflict: read and write are same use"; continue; } @@ -705,8 +700,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // multiple times. if (state.insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in " - "mutually exclusive regions\n"); + LDBG() << " no conflict: read and write are in " + "mutually exclusive regions"; continue; } @@ -721,9 +716,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, state, uRead, uConflictingWrite->get()) || hasEquivalentValueInReverseUseDefChain( state, uConflictingWrite, uRead->get())) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: op bufferizes to element-wise access\n"); + LDBG() << " no conflict: op bufferizes to element-wise access"; continue; } } @@ -733,15 +726,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // No conflict if the operands are non-conflicting subsets. if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n"); + LDBG() << " no conflict: non-conflicting subsets"; continue; } // No conflict if the op interface says so. if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: op interace of reading op says 'no'\n"); + LDBG() << " no conflict: op interace of reading op says 'no'"; continue; } } @@ -751,9 +743,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, options.dynCastBufferizableOp(conflictingWritingOp)) { if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: op interace of writing op says 'no'\n"); + LDBG() << " no conflict: op interace of writing op says 'no'"; continue; } } @@ -761,29 +751,26 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // Check all possible definitions. for (Value definition : definitions) { - LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); + LDBG() << " * definition = " << definition; // No conflict if the conflicting write happens before the definition. if (Operation *defOp = definition.getDefiningOp()) { if (happensBefore(conflictingWritingOp, defOp, domInfo)) { // conflictingWritingOp happens before defOp. No conflict. - LLVM_DEBUG(llvm::dbgs() - << " no conflict: write happens before definition\n"); + LDBG() << " no conflict: write happens before definition"; continue; } // No conflict if conflictingWritingOp is contained in defOp. if (defOp->isProperAncestor(conflictingWritingOp)) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: write is contained in definition\n"); + LDBG() << " no conflict: write is contained in definition"; continue; } } else { auto bbArg = cast<BlockArgument>(definition); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " - "and write happens outside of block\n"); + LDBG() << " no conflict: definition is bbArg " + "and write happens outside of block"; // conflictingWritingOp happens outside of the block. No // conflict. continue; @@ -795,8 +782,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite); if (aliases.getNumAliases() == 1 && aliases.getAliases()[0].value == definition) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: definition and write are same\n"); + LDBG() << " no conflict: definition and write are same"; continue; } @@ -804,7 +790,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (options.printConflicts) annotateConflict(uRead, uConflictingWrite, definition); - LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); + LDBG() << " => RaW CONFLICT FOUND"; return true; } } @@ -958,7 +944,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, for (AliasingValue alias : state.getAliasingValues(operand)) state.applyOnAliases(alias.value, checkReadOnly); if (foundReadOnly) { - LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); + LDBG() << "=> NOT WRITABLE"; return true; } @@ -987,10 +973,9 @@ void OneShotAnalysisState::resetCache() { static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo) { - LLVM_DEBUG( - llvm::dbgs() << "//===-------------------------------------------===//\n" - << "Analyzing operand #" << operand.getOperandNumber() - << " of " << *operand.getOwner() << "\n"); + LDBG() << "//===-------------------------------------------===//\n" + << "Analyzing operand #" << operand.getOperandNumber() << " of " + << *operand.getOwner(); bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, state) || @@ -1001,8 +986,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, else state.bufferizeInPlace(operand); - LLVM_DEBUG(llvm::dbgs() - << "//===-------------------------------------------===//\n"); + LDBG() << "//===-------------------------------------------===//"; return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 725fa24..b593cca 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -51,14 +51,8 @@ static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } /// Return "true" if the given op is guaranteed to have neither "Allocate" nor /// "Free" side effects. static bool hasNeitherAllocateNorFreeSideEffect(Operation *op) { - if (isa<MemoryEffectOpInterface>(op)) - return !hasEffect<MemoryEffects::Allocate>(op) && - !hasEffect<MemoryEffects::Free>(op); - // If the op does not implement the MemoryEffectOpInterface but has has - // recursive memory effects, then this op in isolation (without its body) does - // not have any side effects. All the ops inside the regions of this op will - // be processed separately. - return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>(); + return !mightHaveEffect<MemoryEffects::Allocate>(op) && + !mightHaveEffect<MemoryEffects::Free>(op); } /// Return "true" if the given op has buffer semantics. I.e., it has buffer @@ -517,9 +511,7 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) { // MemoryEffectOpInterface. They usually do not have side effects apart // from the callee, which will be analyzed separately. (This is similar to // "recursive memory effects".) - if (!isa<MemoryEffectOpInterface>(op) && - !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() && - !isa<CallOpInterface>(op)) + if (hasUnknownEffects(op) && !isa<CallOpInterface>(op)) return op->emitError( "ops with unknown memory side effects are not supported"); diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 053ee95..0acb4b1 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -41,6 +41,7 @@ add_subdirectory(Transform) add_subdirectory(UB) add_subdirectory(Utils) add_subdirectory(Vector) +add_subdirectory(WasmSSA) add_subdirectory(X86Vector) add_subdirectory(XeGPU) diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index 37b4cfc..47740d3 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms BufferizableOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms LINK_LIBS PUBLIC MLIRBufferizationDialect diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp index 34f2dd5..3f6dd29 100644 --- a/mlir/lib/Dialect/DLTI/Traits.cpp +++ b/mlir/lib/Dialect/DLTI/Traits.cpp @@ -24,7 +24,7 @@ LogicalResult mlir::impl::verifyHasDefaultDLTIDataLayoutTrait(Operation *op) { } DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) { - return op->getAttrOfType<DataLayoutSpecAttr>( + return op->getAttrOfType<DataLayoutSpecInterface>( DLTIDialect::kDataLayoutAttrName); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e6a3154..00ce3b5 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -114,11 +114,8 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) { bool mlir::emitc::isSupportedFloatType(Type type) { if (auto floatType = llvm::dyn_cast<FloatType>(type)) { switch (floatType.getWidth()) { - case 16: { - if (llvm::isa<Float16Type, BFloat16Type>(type)) - return true; - return false; - } + case 16: + return llvm::isa<Float16Type, BFloat16Type>(type); case 32: case 64: return true; @@ -134,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) { type); } +bool mlir::emitc::isFundamentalType(Type type) { + return llvm::isa<IndexType>(type) || isPointerWideType(type) || + isSupportedIntegerType(type) || isSupportedFloatType(type) || + isa<emitc::PointerType>(type); +} + /// Check that the type of the initial value is compatible with the operations /// result type. static LogicalResult verifyInitializationAttribute(Operation *op, @@ -378,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } // ExpressionOp //===----------------------------------------------------------------------===// +ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::UnresolvedOperand> operands; + if (parser.parseOperandList(operands)) + return parser.emitError(parser.getCurrentLocation()) << "expected operands"; + if (succeeded(parser.parseOptionalKeyword("noinline"))) + result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name), + parser.getBuilder().getUnitAttr()); + Type type; + if (parser.parseColonType(type)) + return parser.emitError(parser.getCurrentLocation(), + "expected function type"); + auto fnType = llvm::dyn_cast<FunctionType>(type); + if (!fnType) + return parser.emitError(parser.getCurrentLocation(), + "expected function type"); + if (parser.resolveOperands(operands, fnType.getInputs(), + parser.getCurrentLocation(), result.operands)) + return failure(); + if (fnType.getNumResults() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected single return type"); + result.addTypes(fnType.getResults()); + Region *body = result.addRegion(); + SmallVector<OpAsmParser::Argument> argsInfo; + for (auto [unresolvedOperand, operandType] : + llvm::zip(operands, fnType.getInputs())) { + OpAsmParser::Argument argInfo; + argInfo.ssaName = unresolvedOperand; + argInfo.type = operandType; + argsInfo.push_back(argInfo); + } + if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true)) + return failure(); + return success(); +} + +void emitc::ExpressionOp::print(OpAsmPrinter &p) { + p << ' '; + p.printOperands(getDefs()); + p << " : "; + p.printFunctionalType(getOperation()); + p.shadowRegionArgs(getRegion(), getDefs()); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + Operation *ExpressionOp::getRootOp() { auto yieldOp = cast<YieldOp>(getBody()->getTerminator()); Value yieldedValue = yieldOp.getResult(); @@ -1398,6 +1447,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { //===----------------------------------------------------------------------===// // FieldOp //===----------------------------------------------------------------------===// + static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue) { @@ -1455,6 +1505,15 @@ LogicalResult FieldOp::verify() { //===----------------------------------------------------------------------===// // GetFieldOp //===----------------------------------------------------------------------===// + +LogicalResult GetFieldOp::verify() { + auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>(); + if (!parentClassOp.getOperation()) + return emitOpError(" must be nested within an emitc.class operation"); + + return success(); +} + LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); FieldOp fieldOp = diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index 3f0690c..f8469b8 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -9,7 +9,9 @@ #include "mlir/Dialect/EmitC/Transforms/Transforms.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace emitc { @@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) { Location loc = op->getLoc(); builder.setInsertionPointAfter(op); - auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType); + auto expressionOp = + emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands()); // Replace all op's uses with the new expression's result. result.replaceAllUsesWith(expressionOp.getResult()); - // Create an op to yield op's value. - Region ®ion = expressionOp.getRegion(); - Block &block = region.emplaceBlock(); + Block &block = expressionOp.createBody(); + IRMapping mapper; + for (auto [operand, arg] : + llvm::zip(expressionOp.getOperands(), block.getArguments())) + mapper.map(operand, arg); builder.setInsertionPointToEnd(&block); - auto yieldOp = emitc::YieldOp::create(builder, loc, result); - // Move op into the new expression. - op->moveBefore(yieldOp); + Operation *rootOp = builder.clone(*op, mapper); + op->erase(); + // Create an op to yield op's value. + emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]); return expressionOp; } @@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> { using OpRewritePattern<ExpressionOp>::OpRewritePattern; LogicalResult matchAndRewrite(ExpressionOp expressionOp, PatternRewriter &rewriter) const override { - bool anythingFolded = false; - for (Operation &op : llvm::make_early_inc_range( - expressionOp.getBody()->without_terminator())) { - // Don't fold expressions whose result value has its address taken. - auto applyOp = dyn_cast<emitc::ApplyOp>(op); - if (applyOp && applyOp.getApplicableOperator() == "&") - continue; - - for (Value operand : op.getOperands()) { - auto usedExpression = operand.getDefiningOp<ExpressionOp>(); - if (!usedExpression) - continue; - - // Don't fold expressions with multiple users: assume any - // re-materialization was done separately. - if (!usedExpression.getResult().hasOneUse()) - continue; - - // Don't fold expressions with side effects. - if (usedExpression.hasSideEffects()) - continue; - - // Fold the used expression into this expression by cloning all - // instructions in the used expression just before the operation using - // its value. - rewriter.setInsertionPoint(&op); - IRMapping mapper; - for (Operation &opToClone : - usedExpression.getBody()->without_terminator()) { - Operation *clone = rewriter.clone(opToClone, mapper); - mapper.map(&opToClone, clone); - } - - Operation *expressionRoot = usedExpression.getRootOp(); - Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); - assert(clonedExpressionRootOp && - "Expected cloned expression root to be in mapper"); - assert(clonedExpressionRootOp->getNumResults() == 1 && - "Expected cloned root to have a single result"); - - rewriter.replaceOp(usedExpression, clonedExpressionRootOp); - anythingFolded = true; - } + Block *expressionBody = expressionOp.getBody(); + ExpressionOp usedExpression; + SetVector<Value> foldedOperands; + + auto takesItsOperandsAddress = [](Operation *user) { + auto applyOp = dyn_cast<emitc::ApplyOp>(user); + return applyOp && applyOp.getApplicableOperator() == "&"; + }; + + // Select as expression to fold the first operand expression that + // - doesn't have its result value's address taken, + // - has a single user: assume any re-materialization was done separately, + // - has no side effects, + // and save all other operands to be used later as operands in the folded + // expression. + for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(), + expressionBody->getArguments())) { + ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>(); + if (usedExpression || !operandExpression || + llvm::any_of(arg.getUsers(), takesItsOperandsAddress) || + !operandExpression.getResult().hasOneUse() || + operandExpression.hasSideEffects()) + foldedOperands.insert(operand); + else + usedExpression = operandExpression; } - return anythingFolded ? success() : failure(); + + // If no operand expression was selected, bail out. + if (!usedExpression) + return failure(); + + // Collect additional operands from the folded expression. + for (Value operand : usedExpression.getOperands()) + foldedOperands.insert(operand); + + // Create a new expression to hold the folding result. + rewriter.setInsertionPointAfter(expressionOp); + auto foldedExpression = emitc::ExpressionOp::create( + rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(), + foldedOperands.getArrayRef(), expressionOp.getDoNotInline()); + Block &foldedExpressionBody = foldedExpression.createBody(); + + // Map each operand of the new expression to its matching block argument. + IRMapping mapper; + for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(), + foldedExpressionBody.getArguments())) + mapper.map(operand, arg); + + // Prepare to fold the used expression and the matched expression into the + // newly created folded expression. + auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold, + bool withTerminator) { + Block *expressionToFoldBody = expressionToFold.getBody(); + for (auto [operand, arg] : + llvm::zip(expressionToFold.getOperands(), + expressionToFoldBody->getArguments())) { + mapper.map(arg, mapper.lookup(operand)); + } + + for (Operation &opToClone : expressionToFoldBody->without_terminator()) + rewriter.clone(opToClone, mapper); + + if (withTerminator) + rewriter.clone(*expressionToFoldBody->getTerminator(), mapper); + }; + rewriter.setInsertionPointToStart(&foldedExpressionBody); + + // First, fold the used expression into the new expression and map its + // result to the clone of its root operation within the new expression. + foldExpression(usedExpression, /*withTerminator=*/false); + Operation *expressionRoot = usedExpression.getRootOp(); + Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); + assert(clonedExpressionRootOp && + "Expected cloned expression root to be in mapper"); + assert(clonedExpressionRootOp->getNumResults() == 1 && + "Expected cloned root to have a single result"); + mapper.map(usedExpression.getResult(), + clonedExpressionRootOp->getResults()[0]); + + // Now fold the matched expression into the new expression. + foldExpression(expressionOp, /*withTerminator=*/true); + + // Complete the rewrite. + rewriter.replaceOp(expressionOp, foldedExpression); + rewriter.eraseOp(usedExpression); + + return success(); } }; diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index c55e26e..06d7e07 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -64,8 +64,8 @@ public: TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - FieldOp fieldop = rewriter.create<emitc::FieldOp>( - funcOp->getLoc(), fieldName, typeAttr, nullptr); + FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(), + fieldName, typeAttr, nullptr); if (argAttrs && idx < argAttrs->size()) { fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx)); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 5a72ef1..b87b4f4 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result, Type asyncTokenType, ValueRange asyncDependencies, TypeRange workgroupAttributions, TypeRange privateAttributions, Value clusterSizeX, - Value clusterSizeY, Value clusterSizeZ) { + Value clusterSizeY, Value clusterSizeZ, + FlatSymbolRefAttr module, FlatSymbolRefAttr function) { OpBuilder::InsertionGuard g(builder); // Add a WorkGroup attribution attribute. This attribute is required to @@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result, if (dynamicSharedMemorySize) result.addOperands(dynamicSharedMemorySize); + // Add optional module and function attributes. + if (module) + result.addAttribute(getModuleAttrName(result.name), module); + if (function) + result.addAttribute(getFunctionAttrName(result.name), function); + // Create a kernel body region with kNumConfigRegionAttributes + N memory // attributions, where the first kNumConfigRegionAttributes arguments have // `index` type and the rest have the same types as the data operands. @@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) { p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' << getDynamicSharedMemorySize(); + // Print optional module attribute. + StringRef moduleAttrName = getModuleAttrName(); + if (auto module = getModule()) { + p << ' ' << moduleAttrName << '('; + p.printSymbolName(*module); + p << ')'; + } + // Print optional function attribute. + StringRef functionAttrName = getFunctionAttrName(); + if (auto function = getFunction()) { + p << ' ' << functionAttrName << '('; + p.printSymbolName(*function); + p << ')'; + } + printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); @@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) { p.printRegion(getBody(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ LaunchOp::getOperandSegmentSizeAttr(), - getNumWorkgroupAttributionsAttrName()}); + getNumWorkgroupAttributionsAttrName(), + moduleAttrName, functionAttrName}); } // Parse the size assignment blocks for blocks and threads. These have the form @@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser, /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional) /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment +/// (`dynamic_shared_memory_size` ssa-use)? +/// (`module(` symbol-ref-id `)`)? +/// (`function(` symbol-ref-id `)`)? /// memory-attribution /// region attr-dict? /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` @@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } + // Parse optional module attribute. + StringRef moduleAttrName = getModuleAttrName(result.name); + if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) { + FlatSymbolRefAttr moduleSymbol; + if (parser.parseLParen() || + parser.parseAttribute(moduleSymbol, Type(), moduleAttrName, + result.attributes) || + parser.parseRParen()) + return failure(); + } + // Parse optional function attribute. + StringRef functionAttrName = getFunctionAttrName(result.name); + if (succeeded(parser.parseOptionalKeyword(functionAttrName))) { + FlatSymbolRefAttr funcSymbol; + if (parser.parseLParen() || + parser.parseAttribute(funcSymbol, Type(), functionAttrName, + result.attributes) || + parser.parseRParen()) + return failure(); + } + // Create the region arguments, it has kNumConfigRegionAttributes arguments // that correspond to block/thread identifiers and grid/block sizes, all // having `index` type, a variadic number of WorkGroup Attributions and @@ -2439,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() { if (getArgs().size() != getWarpRegion().getNumArguments()) return emitOpError( "expected same number op arguments and block arguments."); - auto yield = - cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = getTerminator(); if (yield.getNumOperands() != getNumResults()) return emitOpError( "expected same number of yield operands and return values."); @@ -2464,6 +2510,50 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); } +gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() { + return cast<gpu::YieldOp>(getBody()->getTerminator()); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupBroadcastOp +//===----------------------------------------------------------------------===// + +void gpu::SubgroupBroadcastOp::inferResultRanges( + ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { + setResultRange(getResult(), argRanges.front()); +} + +Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() { + switch (getBroadcastType()) { + case BroadcastType::first_active_lane: + // Cannot speculate first_lane broadcast, because speculating it across + // control flow can change the active lanes. + return Speculation::NotSpeculatable; + case BroadcastType::any_lane: + LLVM_FALLTHROUGH; + case BroadcastType::specific_lane: + // Speculation should be safe as long as we inside structured control flow. + return Speculation::Speculatable; + } +} + +LogicalResult gpu::SubgroupBroadcastOp::verify() { + switch (getBroadcastType()) { + case BroadcastType::first_active_lane: + LLVM_FALLTHROUGH; + case BroadcastType::any_lane: + if (getLane()) + return emitOpError() + << "lane can only be specified for `specific_lane` broadcast"; + return success(); + case BroadcastType::specific_lane: + if (!getLane()) + return emitOpError() + << "lane must be specified for `specific_lane` broadcast"; + return success(); + } +} + //===----------------------------------------------------------------------===// // GPU KernelMetadataAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 21cb2f6..c766539 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/Utils.h" @@ -43,6 +44,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/LogicalResult.h" +#include <optional> #include <type_traits> using namespace mlir; @@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns( RewritePatternSet &patterns) { - populateGpuPromoteShuffleToAMDGPUPatterns(patterns); + std::optional<StringRef> chipsetName = getChipset(); + std::optional<amdgpu::Chipset> maybeChipset; + if (chipsetName) { + FailureOr<amdgpu::Chipset> parsedChipset = + amdgpu::Chipset::parse(*chipsetName); + assert(llvm::succeeded(parsedChipset) && "expected valid chipset"); + maybeChipset = parsedChipset; + } + + populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 9bf11c7..d2c2138 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -25,6 +25,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_GPUELIMINATEBARRIERS @@ -37,9 +38,6 @@ using namespace mlir::gpu; #define DEBUG_TYPE "gpu-erase-barriers" #define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") - // The functions below provide interface-like verification, but are too specific // to barrier elimination to become interfaces. @@ -424,27 +422,18 @@ static bool maybeCaptured(Value v) { /// everything. This seems sufficient to achieve barrier removal in structured /// control flow, more complex cases would require a proper dataflow analysis. static bool mayAlias(Value first, Value second) { - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { - DBGS_ALIAS() << "checking aliasing between "; - DBGS_ALIAS() << first << "\n"; - DBGS_ALIAS() << " and "; - DBGS_ALIAS() << second << "\n"; - }); + LDBG(DEBUG_TYPE_ALIAS, 1) + << "checking aliasing between " << first << " and " << second; first = getBase(first); second = getBase(second); - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { - DBGS_ALIAS() << "base "; - DBGS_ALIAS() << first << "\n"; - DBGS_ALIAS() << " and "; - DBGS_ALIAS() << second << "\n"; - }); + LDBG(DEBUG_TYPE_ALIAS, 1) << "base " << first << " and " << second; // Values derived from the same base memref do alias (unless we do a more // advanced analysis to prove non-overlapping accesses). if (first == second) { - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n"); + LDBG(DEBUG_TYPE_ALIAS, 1) << "-> do alias!"; return true; } @@ -493,7 +482,7 @@ static bool mayAlias(Value first, Value second) { return false; // Otherwise, conservatively assume aliasing. - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n"); + LDBG(DEBUG_TYPE_ALIAS, 1) << "-> may alias!"; return true; } @@ -567,20 +556,16 @@ haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects, continue; // Other kinds of effects create a conflict, e.g. read-after-write. - LLVM_DEBUG( - DBGS() << "found a conflict between (before): " << before.getValue() - << " read:" << isa<MemoryEffects::Read>(before.getEffect()) - << " write:" << isa<MemoryEffects::Write>(before.getEffect()) - << " alloc:" - << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:" - << isa<MemoryEffects::Free>(before.getEffect()) << "\n"); - LLVM_DEBUG( - DBGS() << "and (after): " << after.getValue() - << " read:" << isa<MemoryEffects::Read>(after.getEffect()) - << " write:" << isa<MemoryEffects::Write>(after.getEffect()) - << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect()) - << " free:" << isa<MemoryEffects::Free>(after.getEffect()) - << "\n"); + LDBG() << "found a conflict between (before): " << before.getValue() + << " read:" << isa<MemoryEffects::Read>(before.getEffect()) + << " write:" << isa<MemoryEffects::Write>(before.getEffect()) + << " alloc:" << isa<MemoryEffects::Allocate>(before.getEffect()) + << " free:" << isa<MemoryEffects::Free>(before.getEffect()); + LDBG() << "and (after): " << after.getValue() + << " read:" << isa<MemoryEffects::Read>(after.getEffect()) + << " write:" << isa<MemoryEffects::Write>(after.getEffect()) + << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect()) + << " free:" << isa<MemoryEffects::Free>(after.getEffect()); return true; } } @@ -595,8 +580,8 @@ public: LogicalResult matchAndRewrite(BarrierOp barrier, PatternRewriter &rewriter) const override { - LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " " - << barrier.getLoc() << "\n"); + LDBG() << "checking the necessity of: " << barrier << " " + << barrier.getLoc(); SmallVector<MemoryEffects::EffectInstance> beforeEffects; getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true); @@ -605,14 +590,12 @@ public: getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true); if (!haveConflictingEffects(beforeEffects, afterEffects)) { - LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing " - << barrier << "\n"); + LDBG() << "the surrounding barriers are sufficient, removing " << barrier; rewriter.eraseOp(barrier); return success(); } - LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " " - << barrier.getLoc() << "\n"); + LDBG() << "barrier is necessary: " << barrier << " " << barrier.getLoc(); return failure(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 99f5c5b..97adad6 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -356,8 +356,8 @@ public: auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { SetVector<Value> operands; std::string kernelFnName; - if (op.getKernelFunc()) { - kernelFnName = op.getKernelFunc()->getRootReference().str(); + if (op.getFunction()) { + kernelFnName = op.getFunction()->str(); } else { kernelFnName = Twine(op->getParentOfType<SymbolOpInterface>().getName(), @@ -403,9 +403,8 @@ private: OpBuilder builder(context); std::string kernelModuleName; gpu::GPUModuleOp kernelModule; - if (gpuLaunchOp.getKernelModule()) { - kernelModuleName = - gpuLaunchOp.getKernelModule()->getRootReference().str(); + if (gpuLaunchOp.getModule()) { + kernelModuleName = gpuLaunchOp.getModule()->str(); kernelModule = parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName); } else { @@ -432,8 +431,7 @@ private: if (std::optional<SymbolTable::UseRange> symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = - cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue(); + StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp index 18c69f5..67cef8a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp @@ -11,16 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/PatternMatch.h" +#include <optional> using namespace mlir; namespace { + +constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0); + /// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64 /// and offset must be a constant integer in the range [0, 31]. struct PromoteShuffleToSwizzlePattern @@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern return success(); } }; + +/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64 +/// and offset must be a constant integer in the set {16, 32}. +struct PromoteShuffleToPermlanePattern + : public OpRewritePattern<gpu::ShuffleOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ShuffleOp op, + PatternRewriter &rewriter) const override { + if (op.getMode() != gpu::ShuffleMode::XOR) + return rewriter.notifyMatchFailure(op, + "only xor shuffle mode is supported"); + + if (!isConstantIntValue(op.getWidth(), 64)) + return rewriter.notifyMatchFailure(op, + "only 64 width shuffle is supported"); + + std::optional<int64_t> offset = getConstantIntValue(op.getOffset()); + if (!offset) + return rewriter.notifyMatchFailure(op, + "offset must be a constant integer"); + + int64_t offsetValue = *offset; + if (offsetValue != 16 && offsetValue != 32) + return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31"); + + Location loc = op.getLoc(); + Value res = amdgpu::PermlaneSwapOp::create( + rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue); + Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1); + rewriter.replaceOp(op, {res, valid}); + return success(); + } +}; + } // namespace void mlir::populateGpuPromoteShuffleToAMDGPUPatterns( - RewritePatternSet &patterns) { - patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext()); + RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) { + patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(), + /*benefit*/ 1); + if (maybeChipset && *maybeChipset >= kGfx950) + patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(), + /*benefit*/ 2); } diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp index e9cf493..6da76e9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVM/XeVM/Target.h" #include "llvm/Support/Regex.h" namespace mlir { diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 384d1a0..88f531f 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include <numeric> @@ -55,28 +56,30 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( SmallVector<size_t> &indices) const { SmallVector<Type> types(warpOp.getResultTypes().begin(), warpOp.getResultTypes().end()); - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); + gpu::YieldOp yield = warpOp.getTerminator(); + SmallVector<Value> yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + llvm::SmallDenseMap<Value, unsigned> indexLookup; + // Record the value -> first index mapping for faster lookup. + for (auto [i, v] : llvm::enumerate(yieldValues)) { + if (!indexLookup.count(v)) + indexLookup[v] = i; + } + for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) { - if (yieldValues.insert(value)) { + // If the value already exists in the yield, don't create a new output. + if (indexLookup.count(value)) { + indices.push_back(indexLookup[value]); + } else { + // If the value is new, add it to the yield and to the types. + yieldValues.push_back(value); types.push_back(type); indices.push_back(yieldValues.size() - 1); - } else { - // If the value already exit the region don't create a new output. - for (auto [idx, yieldOperand] : - llvm::enumerate(yieldValues.getArrayRef())) { - if (yieldOperand == value) { - indices.push_back(idx); - break; - } - } } } - yieldValues.insert_range(newYieldedValues); + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues.getArrayRef(), types); + rewriter, warpOp, yieldValues, types); rewriter.replaceOp(warpOp, newWarpOp.getResults().take_front(warpOp.getNumResults())); return newWarpOp; @@ -85,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( OpOperand *WarpDistributionPattern::getWarpResult( WarpExecuteOnLane0Op warpOp, llvm::function_ref<bool(Operation *)> fn) const { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); for (OpOperand &yieldOperand : yield->getOpOperands()) { Value yieldValues = yieldOperand.get(); Operation *definedOp = yieldValues.getDefiningOp(); diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ff55f17..ec581ac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces + MLIRPtrMemorySpaceInterfaces MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 894de44..7220e10 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -12,10 +12,20 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/Regex.h" #define DEBUG_TYPE "ptx-builder" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") //===----------------------------------------------------------------------===// // BasicPtxBuilderInterface @@ -28,50 +38,122 @@ using namespace NVVM; static constexpr int64_t kSharedMemorySpace = 3; -static char getRegisterType(Type type) { - if (type.isInteger(1)) - return 'b'; - if (type.isInteger(16)) - return 'h'; - if (type.isInteger(32)) - return 'r'; - if (type.isInteger(64)) - return 'l'; - if (type.isF32()) - return 'f'; - if (type.isF64()) - return 'd'; - if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { - // Shared address spaces is addressed with 32-bit pointers. - if (ptr.getAddressSpace() == kSharedMemorySpace) { +static FailureOr<char> getRegisterType(Type type, Location loc) { + MLIRContext *ctx = type.getContext(); + auto i16 = IntegerType::get(ctx, 16); + auto i32 = IntegerType::get(ctx, 32); + auto f32 = Float32Type::get(ctx); + + auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> { + if (type.isInteger(1)) + return 'b'; + if (type.isInteger(16)) + return 'h'; + if (type.isInteger(32)) return 'r'; + if (type.isInteger(64)) + return 'l'; + if (type.isF32()) + return 'f'; + if (type.isF64()) + return 'd'; + if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { + // Shared address spaces is addressed with 32-bit pointers. + if (ptr.getAddressSpace() == kSharedMemorySpace) { + return 'r'; + } + return 'l'; } - return 'l'; + // register type for struct is not supported. + mlir::emitError( + loc, "The register type could not be deduced from MLIR type. The ") + << type + << " is not supported. Supported types are:" + "i1, i16, i32, i64, f32, f64," + "pointers.\nPlease use llvm.bitcast if you have different type. " + "\nSee the constraints from here: " + "https://docs.nvidia.com/cuda/inline-ptx-assembly/" + "index.html#constraints"; + return failure(); + }; + + // Packed registers + if (auto v = dyn_cast<VectorType>(type)) { + assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported"); + + int64_t lanes = v.getNumElements(); + Type elem = v.getElementType(); + + // Case 1. Single vector + if (lanes <= 1) + return getRegisterTypeForScalar(elem); + + // Case 2. Packed registers + Type widened = elem; + switch (lanes) { + + case 2: + if (elem.isF16() || elem.isBF16()) // vector<2xf16> + widened = f32; + else if (elem.isFloat(8)) // vector<2xf8> + widened = i16; + break; + case 4: + if (elem.isInteger(8)) // vector<i8x4> + widened = i32; + else if (elem.isFloat(8)) // vector<f8x4> + widened = f32; + else if (elem.isFloat(4)) // vector<f4x4> + widened = i16; + break; + // Other packing is not supported + default: + break; + } + return getRegisterTypeForScalar(widened); } - // register type for struct is not supported. - llvm_unreachable("The register type could not deduced from MLIR type"); - return '?'; + + return getRegisterTypeForScalar(type); } -static char getRegisterType(Value v) { +static FailureOr<char> getRegisterType(Value v, Location loc) { if (v.getDefiningOp<LLVM::ConstantOp>()) return 'n'; - return getRegisterType(v.getType()); + return getRegisterType(v.getType(), loc); } -void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { - LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n"); +/// Extract every element of a struct value. +static SmallVector<Value> extractStructElements(PatternRewriter &rewriter, + Location loc, Value structVal) { + auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType()); + assert(structTy && "expected LLVM struct"); + + SmallVector<Value> elems; + for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size())) + elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i)); + + return elems; +} + +LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { + LDBG() << v << "\t Modifier : " << itype << "\n"; + registerModifiers.push_back(itype); + + Location loc = interfaceOp->getLoc(); auto getModifier = [&]() -> const char * { - if (itype == PTXRegisterMod::ReadWrite) { - assert(false && "Read-Write modifier is not supported. Try setting the " - "same value as Write and Read separately."); - return "+"; - } - if (itype == PTXRegisterMod::Write) { + switch (itype) { + case PTXRegisterMod::Read: + return ""; + case PTXRegisterMod::Write: return "="; + case PTXRegisterMod::ReadWrite: + // "Read-Write modifier is not actually supported + // Interface will change it to "=" later and add integer mapping + return "+"; } - return ""; + llvm_unreachable("Unknown PTX register modifier"); }; + auto addValue = [&](Value v) { if (itype == PTXRegisterMod::Read) { ptxOperands.push_back(v); @@ -90,35 +172,273 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } for (auto [idx, t] : llvm::enumerate(stype.getBody())) { if (itype != PTXRegisterMod::Write) { - Value extractValue = LLVM::ExtractValueOp::create( - rewriter, interfaceOp->getLoc(), v, idx); + Value extractValue = + LLVM::ExtractValueOp::create(rewriter, loc, v, idx); addValue(extractValue); } if (itype == PTXRegisterMod::ReadWrite) { ss << idx << ","; } else { - ss << getModifier() << getRegisterType(t) << ","; + FailureOr<char> regType = getRegisterType(t, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, + "failed to get register type"); + ss << getModifier() << regType.value() << ","; } } - return; + return success(); } // Handle Scalars addValue(v); - ss << getModifier() << getRegisterType(v) << ","; + FailureOr<char> regType = getRegisterType(v, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, "failed to get register type"); + ss << getModifier() << regType.value() << ","; + return success(); +} + +/// Check if the operation needs to pack and unpack results. +static bool +needsPackUnpack(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl<PTXRegisterMod> ®isterModifiers) { + if (needsManualRegisterMapping) + return false; + const unsigned writeOnlyVals = interfaceOp->getNumResults(); + const unsigned readWriteVals = + llvm::count_if(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); + return (writeOnlyVals + readWriteVals) > 1; +} + +/// Pack the result types of the interface operation. +/// If the operation has multiple results, it packs them into a struct +/// type. Otherwise, it returns the original result types. +static SmallVector<Type> +packResultTypes(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl<PTXRegisterMod> ®isterModifiers, + SmallVectorImpl<Value> &ptxOperands) { + MLIRContext *ctx = interfaceOp->getContext(); + TypeRange resultRange = interfaceOp->getResultTypes(); + + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + // Single value path: + if (interfaceOp->getResults().size() == 1) + return SmallVector<Type>{resultRange.front()}; + + // No declared results: if there is an RW, forward its type. + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + return SmallVector<Type>{v.getType()}; + } + + SmallVector<Type> packed; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + packed.push_back(v.getType()); + for (Type t : resultRange) + packed.push_back(t); + + if (packed.empty()) + return {}; + + auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false); + return SmallVector<Type>{sTy}; +} + +/// Canonicalize the register constraints: +/// - Turn every "+X" into "=X" +/// - Append (at the very end) the 0-based indices of tokens that were "+X" +/// Examples: +/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2" +/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2" +static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) { + SmallVector<llvm::StringRef> toks; + SmallVector<std::string> out; + SmallVector<unsigned> plusIdx; + + csv.split(toks, ','); + out.reserve(toks.size() + 8); + + for (unsigned i = 0, e = toks.size(); i < e; ++i) { + StringRef t = toks[i].trim(); + if (t.consume_front("+")) { + plusIdx.push_back(i); + out.push_back(("=" + t).str()); + } else { + out.push_back(t.str()); + } + } + + // Append indices of original "+X" tokens. + for (unsigned idx : plusIdx) + out.push_back(std::to_string(idx)); + + // Join back to CSV. + std::string result; + result.reserve(csv.size() + plusIdx.size() * 2); + llvm::raw_string_ostream os(result); + for (size_t i = 0; i < out.size(); ++i) { + if (i) + os << ','; + os << out[i]; + } + return os.str(); +} + +constexpr llvm::StringLiteral kReadWritePrefix{"rw"}; +constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"}; +constexpr llvm::StringLiteral kReadOnlyPrefix{"r"}; + +/// Returns a regex that matches {$rwN}, {$wN}, {$rN} +static llvm::Regex getPredicateMappingRegex() { + llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", + kReadWritePrefix, kWriteOnlyPrefix, + kReadOnlyPrefix) + .str()); + return rx; +} + +void mlir::NVVM::countPlaceholderNumbers( + StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW, + llvm::SmallDenseSet<unsigned int> &seenW, + llvm::SmallDenseSet<unsigned int> &seenR, + llvm::SmallVectorImpl<unsigned int> &rwNums, + llvm::SmallVectorImpl<unsigned int> &wNums, + llvm::SmallVectorImpl<unsigned int> &rNums) { + + llvm::Regex rx = getPredicateMappingRegex(); + StringRef rest = ptxCode; + + SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number + while (!rest.empty() && rx.match(rest, &m)) { + unsigned num = 0; + (void)m[2].getAsInteger(10, num); + // Insert it into the vector only the first time we see this number + if (m[1].equals_insensitive(kReadWritePrefix)) { + if (seenRW.insert(num).second) + rwNums.push_back(num); + } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) { + if (seenW.insert(num).second) + wNums.push_back(num); + } else { + if (seenR.insert(num).second) + rNums.push_back(num); + } + + const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size(); + rest = rest.drop_front(advance); + } +} + +/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into +/// compact `$K` indices: +/// - All `rw*` first (sorted by N), +/// - Then `w*`, +/// - Then `r*`. +/// If there a predicate, it comes always in the end. +/// Each number is assigned once; duplicates are ignored. +/// +/// Example Input: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, {$r0}, {$r1};" +/// selp.s32 {$rw0}, {$r0}, {$r1}, p; +/// selp.s32 {$rw1}, {$r0}, {$r1}, p; +/// selp.s32 {$w0}, {$r0}, {$r1}, p; +/// selp.s32 {$w1}, {$r0}, {$r1}, p; +/// }\n" +/// Example Output: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, $4, $5;" +/// selp.s32 $0, $4, $5, p; +/// selp.s32 $1, $4, $5, p; +/// selp.s32 $2, $4, $5, p; +/// selp.s32 $3, $4, $5, p; +/// }\n" +static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) { + llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR; + llvm::SmallVector<unsigned> rwNums, wNums, rNums; + + // Step 1. Count Register Placeholder numbers + countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums); + + // Step 2. Sort the Register Placeholder numbers + llvm::sort(rwNums); + llvm::sort(wNums); + llvm::sort(rNums); + + // Step 3. Create mapping from original to new IDs + llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap; + unsigned nextId = 0; + for (unsigned n : rwNums) + rwMap[n] = nextId++; + for (unsigned n : wNums) + wMap[n] = nextId++; + for (unsigned n : rNums) + rMap[n] = nextId++; + + // Step 4. Rewrite the PTX code with new IDs + std::string out; + out.reserve(ptxCode.size()); + size_t prev = 0; + StringRef rest = ptxCode; + SmallVector<StringRef, 3> matches; + llvm::Regex rx = getPredicateMappingRegex(); + while (!rest.empty() && rx.match(rest, &matches)) { + // Compute absolute match bounds in the original buffer. + size_t absStart = (size_t)(matches[0].data() - ptxCode.data()); + size_t absEnd = absStart + matches[0].size(); + + // Emit text before the match. + out.append(ptxCode.data() + prev, ptxCode.data() + absStart); + + // Emit compact $K + unsigned num = 0; + (void)matches[2].getAsInteger(10, num); + unsigned id = 0; + if (matches[1].equals_insensitive(kReadWritePrefix)) + id = rwMap.lookup(num); + else if (matches[1].equals_insensitive(kWriteOnlyPrefix)) + id = wMap.lookup(num); + else + id = rMap.lookup(num); + + out.push_back('$'); + out += std::to_string(id); + + prev = absEnd; + + const size_t advance = + (size_t)(matches[0].data() - rest.data()) + matches[0].size(); + rest = rest.drop_front(advance); + } + + // Step 5. Tail. + out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size()); + return out; } LLVM::InlineAsmOp PtxBuilder::build() { auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); - auto resultTypes = interfaceOp->getResultTypes(); + SmallVector<Type> resultTypes = packResultTypes( + interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands); // Remove the last comma from the constraints string. if (!registerConstraints.empty() && registerConstraints[registerConstraints.size() - 1] == ',') registerConstraints.pop_back(); + registerConstraints = canonicalizeRegisterConstraints(registerConstraints); std::string ptxInstruction = interfaceOp.getPtx(); + if (!needsManualRegisterMapping) + ptxInstruction = rewriteAsmPlaceholders(ptxInstruction); // Add the predicate to the asm string. if (interfaceOp.getPredicate().has_value() && @@ -136,7 +456,7 @@ LLVM::InlineAsmOp PtxBuilder::build() { rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, - /*asm_string=*/llvm::StringRef(ptxInstruction), + /*asm_string=*/ptxInstruction, /*constraints=*/registerConstraints.data(), /*has_side_effects=*/interfaceOp.hasSideEffect(), /*is_align_stack=*/false, LLVM::TailCallKind::None, @@ -146,10 +466,89 @@ LLVM::InlineAsmOp PtxBuilder::build() { void PtxBuilder::buildAndReplaceOp() { LLVM::InlineAsmOp inlineAsmOp = build(); - LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n"); - if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { - rewriter.replaceOp(interfaceOp, inlineAsmOp); - } else { + LDBG() << "\n Generated PTX \n\t" << inlineAsmOp; + + // Case 0: no result at all → just erase wrapper op. + if (!hasResult) { rewriter.eraseOp(interfaceOp); + return; + } + + if (needsManualRegisterMapping) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + return; + } + + // Case 1: Simple path, return single scalar + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + if (inlineAsmOp->getNumResults() > 0) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + } else { + // RW-only case with no declared results: forward the RW value. + SmallVector<Value> results; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) { + results.push_back(v); + break; + } + rewriter.replaceOp(interfaceOp, results); + } + return; + } + + const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); + + // All multi-value paths produce a single struct result we need to unpack. + assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) && + "expected struct return for multi-result inline asm"); + Value structVal = inlineAsmOp.getResult(0); + SmallVector<Value> unpacked = + extractStructElements(rewriter, interfaceOp->getLoc(), structVal); + + // Case 2: only declared results (no RW): replace the op with all unpacked. + if (!hasRW && interfaceOp->getResults().size() > 0) { + rewriter.replaceOp(interfaceOp, unpacked); + return; + } + + // Case 3: RW-only (no declared results): update RW uses and erase wrapper. + if (hasRW && interfaceOp->getResults().size() == 0) { + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + rewriter.eraseOp(interfaceOp); + return; + } + + // Case 4: mixed (RW + declared results). + { + // First rewrite RW operands in place. + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + // The remaining unpacked values correspond to the declared results. + SmallVector<Value> tail; + tail.reserve(unpacked.size() - idx); + for (unsigned i = idx, e = unpacked.size(); i < e; ++i) + tail.push_back(unpacked[i]); + + rewriter.replaceOp(interfaceOp, tail); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 1e02bfe..e268e8f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrEnums.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -51,6 +53,87 @@ void LLVMDialect::registerAttributes() { } //===----------------------------------------------------------------------===// +// AddressSpaceAttr +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an LLVM type that can be loaded or stored. +static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, + std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) { + if (!isLoadableType(type)) { + if (emitError) + emitError() << "type must be LLVM type with size, but got " << type; + return false; + } + if (ordering == ptr::AtomicOrdering::not_atomic) + return true; + + // To check atomic validity we need a datalayout. + if (!dataLayout) { + if (emitError) + emitError() << "expected a valid data layout"; + return false; + } + if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) { + if (emitError) + emitError() << "unsupported type " << type << " for atomic access"; + return false; + } + return true; +} + +bool AddressSpaceAttr::isValidLoad( + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidStore( + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidAtomicOp( + ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, + std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once `ptr.atomic_rmw` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAtomicXchg( + Type type, ptr::AtomicOrdering successOrdering, + ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once `ptr.atomic_cmpxchg` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once the `ptr.addrspace_cast` op is added to the + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once the int-cast ops are added to the `ptr` + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +//===----------------------------------------------------------------------===// // AliasScopeAttr //===----------------------------------------------------------------------===// @@ -374,6 +457,43 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) { getAttributeName()); } +FailureOr<Attribute> TargetFeaturesAttr::query(DataLayoutEntryKey key) { + auto stringKey = dyn_cast<StringAttr>(key); + if (!stringKey) + return failure(); + + if (contains(stringKey)) + return UnitAttr::get(getContext()); + + if (contains((std::string("+") + stringKey.strref()).str())) + return BoolAttr::get(getContext(), true); + + if (contains((std::string("-") + stringKey.strref()).str())) + return BoolAttr::get(getContext(), false); + + return failure(); +} + +//===----------------------------------------------------------------------===// +// TargetAttr +//===----------------------------------------------------------------------===// + +FailureOr<::mlir::Attribute> TargetAttr::query(DataLayoutEntryKey key) { + if (auto stringAttrKey = dyn_cast<StringAttr>(key)) { + if (stringAttrKey.getValue() == "triple") + return getTriple(); + if (stringAttrKey.getValue() == "chip") + return getChip(); + if (stringAttrKey.getValue() == "features" && getFeatures()) + return getFeatures(); + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// ModuleFlagAttr +//===----------------------------------------------------------------------===// + LogicalResult ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError, LLVM::ModFlagBehavior flagBehavior, StringAttr key, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 422039f..ef27070 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { return success(); } +static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, + bool isExpandLoad, + uint64_t alignment = 1) { + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The pointer alignment defaults to 1. + if (alignment == 1) { + return nullptr; + } + + auto emptyDictAttr = builder.getDictionaryAttr({}); + auto alignmentAttr = builder.getI64IntegerAttr(alignment); + auto namedAttr = + builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr); + SmallVector<mlir::NamedAttribute> attrs = {namedAttr}; + auto alignDictAttr = builder.getDictionaryAttr(attrs); + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The align parameter attribute can be provided for [expandload]'s first + // argument. The align parameter attribute can be provided for + // [compressstore]'s second argument. + int pos = isExpandLoad ? 0 : 1; + return pos == 0 ? builder.getArrayAttr( + {alignDictAttr, emptyDictAttr, emptyDictAttr}) + : builder.getArrayAttr( + {emptyDictAttr, alignDictAttr, emptyDictAttr}); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -821,8 +853,8 @@ void LoadOp::getEffects( /// Returns true if the given type is supported by atomic operations. All /// integer, float, and pointer types with a power-of-two bitsize and a minimal /// size of 8 bits are supported. -static bool isTypeCompatibleWithAtomicOp(Type type, - const DataLayout &dataLayout) { +bool LLVM::isTypeCompatibleWithAtomicOp(Type type, + const DataLayout &dataLayout) { if (!isa<IntegerType, LLVMPointerType>(type)) if (!isCompatibleFloatingPointType(type)) return false; @@ -836,8 +868,9 @@ static bool isTypeCompatibleWithAtomicOp(Type type, /// Verifies the attributes and the type of atomic memory access operations. template <typename OpTy> -LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, - ArrayRef<AtomicOrdering> unsupportedOrderings) { +static LogicalResult +verifyAtomicMemOp(OpTy memOp, Type valueType, + ArrayRef<AtomicOrdering> unsupportedOrderings) { if (memOp.getOrdering() != AtomicOrdering::not_atomic) { DataLayout dataLayout = DataLayout::closest(memOp); if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout)) @@ -1087,7 +1120,7 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) { /// Verify that the parameter and return types of the variadic callee type match /// the `callOp` argument and result types. template <typename OpTy> -LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { +static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType(); if (!varCalleeType) return success(); @@ -2500,7 +2533,7 @@ LogicalResult GlobalOp::verifyRegions() { // LLVM::GlobalCtorsOp //===----------------------------------------------------------------------===// -LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) { +static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) { if (data.empty()) return success(); @@ -4117,6 +4150,32 @@ LogicalResult LLVM::masked_scatter::verify() { } //===----------------------------------------------------------------------===// +// masked_expandload (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state, + mlir::TypeRange resTys, Value ptr, + Value mask, Value passthru, + uint64_t align) { + ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align); + build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// +// masked_compressstore (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_compressstore::build(OpBuilder &builder, + OperationState &state, Value value, + Value ptr, Value mask, uint64_t align) { + ArrayAttr argAttrs = + getLLVMAlignParamForCompressExpand(builder, false, align); + build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// // InlineAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index e7d5dad..ef38027 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -19,6 +19,7 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "sroa" @@ -734,9 +735,8 @@ static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout, return false; }) .Default([&](Type type) { - LLVM_DEBUG(llvm::dbgs() - << "[sroa] Unsupported type for offset computations" - << type << "\n"); + LDBG() << "[sroa] Unsupported type for offset computations" + << type; return true; }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 78b4411..297640c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -24,7 +24,9 @@ using namespace mlir::LLVM; /// prints it as usual. static void dispatchPrint(AsmPrinter &printer, Type type) { if (isCompatibleType(type) && - !llvm::isa<IntegerType, FloatType, VectorType>(type)) + !(llvm::isa<IntegerType, FloatType, VectorType>(type) || + (llvm::isa<PtrLikeTypeInterface>(type) && + !llvm::isa<LLVMPointerType>(type)))) return mlir::LLVM::detail::printType(type, printer); printer.printType(type); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index fee2d3e..2dd0132 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -13,6 +13,7 @@ #include "TypeDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" @@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const { // Utility functions. //===----------------------------------------------------------------------===// +/// Check whether type is a compatible ptr type. These are pointer-like types +/// with no element type, no metadata, and using the LLVM AddressSpaceAttr +/// memory space. +static bool isCompatiblePtrType(Type type) { + auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type); + if (!ptrTy) + return false; + return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr && + isa<AddressSpaceAttr>(ptrTy.getMemorySpace()); +} + bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off if (llvm::isa< @@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { if (auto vecType = llvm::dyn_cast<VectorType>(type)) return vecType.getRank() == 1; - return false; + return isCompatiblePtrType(type); } static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { @@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { LLVMX86AMXType >([](Type) { return true; }) // clang-format on + .Case<PtrLikeTypeInterface>( + [](Type type) { return isCompatiblePtrType(type); }) .Default([](Type) { return false; }); if (!result) @@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) { return LLVMDialect::isCompatibleType(type); } +bool mlir::LLVM::isLoadableType(Type type) { + return /*LLVM_PrimitiveType*/ ( + LLVM::isCompatibleOuterType(type) && + !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) && + /*LLVM_OpaqueStruct*/ + !(isa<LLVM::LLVMStructType>(type) && + cast<LLVM::LLVMStructType>(type).isOpaque()) && + /*LLVM_AnyTargetExt*/ + !(isa<LLVM::LLVMTargetExtType>(type) && + !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps()); +} + bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, Float80Type, Float128Type, LLVMPPCFP128Type>(type); @@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { if (auto intType = llvm::dyn_cast<IntegerType>(elementType)) return intType.isSignless(); return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, - Float80Type, Float128Type, LLVMPointerType>(elementType); + Float80Type, Float128Type, LLVMPointerType>(elementType) || + isCompatiblePtrType(elementType); } return false; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e0977f5..376e3c3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/NVPTXAddrSpace.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <optional> @@ -50,7 +51,6 @@ using namespace NVVM; // This verifier is shared among the following Ops: // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) -// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) // CpAsyncBulkTensorReduceOp (TMA Store-Reduce) static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, @@ -82,8 +82,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { } LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { - if (getCoordinates().size() > 5) - return emitError("Maximum 5 coordinates and dimension is supported."); + TMAStoreMode mode = getMode(); + // We lower through inline-ptx when getPredicate() is true. + // a) Only TILE mode is supported + // b) Cache-hint is not supported + if (getPredicate()) { + if (mode != TMAStoreMode::TILE) + return emitError("Inline-ptx lowering supported only for Tile mode."); + if (getL2CacheHint()) + return emitError("Inline-ptx lowering unsupported with L2 cache-hint."); + } + + size_t dims = getCoordinates().size(); + switch (mode) { + case TMAStoreMode::TILE: + return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); + case TMAStoreMode::IM2COL: + return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); + case TMAStoreMode::TILE_SCATTER4: + if (dims != 5) + return emitError("Scatter4 mode expects 5 coordinates"); + } return success(); } @@ -98,17 +117,59 @@ LogicalResult CpAsyncOp::verify() { return success(); } +// This verify params can be shared across TMA Load and Prefetch Ops. +static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, + TMALoadMode mode, Location loc) { + if (tensorDims < 1 || tensorDims > 5) + return emitError(loc, "expects coordinates between 1 to 5 dimension"); + + auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col, + size_t expectedIm2colOff) -> LogicalResult { + if (isIm2col && (tensorDims < 3)) + return emitError(loc) + << "to use " << stringifyEnum(mode) + << " mode, the tensor has to be at least 3-dimensional"; + + if (numIm2colOff != expectedIm2colOff) + return emitError(loc) << " im2col offsets expected " << expectedIm2colOff + << " (provided " << numIm2colOff << ")"; + + return success(); + }; + + switch (mode) { + case TMALoadMode::TILE: + return checkTMALoadParams(mode, false, 0); + case TMALoadMode::IM2COL: + return checkTMALoadParams(mode, true, tensorDims - 2); + case TMALoadMode::IM2COL_W: + case TMALoadMode::IM2COL_W_128: + return checkTMALoadParams(mode, true, 2); + case TMALoadMode::TILE_GATHER4: + return (tensorDims == 5) + ? checkTMALoadParams(mode, false, 0) + : emitError(loc, "Gather4 mode expects 5 coordinates"); + } + return success(); +} + LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { - size_t numIm2ColOffsets = getIm2colOffsets().size(); - bool isIm2Col = numIm2ColOffsets > 0; - return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, - numIm2ColOffsets, getLoc()); + return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(), + getMode(), getLoc()); } LogicalResult CpAsyncBulkTensorReduceOp::verify() { - bool isIm2Col = (getMode() == TMAStoreMode::IM2COL); - return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0, - getLoc()); + TMAStoreMode mode = getMode(); + size_t dims = getCoordinates().size(); + switch (mode) { + case TMAStoreMode::TILE: + return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); + case TMAStoreMode::IM2COL: + return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); + case TMAStoreMode::TILE_SCATTER4: + return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp"); + } + return success(); } LogicalResult ConvertFloatToTF32Op::verify() { @@ -189,6 +250,26 @@ LogicalResult BulkStoreOp::verify() { return success(); } +LogicalResult PMEventOp::verify() { + auto eventId = getEventId(); + auto maskedEventId = getMaskedEventId(); + if (!maskedEventId && !eventId) { + return emitOpError() << "either `id` or `mask` must be set"; + } + + if (maskedEventId && eventId) { + return emitOpError() << "`id` and `mask` cannot be set at the same time"; + } + + if (eventId) { + if (eventId < 0 || eventId > 15) { + return emitOpError() << "`id` must be between 0 and 15"; + } + } + + return llvm::success(); +} + // Given the element type of an operand and whether or not it is an accumulator, // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the // operand's element type. @@ -791,24 +872,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() { } LogicalResult NVVM::LdMatrixOp::verify() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector<Type>(getNum(), i32)); + getContext(), SmallVector<Type>(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } @@ -1069,7 +1184,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { return ptx; } -void NVVM::WgmmaMmaAsyncOp::getAsmValues( +bool NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) { @@ -1100,7 +1215,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues( {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), mlir::NVVM::PTXRegisterMod::Read}); } + return true; // Has manual mapping } + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1216,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() { unsigned addressSpace = llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority(); + std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel(); - if (getUniform()) { - if (getCacheLevel() != CacheLevel::L1) - return emitOpError("unsupported cache level, the only supported uniform " - "cache level is L1"); + if (getTensormap() && cacheLevel) + return emitOpError("cannot specify both tensormap and cache level"); - if (addressSpace != MemSpace::kGenericMemorySpace) + if (getTensormap()) { + if (addressSpace != MemSpace::kGenericMemorySpace && + addressSpace != MemSpace::kConstantMemorySpace) { return emitOpError( - "prefetch to uniform cache requires a generic pointer"); - } + "prefetch tensormap requires a generic or constant pointer"); + } - if (evictPriority) { - if (getCacheLevel() != CacheLevel::L2) + if (evictPriority) { return emitOpError( - "cache eviction priority supported only for cache level L2"); - - if (addressSpace != MemSpace::kGlobalMemorySpace) - return emitOpError("cache eviction priority requires a global pointer"); + "prefetch tensormap does not support eviction priority"); + } - if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && - *evictPriority != NVVM::CacheEvictionPriority::EvictLast) + if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) { return emitOpError( - "unsupported cache eviction priority, only evict_last and " - "evict_normal are supported"); + "in_param_space can only be specified for a generic pointer"); + } + + } else if (cacheLevel) { + if (addressSpace != MemSpace::kGenericMemorySpace && + addressSpace != MemSpace::kGlobalMemorySpace && + addressSpace != MemSpace::kLocalMemorySpace) { + return emitOpError("prefetch to cache level requires a generic, global, " + "or local pointer"); + } + + if (getUniform()) { + if (*cacheLevel != CacheLevel::L1) { + return emitOpError( + "unsupported cache level, the only supported uniform " + "cache level is L1"); + } + + if (addressSpace != MemSpace::kGenericMemorySpace) { + return emitOpError( + "prefetch to uniform cache requires a generic pointer"); + } + } + + if (evictPriority) { + if (*cacheLevel != CacheLevel::L2) + return emitOpError( + "cache eviction priority supported only for cache level L2"); + + if (addressSpace != MemSpace::kGlobalMemorySpace) + return emitOpError("cache eviction priority requires a global pointer"); + + if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && + *evictPriority != NVVM::CacheEvictionPriority::EvictLast) + return emitOpError( + "unsupported cache eviction priority, only evict_last and " + "evict_normal are supported"); + } + + if (getPredicate()) + return emitOpError("predicate supported only on prefetch tensormap"); + + } else { + return emitOpError( + "requires specification of either cache level or tensormap"); } return success(); @@ -1379,28 +1536,102 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } -llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, - bool isIm2Col) { - switch (tensorDims) { - case 1: - return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; - case 2: - return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; - case 3: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; - case 4: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; - case 5: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; - default: - llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); - } +mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); + + for (auto v : thisOp.getCoordinates()) + args.push_back(mt.lookupValue(v)); + for (auto v : thisOp.getIm2colOffsets()) + args.push_back(mt.lookupValue(v)); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = + llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + args.push_back(builder.getInt1(hasCacheHint)); + + const unsigned NI = llvm::Intrinsic::not_intrinsic; + static constexpr llvm::Intrinsic::ID IDTable[][6] = { + {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d}, + {NI, NI, NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}}; + + static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1, + "TMALoadModes must match number of rows in IDTable"); + size_t mode = static_cast<size_t>(thisOp.getMode()); + size_t dim = thisOp.getCoordinates().size(); + llvm::Intrinsic::ID id = IDTable[mode][dim]; + if (id == llvm::Intrinsic::not_intrinsic) + llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair +CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); + + for (auto v : thisOp.getCoordinates()) + args.push_back(mt.lookupValue(v)); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = + llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + args.push_back(builder.getInt1(hasCacheHint)); + + const unsigned NI = llvm::Intrinsic::not_intrinsic; + static constexpr llvm::Intrinsic::ID IDTable[][6] = { + {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d}, + {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d}, + {NI, NI, NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}}; + + static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1, + "TMAStoreModes must match number of rows in IDTable"); + size_t mode = static_cast<size_t>(thisOp.getMode()); + size_t dim = thisOp.getCoordinates().size(); + llvm::Intrinsic::ID id = IDTable[mode][dim]; + if (id == llvm::Intrinsic::not_intrinsic) + llvm_unreachable( + "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp."); + + return {id, std::move(args)}; } #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ @@ -1566,7 +1797,7 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType()) .getAddressSpace(); bool isShared = as == NVVMMemorySpace::kSharedMemorySpace; - bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id; if (isShared) { @@ -1588,7 +1819,7 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector<llvm::Value *> &args) { auto curOp = cast<NVVM::Tcgen05DeallocOp>(op); - auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1) + auto id = (curOp.getGroup() == CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2; @@ -1616,7 +1847,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, .getAddressSpace(); bool isShared = as == NVVMMemorySpace::kSharedMemorySpace; bool hasMulticast = static_cast<bool>(curOp.getMulticastMask()); - bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast) @@ -1648,7 +1879,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast<NVVM::Tcgen05CpOp>(op); - bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; auto srcFmt = curOp.getSrcFormat(); auto mc = curOp.getMulticast(); @@ -1774,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( return {ids[type], args}; } -llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { +static llvm::Value *getParamCastedAddr(llvm::Value *addr, + llvm::IRBuilderBase &builder) { + return builder.CreateAddrSpaceCast( + addr, + llvm::PointerType::get(builder.getContext(), + llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM)); +} + +NVVM::IDArgPair +PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { using MemSpace = NVVM::NVVMMemorySpace; using CacheLevel = NVVM::PrefetchCacheLevel; - NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel(); + std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel(); std::optional<NVVM::CacheEvictionPriority> evictPriority = op.getEvictPriority(); unsigned addressSpace = llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType()) .getAddressSpace(); - if (op.getUniform() && cacheLevel == CacheLevel::L1) - return llvm::Intrinsic::nvvm_prefetchu_L1; + llvm::SmallVector<llvm::Value *> args; + llvm::Value *addr = mt.lookupValue(op.getAddr()); + args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder) + : addr); + + if (op.getTensormap()) + return {llvm::Intrinsic::nvvm_prefetch_tensormap, args}; + + assert(cacheLevel && "expected cache level for non-tensormap prefetch"); + + if (op.getUniform() && *cacheLevel == CacheLevel::L1) + return {llvm::Intrinsic::nvvm_prefetchu_L1, args}; - if (evictPriority && cacheLevel == CacheLevel::L2) { + if (evictPriority && *cacheLevel == CacheLevel::L2) { switch (*evictPriority) { case NVVM::CacheEvictionPriority::EvictLast: - return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args}; case NVVM::CacheEvictionPriority::EvictNormal: - return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args}; default: llvm_unreachable("Invalid cache eviction priority"); } @@ -1801,21 +2053,41 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { switch (addressSpace) { case MemSpace::kGenericMemorySpace: - return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1 - : llvm::Intrinsic::nvvm_prefetch_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args}) + : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args}); case MemSpace::kGlobalMemorySpace: - return cacheLevel == CacheLevel::L1 - ? llvm::Intrinsic::nvvm_prefetch_global_L1 - : llvm::Intrinsic::nvvm_prefetch_global_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_global_L1, args}) + : NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_global_L2, args}); case MemSpace::kLocalMemorySpace: - return cacheLevel == CacheLevel::L1 - ? llvm::Intrinsic::nvvm_prefetch_local_L1 - : llvm::Intrinsic::nvvm_prefetch_local_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_local_L1, args}) + : NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_local_L2, args}); default: llvm_unreachable("Invalid pointer address space"); } } +bool NVVM::InlinePtxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + for (auto arg : getReadWriteArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite}); + for (auto arg : getResults()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write}); + for (auto arg : getReadOnlyArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read}); + if (getPredicate()) + asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read}); + return false; // No manual mapping needed +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -1854,19 +2126,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, attrName == NVVMDialect::getReqntidAttrName() || attrName == NVVMDialect::getClusterDimAttrName()) { auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); - if (!values || values.empty() || values.size() > 3) + if (!values || values.empty() || values.size() > 3) { return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; + } } // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer // attribute if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName() || attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { - if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) + if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) { return op->emitError() << "'" << attrName << "' attribute must be integer constant"; + } + } + // blocksareclusters must be used along with reqntid and cluster_dim + if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) { + if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) || + !op->hasAttr(NVVMDialect::getClusterDimAttrName())) { + return op->emitError() + << "'" << attrName << "' attribute must be used along with " + << "'" << NVVMDialect::getReqntidAttrName() << "' and " + << "'" << NVVMDialect::getClusterDimAttrName() << "'"; + } } return success(); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp index 8317b67..23b4130 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; using namespace LLVM; @@ -63,9 +63,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr, } if (maxNumRewrites && numRewrites >= *maxNumRewrites) { - LLVM_DEBUG(llvm::dbgs() - << "LLVMDIExpressionSimplifier exceeded max num rewrites (" - << maxNumRewrites << ")\n"); + LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites (" + << maxNumRewrites << ")"; // Skip rewriting the rest. result.append(inputs.begin(), inputs.end()); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index b951df8..4ea2ac9 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -129,7 +129,6 @@ handleInlinedAllocas(Operation *call, OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(allocaOp); LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(), - arraySize.getValue().getLimitedValue(), allocaOp.getResult()); } allocaOp->moveAfter(newConstant); @@ -147,7 +146,6 @@ handleInlinedAllocas(Operation *call, for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { if (shouldInsertLifetime) LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(), - arraySize.getValue().getLimitedValue(), allocaOp.getResult()); } } @@ -237,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) { WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) { // Attempt to advance to the source of the underlying view-like operation. // Examples of view-like operations include GEPOp and AddrSpaceCastOp. - if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) - return WalkContinuation::advanceTo(viewOp.getViewSource()); + if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) { + if (val == viewOp.getViewDest()) + return WalkContinuation::advanceTo(viewOp.getViewSource()); + } // Attempt to advance to control flow predecessors. std::optional<SmallVector<Value>> controlFlowPredecessors = diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 34c63d3..578931e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, ArrayRef<AffineMap> indexingMaps) { // Initialize indexingMaps attribute, for MatmulOp. SmallVector<Attribute, 3> indexingMapsAttrVal; - indexingMapsAttrVal = llvm::map_to_vector( - MatmulOp::getDefaultIndexingMaps(b.getContext()), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, attributes, regionBuilder); @@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -// Retrieve the operation from the body, if it is the only one (except -// yield) and if it gets the same amount of arguments as the body does. -// If initFirst flag is enabled, we check that init takes the first position in -// operands of payload. -static Operation *findPayloadOp(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false) { + // Check if the body can be printed in short form. The following 4 conditions + // must be satisfied: + + // 1) The body must contain exactly 2 operations: the payload op and a yield. if (body->getOperations().size() != 2) - return nullptr; + return false; Operation &payload = body->getOperations().front(); - assert(isa<YieldOp>(body->getOperations().back())); + // 2) The payload op must have the same number of operands as the number of + // block arguments. if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) - return nullptr; + return false; + + // 3) If `initFirst` is true (e.g., for reduction ops), the init block + // must be the first operand of the payload op, otherwise, the operands + // must match the block arguments in order. if (initFirst) { // check init if (payload.getOperands().back() != body->getArgument(0)) - return nullptr; + return false; // check rest for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { if (bbArg != operand) - return nullptr; + return false; } } else { for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments())) { if (bbArg != operand) - return nullptr; + return false; } } - return &payload; + + // 4) The `yield` operand must be the result of the payload op. + auto yieldOp = cast<YieldOp>(body->getTerminator()); + return yieldOp.getNumOperands() == 1 && + yieldOp.getOperand(0).getDefiningOp() && + yieldOp.getOperand(0).getDefiningOp() == &payload; } -void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { +static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector<StringRef> elidedAttrs; std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); @@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); p.printOptionalAttrDict((*this)->getAttrs()); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) { // MatMulOp //===----------------------------------------------------------------------===// +static FailureOr<SmallVector<SmallVector<int64_t>>> +getAffineResultPositions(ArrayAttr maps) { + SmallVector<SmallVector<int64_t>> positions; + for (auto map : maps) { + AffineMapAttr attr = dyn_cast<AffineMapAttr>(map); + if (!attr) + return failure(); + SmallVector<int64_t> pos; + for (auto result : attr.getAffineMap().getResults()) { + auto dim = dyn_cast<AffineDimExpr>(result); + if (!dim) + return failure(); + pos.push_back(dim.getPosition()); + } + positions.push_back(pos); + } + return positions; +} + /// Returns a list of AffineMap with the typical matmul indexing charactristic. SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { AffineExpr d0, d1, d2; @@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool MatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -3836,7 +3880,7 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { return expr.isFunctionOfDim(bcastMap.getNumDims() - 1); } -FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) { +static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) { if (parser.parseOptionalKeyword("indexing_maps")) return ArrayAttr{ nullptr}; // Success in case indexing_maps was not provided. @@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +SmallVector<AffineMap> +MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{2, 0} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{1, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 2, 3} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + //===----------------------------------------------------------------------===// // ContractOp //===----------------------------------------------------------------------===// @@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{ utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -5042,7 +5474,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, /// Returns true if the tiles and the tiled dims are constant. template <typename OpTy> -bool areTilesAndTiledDimsAllConstant(OpTy op) { +static bool areTilesAndTiledDimsAllConstant(OpTy op) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, "applies to only pack or unpack operations"); ShapedType packedType = (std::is_same<OpTy, PackOp>::value) @@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() { SmallVector<int64_t> UnPackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getSourceType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } @@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{1, 2}; +} unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; } std::string BatchReduceMatmulOp::getLibraryCallName() { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bdfc8d0..f0c1f44 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" @@ -27,6 +28,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -68,12 +70,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); // We want to discourage direct use of PatternRewriter in APIs but In this // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(operation->getContext()); + PatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) @@ -1985,14 +1982,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, // Convert the padding values to attributes. SmallVector<Attribute> paddingValues; - for (auto const &it : + for (auto const &[untypedAttr, elementOrTensorType] : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { - auto attr = dyn_cast<TypedAttr>(std::get<0>(it)); + + if (isa<ub::PoisonAttr>(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } + auto attr = dyn_cast<TypedAttr>(untypedAttr); if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } - Type elementType = getElementTypeOrSelf(std::get<1>(it)); + Type elementType = getElementTypeOrSelf(elementOrTensorType); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = dyn_cast<StringAttr>(attr)) { auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute( @@ -2000,7 +2002,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") - << elementType << ", got " << std::get<0>(it); + << elementType << ", got " << untypedAttr; diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; return DiagnosedSilenceableFailure::definiteFailure(); } @@ -2235,8 +2237,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) { auto attr = dyn_cast<TypedAttr>(untypedAttr); Type elementType = getElementTypeOrSelf(elementOrTensorType); + + if (isa<ub::PoisonAttr>(untypedAttr)) { + paddingValues.push_back(untypedAttr); + continue; + } if (!attr) { - emitOpError("expects padding values to be typed attributes"); + emitOpError("expects padding values to be typed attributes or poison"); return DiagnosedSilenceableFailure::definiteFailure(); } // Try to parse string attributes to obtain an attribute of element type. @@ -3783,8 +3790,15 @@ LogicalResult TileUsingForallOp::verify() { void transform::VectorizeChildrenAndApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { + bool foldTypeExtensionsIntoContract, bool vectorizePadding, + bool vectorizeExtract, bool flatten1DDepthwiseConv) { result.addOperands(target); + if (foldTypeExtensionsIntoContract) { + result.addAttribute( + VectorizeChildrenAndApplyPatternsOp:: + getFoldTypeExtensionsIntoContractAttrName(result.name), + builder.getUnitAttr()); + } if (vectorizePadding) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName( @@ -3875,6 +3889,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( patterns.add<CopyVectorizationPattern>(ctx); + if (getFoldTypeExtensionsIntoContract()) + vector::populateFoldArithExtensionPatterns(patterns); + if (getVectorizePadding()) { linalg::populatePadOpVectorizationPatterns(patterns); // This creates an alternative path for lowering tensor.pad - by diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 3908d73..6912da3f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -55,8 +55,8 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, // Skip the batch dimension if present. // Offset all dimensions accordingly. SmallVector<int64_t, 3> offsetDims(dims); - for (size_t i = 0; i < offsetDims.size(); i++) - offsetDims[i] += batchDimsOffset; + for (int64_t &offsetDim : offsetDims) + offsetDim += batchDimsOffset; auto tileOp = cast<TilingInterface>(linalgOp.getOperation()); OpBuilder builder(tileOp); @@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns( RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { patterns.add<BlockPackMatmul<linalg::GenericOp>, BlockPackMatmul<linalg::MatmulOp>, - BlockPackMatmul<linalg::BatchMatmulOp>, - BlockPackMatmul<linalg::MatmulTransposeAOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, - BlockPackMatmul<linalg::MatmulTransposeBOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( - patterns.getContext(), controlFn); + BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(), + controlFn); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 70f846e..fb39e186 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms InlineScalarOperands.cpp Interchange.cpp Loops.cpp + MorphOps.cpp TransposeMatmul.cpp ShardingInterfaceImpl.cpp - NamedOpConversions.cpp + SimplifyDepthwiseConv.cpp + NamedToElementwise.cpp BlockPackMatmul.cpp PackAndUnpackPatterns.cpp Padding.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index d1eb270..108abe8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType, return arith::MulFOp::create(builder, loc, xConvert, yConvert); } -// Delinearizes the given composite `index` by the basis specified in `factors`. -static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index, - ArrayRef<int64_t> factors) { - assert(!factors.empty() && "empty factor list"); - SmallVector<Value> basis; - for (int64_t f : factors) - basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f))); - FailureOr<SmallVector<Value>> multiIndex = - affine::delinearizeIndex(b, loc, index, basis); - assert(!failed(multiIndex) && "Failed to linearize img2col index"); - return *multiIndex; +// Generate the affine expression to compute the convolved index +// for the input as `oIndex * stride + fIndex`, +// where oIndex: output iterator; fIndex: filter iterator. +static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, + bool useSymbols = true) { + AffineExpr oExpr, fExpr; + if (useSymbols) + bindSymbols(b.getContext(), oExpr, fExpr); + else + bindDims(b.getContext(), oExpr, fExpr); + return AffineExpr(stride * oExpr + fExpr); } -// Given indices corresponding to iterators in the output (oIndex) and filter -// (fIndex) for a convolution, compute the convolved index for the -// input as `oIndex * stride + fIndex`. -static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, - Value fIndex, int64_t stride) { - AffineExpr oExpr, fExpr; - bindSymbols(b.getContext(), oExpr, fExpr); - AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr); - return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex}); +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the corresponding indices of the output and filter matrices +struct Im2ColToOperandsExprs { + AffineExpr fhIndex; + AffineExpr fwIndex; + AffineExpr icIndex; + AffineExpr ohIndex; + AffineExpr owIndex; +}; + +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the input matrix indices +struct Im2ColToInputDimsExprs { + AffineExpr bIndex; + AffineExpr hIndex; + AffineExpr wIndex; + AffineExpr cIndex; +}; + +/// Construct the affine expressions that map the indices of the im2col matrix +/// to the corresponding input tensor indices for a 2D convolution with the the +/// provided strides. +/// +/// @param exprs Affine expressions for output and filter indices. +/// @param strides [height, width] stride values for the convolution. +/// @param rewriter Pattern rewriter. +/// @return Affine expressions mapping im2col matrix indices to input +/// offsets. +static Im2ColToInputDimsExprs +getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, + ArrayRef<int64_t> strides, RewriterBase &rewriter) { + // maps the iteration space of the im2col matrix to (output_y, filter_y) + auto hIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0]; + // maps the iteration space of the im2col matrix to (output_x, filter_x) + auto wIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0]; + // Compute the input indexing map, to map the indices of the im2col matrix to + // the original input offsets. Each element of the im2col matrix corresponds + // to a pair of (out_element, filter_element). First, we build the expressions + // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs; + // then we compose them with the maps that map the im2col matrix elements to + // the (out_element, filter_element) pairs. + auto bIndexExpr = rewriter.getAffineDimExpr(0U); + auto hIndexExpr = getConvolvedExpr(rewriter, strides[0], + /*useSymbols*/ false); + hIndexExpr = hIndexExpr.compose(hIndicesMap); + auto wIndexExpr = getConvolvedExpr(rewriter, strides[1], + /*useSymbols*/ false); + wIndexExpr = wIndexExpr.compose(wIndicesMap); + auto cIndexExpr = exprs.icIndex; + return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}; } FailureOr<std::pair<Operation *, Operation *>> @@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef<int64_t>{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; + SmallVector<AffineMap> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel); - SmallVector<AffineMap, 4> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + // Recover the original iteration indices from the problem/input sizes: + // given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U), + ArrayRef<int64_t>{fh * fw, fw, 1}); + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.icIndex = kIndicesExprs[0]; + i2cToOperExprs.fhIndex = kIndicesExprs[1]; + i2cToOperExprs.fwIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex, + inExprs.hIndex, inExprs.wIndex}}, + rewriter.getContext())[0]; + // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] + SmallVector<AffineMap> img2colIndexingMaps = { + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw}); - auto icIndex = kIndices[0]; - auto fhIndex = kIndices[1]; - auto fwIndex = kIndices[2]; - - SmallVector<Value> nIndices = unrollIndex( - nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = nIndices[0]; - auto owIndex = nIndices[1]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] - SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Value reshapedOutput = tensor::CollapseShapeOp::create( rewriter, loc, reshapedOutputType, output, outputReassocIndices); + // Shape of the Toeplitz matrix produced by Im2col. SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); @@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef<int64_t>{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; SmallVector<AffineMap> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because we didn't transpose the filters we don't actually have a batched diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 76ddee4..2ff7f46 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -75,7 +75,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, // layout for best compatibility. Value toBuffer = bufferization::ToBufferOp::create( b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), - tensorSource, /*readOnly=*/true); + tensorSource, /*read_only=*/true); memref::CopyOp::create(b, loc, toBuffer, memrefDest); } break; case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { @@ -84,7 +84,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, // layout for best compatibility. Value toBuffer = bufferization::ToBufferOp::create( b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), - tensorSource, /*readOnly=*/true); + tensorSource, /*read_only=*/true); linalg::CopyOp::create(b, loc, toBuffer, memrefDest); } break; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0a9c176..40085a2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "llvm/ADT/SetOperations.h" @@ -1236,6 +1238,272 @@ private: ControlPropagationFn controlFn; }; +// This struct contains infomation about extract_slice dims. +struct SliceDimInfo { + OpFoldResult offset; + OpFoldResult sliceSize; + OpFoldResult outputSize; +}; + +/// Return the first input extract slice operand, if present, for the current +/// generic op. +static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) { + OpOperand *sliceOperand = nullptr; + for (auto operand : genericOp.getDpsInputOperands()) { + auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>(); + if (!extractOp) + continue; + sliceOperand = operand; + break; + } + if (!sliceOperand) { + return failure(); + } + return sliceOperand; +} + +// Return a map of dims that have partial slices on them so that other operands +// can use this information. Also return a bool mentioning if a reduction dim +// has a non full slice as that can be used to fold the original extract slice. +static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>> +getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) { + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap; + SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes(); + + SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult( + genericOp.getContext(), producerSliceOp.getSourceType().getShape()); + + for (auto [idx, expr] : llvm::enumerate( + genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + // If we have a full slice in a dimension then we dont need to add it to + // the partial slice map. + if (isConstantIntValue(offsets[idx], 0) && + isEqualConstantIntOrValue(sizes[idx], shape[idx])) { + continue; + } + // We only support partial slices of AffineDimExprs so bail-out if thats not + // the case. + if (!isa<AffineDimExpr>(expr)) { + return failure(); + } + SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; + int64_t dimPos = cast<AffineDimExpr>(expr).getPosition(); + partialSliceDimMap[dimPos] = sliceDimInfo; + } + // Next check if the dims with partial slice info are used in non + // AffineDimExpr in other operands and if they are then bail-out. + for (OpOperand &operand : genericOp->getOpOperands()) { + if (operand == *sliceOperand) { + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + if (isa<AffineDimExpr>(expr)) { + return false; + } + WalkResult status = expr.walk([&](AffineExpr expr) { + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { + if (partialSliceDimMap.contains(dimExpr.getPosition())) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return true; + } + return false; + })) { + return failure(); + } + } + return partialSliceDimMap; +} + +static FailureOr<std::tuple<GenericOp, Value>> +pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ControlPropagationFn controlFn) { + if (genericOp.getNumResults() != 1) + return rewriter.notifyMatchFailure( + genericOp, "propagation through multi-result generic is unsupported."); + if (hasGatherSemantics(genericOp)) + return rewriter.notifyMatchFailure( + genericOp, + "propagation through generic with gather semantics is unsupported."); + // Collect the sliced operand, if present. + auto maybeSliceOperand = getSliceOperand(genericOp); + if (failed(maybeSliceOperand)) + return failure(); + OpOperand *sliceOperand = *maybeSliceOperand; + unsigned OperandIndex = sliceOperand->getOperandNumber(); + + if (!controlFn(sliceOperand)) + return failure(); + + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + + if (producerSliceOp.getSource().getType().getRank() != + producerSliceOp.getResult().getType().getRank()) { + return rewriter.notifyMatchFailure( + genericOp, + "propagation of rank-reducing extract slice is unsupported."); + } + + SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides(); + if (!areAllConstantIntValue(strides, 1)) + return rewriter.notifyMatchFailure( + genericOp, "propagation of strided extract slice is unsupported."); + + // check if we can support the propagation of this extractSlice + // through the generic op and if so return the dimensions that + + auto maybePartialSliceDimMap = + getPartialSliceDimInfo(genericOp, sliceOperand); + + if (failed(maybePartialSliceDimMap)) { + return failure(); + } + + auto partialSliceDimMap = *maybePartialSliceDimMap; + + SmallVector<utils::IteratorType> iterators = + genericOp.getIteratorTypesArray(); + bool hasPartialReductionDimSlice = + llvm::any_of(partialSliceDimMap, [&](const auto &slice) { + int64_t sliceDim = slice.first; + return iterators[sliceDim] == utils::IteratorType::reduction; + }); + + // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). + Location loc = genericOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + MLIRContext *ctx = genericOp.getContext(); + SmallVector<Value> paddedInputs; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + if (idx == OperandIndex && !hasPartialReductionDimSlice) { + paddedInputs.push_back(producerSliceOp.getSource()); + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { + if (!isa<AffineDimExpr>(expr)) { + continue; + } + AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + } + auto paddingValue = ub::PoisonOp::create( + rewriter, loc, getElementTypeOrSelf(operand->get().getType())); + auto paddedOperand = tensor::PadOp::create( + rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, + paddingValue, /*nofold=*/false); + paddedInputs.push_back(paddedOperand); + } + AffineMap outputIndexingMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + + auto outputShapeType = + llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType()); + SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector( + outputShapeType.getShape(), + [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); + SmallVector<OpFoldResult> newSizes = OutputShape; + SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 1)); + for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { + if (!isa<AffineDimExpr>(expr)) { + continue; + } + AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; + } + Value newPadOutput; + auto outputElType = + getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); + if (isGenericOutsNotUsed(genericOp)) { + newPadOutput = + tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); + } else { + auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); + newPadOutput = tensor::PadOp::create( + rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, + outputHighPads, paddingValue, /*nofold=*/false); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, + genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + + auto extractOp = tensor::ExtractSliceOp::create( + rewriter, loc, + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + outputLowPads, newSizes, newStrides); + Value extractRes = extractOp.getResult(); + + return std::make_tuple(newGenericOp, extractRes); +} + +class PushDownExtractSliceOpThroughGenericOp final + : public OpRewritePattern<GenericOp> { +public: + PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( @@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( patterns.getContext(), controlPackUnPackPropagation); } + +void mlir::linalg::populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert<PushDownExtractSliceOpThroughGenericOp>( + patterns.getContext(), controlPackUnPackPropagation); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bf66ed0..22690da 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { auto newResultType = RankedTensorType::get( newResultShape, padOp.getResultType().getElementType()); - auto newPadOp = rewriter.create<tensor::PadOp>( - padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, - newHighPad, paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource, + newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); if (options.rankReductionStrategy == @@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { static bool constexpr reduceLeft = (std::is_same_v<FromOpTy, BatchMatmulOp> && std::is_same_v<ToOpTy, BatchVecmatOp>) || - (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> && - std::is_same_v<ToOpTy, BatchVecmatOp>) || (std::is_same_v<FromOpTy, MatmulOp> && std::is_same_v<ToOpTy, VecmatOp>) || - (std::is_same_v<FromOpTy, MatmulTransposeAOp> && - std::is_same_v<ToOpTy, VecmatOp>) || (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>); /// Look for non-batch spatial dims to collapse. @@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( MLIRContext *context = patterns.getContext(); // Unbatching patterns for unit batch size patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>( - context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>( - context); patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context); patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context); // Non-batch rank 1 reducing patterns patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context); patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context); // Batch rank 1 reducing patterns patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context); patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>( - context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>( - context); // Non-batch rank 0 reducing patterns patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index c523153..baf4083 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -20,13 +20,26 @@ namespace mlir { using namespace mlir; +static inline bool isScalarLike(Type t) { + return isa<IntegerType, FloatType, IndexType, ComplexType>(t); +} + static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { if (!OpTrait::hasElementwiseMappableTraits(op)) return false; - // TODO: The conversion pattern can be made to work for `any_of` here, but - // it's more complex as it requires tracking which operands are scalars. - return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); + auto types = op->getOperandTypes(); + + // We want at least one ranked tensor. + bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>); + + // No invalid operands (i.e., every operand is a ranked tensor or + // scalar-like). + bool noneInvalid = llvm::none_of(types, [](Type t) { + return !(isa<RankedTensorType>(t) || isScalarLike(t)); + }); + + return anyRankedTensor && noneInvalid; } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over @@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank(); - SmallVector<AffineMap, 3> indexingMaps( - op->getNumResults() + op->getNumOperands(), - rewriter.getMultiDimIdentityMap(rank)); - SmallVector<utils::IteratorType, 6> iteratorTypes( + auto resTy = cast<RankedTensorType>(op->getResult(0).getType()); + auto rank = resTy.getRank(); + + // Maps: identity for tensors (rank > 0), scalar map for scalars. + AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, + /*results=*/{}, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(rank); + + // Match phase. + SmallVector<bool> isScalarOperand; + isScalarOperand.reserve(op->getNumOperands()); + for (Type ty : op->getOperandTypes()) { + if (isScalarLike(ty)) + isScalarOperand.push_back(true); + else if (auto rt = dyn_cast<RankedTensorType>(ty)) + isScalarOperand.push_back(false); + else + return rewriter.notifyMatchFailure( + op, + "unsupported operand type (expected scalar-like or ranked tensor)"); + } + + // Create indexing maps. + SmallVector<AffineMap> indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + for (bool isScalar : isScalarOperand) + indexingMaps.push_back(isScalar ? scalarMap : idMap); + + indexingMaps.append(op->getNumResults(), idMap); + + SmallVector<utils::IteratorType> iteratorTypes( rank, utils::IteratorType::parallel); - auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); + SmallVector<Value> outputs = + getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp<linalg::GenericOp>( op, /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/op->getOperands(), @@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { /*iteratorTypes=*/iteratorTypes, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { - auto resultTypes = llvm::to_vector<6>( + SmallVector<Type> resultEltTys = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { return cast<TensorType>(type).getElementType(); })); - auto *scalarOp = + Operation *scalarOp = builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), - resultTypes, op->getAttrs()); + resultEltTys, op->getAttrs()); linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index fd530f2..9436f1c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( auto clonedForOp = scf::ForOp::create( rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), - bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); + bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Map the induction var, region args and results to the `clonedForOp`. bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 58986a6..36434cf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp newLoop = scf::ForOp::create( rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), - loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loop.getUnsignedCmp()); // Generate the new yield with the replaced operand. auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); @@ -165,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, Value source = transferRead.getBase(); // Skip view-like Ops and retrive the actual soruce Operation - while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>()) - source = srcOp.getViewSource(); + while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) { + if (viewLike.getViewDest() != source) { + break; + } + source = viewLike.getViewSource(); + } llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), source.getUsers().end()); @@ -177,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, if (!processed.insert(user).second) continue; if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { - users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp new file mode 100644 index 0000000..f261ccb --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp @@ -0,0 +1,62 @@ +//===- MorphOps.cpp - conversion between named,category and generic ops ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements conversions between linalg ops: +// named <--> category (elementwise, contraction, ..) <--> generic. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_LINALGMORPHOPSPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "linalg-morphism" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct LinalgMorphOpsPass + : public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> { + + using impl::LinalgMorphOpsPassBase< + LinalgMorphOpsPass>::LinalgMorphOpsPassBase; + + void runOnOperation() override; +}; + +void LinalgMorphOpsPass::runOnOperation() { + + RewritePatternSet patterns(&getContext()); + + // Lowering paths (named -> category -> generic) + if (namedToCategory) { + populateLinalgNamedToElementwisePatterns(patterns); + } + if (namedToGeneric || categoryToGeneric) { + populateLinalgNamedOpsGeneralizationPatterns(patterns); + } + + // Lifting paths (named <- category <- generic) + if (genericToNamed) { + populateLinalgGenericOpsSpecializationPatterns(patterns); + } + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp new file mode 100644 index 0000000..00a076b --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp @@ -0,0 +1,98 @@ +//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewriting those linalg named ops that are essentially +// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further +// optimization on `linalg.elementwise` such as folding transpose, broadcast. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::linalg; + +#define DEBUG_TYPE "linalg-named-to-elementwise" + +namespace { +ElementwiseKind getKind(Operation *op) { + return llvm::TypeSwitch<Operation *, ElementwiseKind>(op) + .Case([](SelectOp) { return ElementwiseKind::select; }) + .Case([](AddOp) { return ElementwiseKind::add; }) + .Case([](SubOp) { return ElementwiseKind::sub; }) + .Case([](MulOp) { return ElementwiseKind::mul; }) + .Case([](DivOp) { return ElementwiseKind::div; }) + .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; }) + .Case([](PowFOp) { return ElementwiseKind::powf; }) + .Case([](ExpOp) { return ElementwiseKind::exp; }) + .Case([](LogOp) { return ElementwiseKind::log; }) + .Case([](AbsOp) { return ElementwiseKind::abs; }) + .Case([](CeilOp) { return ElementwiseKind::ceil; }) + .Case([](FloorOp) { return ElementwiseKind::floor; }) + .Case([](NegFOp) { return ElementwiseKind::negf; }) + .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; }) + .Case([](RoundOp) { return ElementwiseKind::round; }) + .Case([](SqrtOp) { return ElementwiseKind::sqrt; }) + .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; }) + .Case([](SquareOp) { return ElementwiseKind::square; }) + .Case([](TanhOp) { return ElementwiseKind::tanh; }) + .Case([](ErfOp) { return ElementwiseKind::erf; }) + .Default([&](Operation *op) { + llvm_unreachable("unhandled case in named to elementwise"); + return ElementwiseKind::sub; + }); +} + +template <typename NamedOpTy> +struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> { + using OpRewritePattern<NamedOpTy>::OpRewritePattern; + + LogicalResult matchAndRewrite(NamedOpTy op, + PatternRewriter &rewriter) const override { + SmallVector<NamedAttribute> attrs; + auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op)); + attrs.push_back(rewriter.getNamedAttr("kind", kindAttr)); + attrs.push_back( + rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps())); + + rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(), + op.getDpsInits(), attrs); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateLinalgNamedToElementwisePatterns( + RewritePatternSet &patterns) { + patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext()); + patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2e62523..8942670 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" @@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, Value paddingValue; if (auto complexTy = dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) { - auto complexAttr = cast<ArrayAttr>(paddingValueAttr); - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), - complexTy, complexAttr); - } else { - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), - cast<TypedAttr>(paddingValueAttr)); + if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) { + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); + } + } else if (isa<ub::PoisonAttr>(paddingValueAttr)) { + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + getElementTypeOrSelf(v.getType())); + } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) { + paddingValue = + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); } + assert(paddingValue && "failed to create value from padding attribute"); // Pad the operand to the bounding box defined by `paddedShape`. SmallVector<int64_t> tensorShape; @@ -257,11 +263,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, paddingValue, /*nofold=*/false, dynDims); } -FailureOr<TilingInterface> -linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector<tensor::PadOp> &padOps, - PadSizeComputationFunction computePaddingSizeFun) { +FailureOr<TilingInterface> linalg::rewriteAsPaddedOp( + RewriterBase &rewriter, TilingInterface opToPad, + const PadTilingInterfaceOptions &constOptions, + SmallVector<tensor::PadOp> &padOps, + const PadSizeComputationFunction &computePaddingSizeFun) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad.getLoc(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index a2bd9d9..27ccf3c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -21,7 +21,7 @@ #include "llvm/ADT/TypeSwitch.h" namespace mlir { -#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS +#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir @@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp } }; -struct LinalgNamedOpConversionPass - : public impl::LinalgNamedOpConversionPassBase< - LinalgNamedOpConversionPass> { - using impl::LinalgNamedOpConversionPassBase< - LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase; +struct SimplifyDepthwiseConvPass + : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> { + using impl::SimplifyDepthwiseConvPassBase< + SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase; void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateLinalgNamedOpConversionPatterns(patterns); + populateSimplifyDepthwiseConvPatterns(patterns); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; } // namespace -void mlir::linalg::populateLinalgNamedOpConversionPatterns( +void mlir::linalg::populateSimplifyDepthwiseConvPatterns( RewritePatternSet &patterns) { patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>( patterns.getContext()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 455e1a6..35ba4f15 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, /// Codegen the different matmul variants. if (numOfBatchDims) { - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter, - genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter, - genericOp); return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp); } - - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp); return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index bb725f2..e9a8b25 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -29,6 +29,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/raw_ostream.h" #include <utility> @@ -38,9 +39,6 @@ using namespace mlir; using namespace mlir::linalg; -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") - //===----------------------------------------------------------------------===// // Transformations exposed as functional-style API calls. //===----------------------------------------------------------------------===// @@ -91,11 +89,11 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { } return true; } +#endif // NDEBUG static std::string stringifyReassocIndices(ReassociationIndicesRef ri) { return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/""); } -#endif // NDEBUG /// Return the index of the first result of `map` that is a function of /// AffineDimExpr(dim), std::nullopt otherwise. @@ -276,23 +274,18 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows, highs, paddingValue, /*nofold=*/false); - LLVM_DEBUG( - DBGSNL(); DBGSNL(); - DBGS() << "insertPositions: " - << llvm::interleaved(packingMetadata.insertPositions); - DBGSNL(); DBGS() << "outerPositions: " - << llvm::interleaved(packingMetadata.outerPositions); - DBGSNL(); DBGS() << "packedShape: " - << llvm::interleaved(packedTensorType.getShape()); - DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " - << llvm::interleaved(packedToStripMinedShapePerm); - DBGSNL(); - DBGS() << "reassociations: " - << llvm::interleaved(llvm::map_range( - packingMetadata.reassociations, stringifyReassocIndices)); - DBGSNL(); - DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); - DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); + LDBG() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + LDBG() << "outerPositions: " + << llvm::interleaved(packingMetadata.outerPositions); + LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape()); + LDBG() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); + LDBG() << "reassociations: " + << llvm::interleaved(llvm::map_range(packingMetadata.reassociations, + stringifyReassocIndices)); + LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); + LDBG() << "collapsed type: " << collapsed; if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal @@ -317,7 +310,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(), /*offsets=*/zeros, sizes, /*strides=*/ones); - LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); + LDBG() << "insert_slice op: " << insertSliceOp; rewriter.replaceOp(packOp, insertSliceOp->getResults()); @@ -339,10 +332,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, auto transposeOp = linalg::TransposeOp::create( rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - DBGS() << "transpPerm: " << llvm::interleaved(transpPerm); - DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); + LDBG() << "reshape op: " << reshapeOp; + LDBG() << "transpPerm: " << llvm::interleaved(transpPerm); + LDBG() << "transpose op: " << transposeOp; // 7. Replace packOp by transposeOp. rewriter.replaceOp(packOp, transposeOp->getResults()); @@ -410,21 +402,16 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); - LLVM_DEBUG( - DBGSNL(); DBGSNL(); - DBGS() << "insertPositions: " - << llvm::interleaved(packingMetadata.insertPositions); - DBGSNL(); DBGS() << "packedShape: " - << llvm::interleaved(packedTensorType.getShape()); - DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " - << llvm::interleaved(packedToStripMinedShapePerm); - DBGSNL(); - DBGS() << "reassociations: " - << llvm::interleaved(llvm::map_range( - packingMetadata.reassociations, stringifyReassocIndices)); - DBGSNL(); - DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); - DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); + LDBG() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape()); + LDBG() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); + LDBG() << "reassociations: " + << llvm::interleaved(llvm::map_range(packingMetadata.reassociations, + stringifyReassocIndices)); + LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); + LDBG() << "collapsed type: " << collapsedType; // 4. Collapse from the stripMinedShape to the padded result. auto reshapeOp = tensor::CollapseShapeOp::create( @@ -486,10 +473,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector<utils::IteratorType> iteratorTypes = linalgOp.getIteratorTypesArray(); - LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n" - << "maps: " << llvm::interleaved(indexingMaps) << "\n" - << "iterators: " << llvm::interleaved(iteratorTypes) - << "\n"); + LDBG() << "Start packing: " << linalgOp; + LDBG() << "maps: " << llvm::interleaved(indexingMaps); + LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); SmallVector<linalg::PackOp> packOps; SmallVector<linalg::UnPackOp> unPackOps; @@ -511,14 +497,11 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); - LLVM_DEBUG( - DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] - << "\n" - << "maps: " << llvm::interleaved(indexingMaps) << "\n" - << "iterators: " << llvm::interleaved(iteratorTypes) << "\n" - << "packedDimForEachOperand: " - << llvm::interleaved(packedOperandsDims.packedDimForEachOperand) - << "\n"); + LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i]; + LDBG() << "maps: " << llvm::interleaved(indexingMaps); + LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); + LDBG() << "packedDimForEachOperand: " + << llvm::interleaved(packedOperandsDims.packedDimForEachOperand); } // Step 2. Propagate packing to all LinalgOp operands. @@ -534,10 +517,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, listOfPackedOperandsDim.extractPackedDimsForOperand(pos); SmallVector<OpFoldResult> innerPackSizes = listOfPackedOperandsDim.extractPackSizesForOperand(pos); - LLVM_DEBUG(DBGS() << "operand: " << operand << "\n" - << "innerPos: " << llvm::interleaved(innerPos) << "\n" - << "innerPackSizes: " - << llvm::interleaved(innerPackSizes) << "\n"); + LDBG() << "operand: " << operand; + LDBG() << "innerPos: " << llvm::interleaved(innerPos); + LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes); if (innerPackSizes.empty()) { inputsAndInits.push_back(operand); continue; @@ -776,8 +758,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { - LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " - << numLoops << "\nin: " << linalgOp << "\n"); + LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops + << " in: " << linalgOp; return rewriter.notifyMatchFailure( linalgOp, "need 3+ loops to find a matmul to pack"); } @@ -801,8 +783,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, FailureOr<ContractionDimensions> maybeDimensions = inferContractionDims(linalgOp); if (failed(maybeDimensions)) { - LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp - << "\n"); + LDBG() << "couldn't infer matmul iterators in: " << linalgOp; return rewriter.notifyMatchFailure(linalgOp, "couldn't infer matmul iterators"); } @@ -814,10 +795,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // to plug a heuristic. int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), kPos = maybeDimensions->k.back(); - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "Start packing generic op greedily with (m@" << mPos - << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp - << "\n";); + LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@" + << nPos << ", k@" << kPos << "): " << linalgOp; // 2.a. Rewrite as a generic. auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); @@ -833,14 +812,14 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // not change the indexings of any operand. SmallVector<int64_t> permutation = computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); - LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n"); + LDBG() << "perm: " << llvm::interleaved(permutation); // Sign .. unsigned pollution. SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); FailureOr<GenericOp> interchangeResult = interchangeGenericOp(rewriter, genericOp, unsignedPerm); assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); genericOp = *interchangeResult; - LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); + LDBG() << "Generalized Op to pack: " << genericOp; // At this point, the op iterators are normalized to {leading, k, m, n}. // The layouts induced by packing will always be: @@ -862,12 +841,11 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // Add leading zeros to match numLoops, we only pack the last 3 dimensions // post interchange. - LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: " - << llvm::interleaved(paddedSizesNextMultipleOf) << "\n" - << "loopRanges: " - << llvm::interleaved(llvm::map_range( - loopRanges, [](Range r) { return r.size; })) - << "\n"); + LDBG() << "paddedSizesNextMultipleOf: " + << llvm::interleaved(paddedSizesNextMultipleOf); + LDBG() << "loopRanges: " + << llvm::interleaved( + llvm::map_range(loopRanges, [](Range r) { return r.size; })); SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), rewriter.getIndexAttr(0)); for (int64_t i = 0, e = numPackedDims; i < e; ++i) { @@ -883,8 +861,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, {loopRanges[adjustedPackedSizes.size()].size, rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); } - LLVM_DEBUG(DBGS() << "adjustedPackedSizes: " - << llvm::interleaved(adjustedPackedSizes) << "\n"); + LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes); // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. One would still need to check that @@ -1214,9 +1191,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); - LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" - << "perm: " << llvm::interleaved(srcPermForTranspose) - << "\n"); + LDBG() << "Pack permutation: " << packOp; + LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); // 2.1 Create tensor.empty (init value for TransposeOp) SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index a2a4335..2650488 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{1, 0}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::MatmulTransposeAOp::create( + newMatmulOp = MatmulTransposeAOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { - newMatmulOp = linalg::MatmulTransposeBOp::create( + newMatmulOp = MatmulTransposeBOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); @@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::BatchMatmulTransposeAOp::create( + newMatmulOp = BatchMatmulTransposeAOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { - newMatmulOp = linalg::BatchMatmulTransposeBOp::create( + newMatmulOp = BatchMatmulTransposeBOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0860cea..406f05c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, inputShape[innerDimsPos[idx]] *= size; auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, packOp.getSource(), inputShape, padValue, - useInBoundsInsteadOfMasking); + useInBoundsInsteadOfMasking, + /*inputScalableVecSizes=*/{}); // Create ShapeCastOp. SmallVector<int64_t> destShape(inputVectorSizes); @@ -1878,19 +1879,46 @@ static VectorType getCollapsedVecType(VectorType type, return VectorType::get(newShape, type.getElementType(), newScalableFlags); } -/// Vectorize a `linalg::UnPackOp` to these 4 Ops: -/// Vector::TransferReadOp - Reads a vector from the source tensor -/// vector::TransposeOp - Transpose the Source tensor -/// ShapeCastOp - Reshape the data based on the target. -/// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor. -/// If the vector sizes are not provided: -/// * the vector sizes are determined by the input operand and attributes, -/// * update the inBounds attribute instead of masking. +/// Vectorize `linalg.unpack` as: +/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write +/// +/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes +/// for the xfer_read operation). This is sufficient to infer the other vector +/// sizes required here. +/// +/// If the vector sizes are not provided: +/// * the vector sizes are determined from the input tensor static shape. +/// * the inBounds attribute is used instead of masking. +/// +/// EXAMPLE (no vector sizes): +/// ``` +/// %unpack = linalg.unpack %src +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 8] +/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32> +/// ``` +/// is vectorized as: +/// ``` +/// %read = vector.transfer_read %src +/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32> +/// %tr = vector.transpose %read, [0, 2, 1, 3] +/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32> +/// %sc = vector.shape_cast %tr +/// : vector<1x8x1x8xf32> to vector<8x8xf32> +/// %vector = vector.transfer_write %sc into %dest +/// : vector<8x8xf32>, tensor<8x8xf32> +/// ``` static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef<int64_t> inputVectorSizes, + ArrayRef<bool> inputScalableVecDims, SmallVectorImpl<Value> &newResults) { + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == unpackOp.getSourceRank() && + "Invalid number of input vector sizes!"); + assert(inputVectorSizes.size() == inputScalableVecDims.size() && + "Incompatible number of vector sizes and vector scalable flags!"); + } // TODO: Introduce a parent class that will handle the insertion point update. OpBuilder::InsertionGuard g(rewriter); @@ -1898,88 +1926,40 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, RankedTensorType unpackTensorType = unpackOp.getSourceType(); - ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); - ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); bool useInBoundsInsteadOfMasking = false; - ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); - - auto destSize = unpackOp.getDestRank(); - - if (!inputVectorSizes.empty()) - assert(inputVectorSizes.size() == destSize && - "Incorrect number of input vector sizes"); - - // vectorSizes is the shape of the vector that will be used to do final - // write on the destination tensor. It is set like this: Let's say the - // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. - // Thus: - // 1. vectorSizes = sourceShape.take_front(N) - // 2. if outer_dims_perms is present: do that permutation on vectorSizes. - // 3. multiply all the locations in vectorSize pointed by innerDimPos by the - // innerTiles attribute value. - SmallVector<int64_t> vectorSizes(inputVectorSizes); - if (vectorSizes.empty()) { - llvm::append_range(vectorSizes, sourceShape.take_front(destSize)); - if (!outerDimsPerm.empty()) - applyPermutationToVector(vectorSizes, outerDimsPerm); - for (auto [i, pos] : llvm::enumerate(innerDimPos)) - vectorSizes[pos] *= innerTiles[i]; - useInBoundsInsteadOfMasking = true; - } + Location loc = unpackOp->getLoc(); - // readVectorSizes is the size of tensor used to read and apply mask. It is - // set like this: Let's say the vectorSize (VS) array is size 'N' and - // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of - // size M-N - // Thus: - // - initially: readVectorSizes = vectorInputSizes - // - Divide all the readMaskShape locations pointed by innerDimPos - // by the innerTileSize attribute value. - // - if outer_dims_perms is present: do that permutation on readVectorSizes. - // - Append the remaining shape from SS - // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> - // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, - // 128] and outer_dims_perm is [1, 0] then read shape is: - // ReadVectorSizes(initial): [512, 128] - // Final Value(after innerDim Adjustment): [512/32, 128/16] - // = [16, 8] - // After applying outer_dims_perm: [8, 16] - // After appending the rest of the sourceShape: [8, 16, 32, 16] - - SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end()); - - for (auto [index, size] : enumerate(innerTiles)) { - readVectorSizes[innerDimPos[index]] = - llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); - } - if (!outerDimsPerm.empty()) { - applyPermutationToVector(readVectorSizes, outerDimsPerm); - } - readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), - sourceShape.end()); + // Obtain vector sizes for the read operation. + SmallVector<int64_t> readVectorSizes(inputVectorSizes); + SmallVector<bool> readScalableVectorFlags(inputScalableVecDims); - Location loc = unpackOp->getLoc(); + // In the absence of input-vector-sizes, use the _static_ input tensor shape. + if (inputVectorSizes.empty()) { + if (ShapedType::isDynamicShape(sourceShape)) + return failure(); + readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); + useInBoundsInsteadOfMasking = true; + } + + // -- Generate the read operation -- auto padValue = arith::ConstantOp::create( rewriter, loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); - - // Read result, mask if necessary. If transferReadOp shape is not equal - // to shape of source, then a mask is necessary. Value readResult = vector::createReadOrMaskedRead( rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false); + useInBoundsInsteadOfMasking, readScalableVectorFlags); + // -- Generate the transpose operation -- PackingMetadata packMetadata; SmallVector<int64_t> lastDimToInsertPosPerm = getUnPackInverseSrcPerm(unpackOp, packMetadata); - // Transpose the appropriate rows to match output. vector::TransposeOp transposeOp = vector::TransposeOp::create( rewriter, loc, readResult, lastDimToInsertPosPerm); - // Collapse the vector to the size required by result. + // -- Generate the shape_cast operation -- VectorType collapsedVecType = getCollapsedVecType( transposeOp.getType(), getSymbolLessAffineMaps(convertReassociationIndicesToExprs( @@ -1987,9 +1967,11 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( rewriter, loc, collapsedVecType, transposeOp->getResult(0)); + // -- Generate the write operation -- Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); + newResults.push_back(write->getResult(0)); return success(); } @@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, assert(succeeded(status) && "failed to reify result shapes"); auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false); + /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, return success(); } -/// Need to check if the inner-tiles are static/constant. +//// This hook considers two cases: +/// (1) If the input-vector-sizes are empty, then the vector sizes will be +/// infered. This is only possible when all shapes are static. +/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then +/// carry out basic sanity-checking. static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef<int64_t> inputVectorSizes) { + // If there are no input vector sizes and all shapes are static, there is + // nothing left to check. + if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() && + unpackOp.getSourceType().hasStaticShape()) + return success(); - if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { - return !getConstantIntValue(res).has_value(); - })) { - LDBG() << "Inner-tiles must be constant: " << unpackOp; + // The number of input vector sizes must be equal to: + // * read-vector-rank + if (!inputVectorSizes.empty() && + (inputVectorSizes.size() != unpackOp.getSourceRank())) { + LDBG() << "Incorrect number of input vector sizes"; return failure(); } - ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); - bool satisfyEmptyCond = inputVectorSizes.empty() && - unpackOp.getDestType().hasStaticShape() && - unpackOp.getSourceType().hasStaticShape(); - if (!satisfyEmptyCond && - failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes))) + + // Check the vector sizes for the read operation. + if (failed(vector::isValidMaskedInputVector( + unpackOp.getSourceType().getShape(), inputVectorSizes))) { + LDBG() << "Invalid vector sizes for the read operation"; return failure(); + } return success(); } @@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, LDBG() << "pad value is not constant: " << packOp; return failure(); } + ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); bool satisfyEmptyCond = true; if (inputVectorSizes.empty()) { @@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return success(); } -/// Preconditions for scalable vectors. This is quite restrictive - it models -/// the fact that in practice we would only make selected dimensions scalable. +/// Preconditions for scalable vectors. +/// +/// For Ops implementing the LinalgOp interface, this is quite restrictive - it +/// models the fact that in practice we would only make selected dimensions +/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed +/// unconditionally - we are yet to identify meaningful conditions. static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef<int64_t> inputVectorSizes, @@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op, auto linalgOp = dyn_cast<LinalgOp>(op); - // Cond 1: There's been no need for scalable vectorisation of - // non-linalg Ops so far - if (!linalgOp) - return failure(); + // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the + // exception of UnpackOp for which there is a dedicated hook. + if (!linalgOp) { + return success(isa<linalg::UnPackOp>(op)); + } // Cond 2: There's been no need for more than 2 scalable dims so far if (numOfScalableDims > 2) @@ -2565,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op, "vectorization"; return failure(); } - if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { + if (isa<linalg::MatmulOp>(op)) { LDBG() << "Scalable vectorization of the reduction dim in Matmul-like ops " "is not supported"; @@ -2606,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op, return failure(); } - // Check to not let go the matmul with extended semantic, through this - // transform. - if (linalgOp.hasUserDefinedMaps()) - return failure(); - // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || - isa<linalg::MatmulTransposeAOp>(op) || isa<linalg::DepthwiseConv1DNwcWcOp>(op) || isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || + isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp)); } @@ -2750,7 +2743,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( }) .Case<linalg::UnPackOp>([&](auto unpackOp) { return vectorizeAsTensorUnpackOp(rewriter, unpackOp, - inputVectorSizes, results); + inputVectorSizes, + inputScalableVecDims, results); }) .Case<tensor::InsertSliceOp>([&](auto sliceOp) { return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, @@ -3142,7 +3136,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), + /*inputScalableVecSizes=*/{}); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index e1c0c24..d37a056 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp - ExpandPatterns.cpp + ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp index 4a40a30..cd68039 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp @@ -13,14 +13,18 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +namespace mlir::math { +#define GEN_PASS_DEF_MATHEXPANDOPSPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + /// Create a float constant. static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b) { @@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, return success(); } -void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { - patterns.add(convertCtlzOp); -} - -void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { - patterns.add(convertSinhOp); -} - -void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { - patterns.add(convertCoshOp); -} - -void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { - patterns.add(convertTanOp); -} - -void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { - patterns.add(convertTanhOp); -} - -void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { - patterns.add(convertAsinhOp); -} - -void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { - patterns.add(convertAcoshOp); -} - -void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { - patterns.add(convertAtanhOp); -} - -void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { - patterns.add(convertFmaFOp); -} - -void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { - patterns.add(convertCeilOp); -} - -void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { - patterns.add(convertExp2fOp); -} - -void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertPowfOp); -} - -void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { - patterns.add(convertFPowIOp); -} - -void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundOp); +// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf` +static LogicalResult convertClampfOp(math::ClampFOp op, + PatternRewriter &rewriter) { + auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(), + op.getMin(), op.getFastmath()); + rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(), + op.getFastmath()); + return success(); } -void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundEvenOp); +void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef<StringRef> opMnemonics) { + auto filter = [&](StringRef name) { + // This should be a static assert and `consume_front` take a twine, but none + // is currently possible. TODO: augment `StringRef::consume_front` and make + // `getDialectNamespace` use `std::string_view`. + assert("math" == MathDialect::getDialectNamespace()); + name.consume_front("math."); + return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0); + }; + if (filter(CountLeadingZerosOp::getOperationName())) + patterns.add(convertCtlzOp); + if (filter(SinhOp::getOperationName())) + patterns.add(convertSinhOp); + if (filter(CoshOp::getOperationName())) + patterns.add(convertCoshOp); + if (filter(TanOp::getOperationName())) + patterns.add(convertTanOp); + if (filter(TanhOp::getOperationName())) + patterns.add(convertTanhOp); + if (filter(AsinhOp::getOperationName())) + patterns.add(convertAsinhOp); + if (filter(AcoshOp::getOperationName())) + patterns.add(convertAcoshOp); + if (filter(AtanhOp::getOperationName())) + patterns.add(convertAtanhOp); + if (filter(FmaOp::getOperationName())) + patterns.add(convertFmaFOp); + if (filter(CeilOp::getOperationName())) + patterns.add(convertCeilOp); + if (filter(Exp2Op::getOperationName())) + patterns.add(convertExp2fOp); + if (filter(PowFOp::getOperationName())) + patterns.add(convertPowfOp); + if (filter(FPowIOp::getOperationName())) + patterns.add(convertFPowIOp); + if (filter(RoundOp::getOperationName())) + patterns.add(convertRoundOp); + if (filter(RoundEvenOp::getOperationName())) + patterns.add(convertRoundEvenOp); + if (filter(RsqrtOp::getOperationName())) + patterns.add(convertRsqrtOp); + if (filter(ClampFOp::getOperationName())) + patterns.add(convertClampfOp); } -void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { - patterns.add(convertRsqrtOp); -} +//===----------------------------------------------------------------------===// +// MathExpandOpsPass pass +//===----------------------------------------------------------------------===// +namespace { +struct MathExpandOpsPass final + : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> { + using MathExpandOpsPassBase::MathExpandOpsPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SmallVector<StringRef> mnemonics = + llvm::to_vector_of<StringRef>(opMnemonics); + math::populateExpansionPatterns(patterns, mnemonics); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 74b968c..b59d73d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() { case arith::AtomicRMWKind::minu: case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: + case arith::AtomicRMWKind::xori: case arith::AtomicRMWKind::andi: if (!llvm::isa<IntegerType>(getValue().getType())) return emitOpError() << "with kind '" diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp index bbb269b..1939195 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -21,9 +21,9 @@ namespace { struct ReallocOpInterface : public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface, ReallocOp> { - void - populateDependencies(Operation *op, - RegisterDependenciesFn registerDependenciesFn) const { + void populateDependencies( + Operation *op, + const RegisterDependenciesFn ®isterDependenciesFn) const { auto reallocOp = cast<ReallocOp>(op); // memref.realloc may return the source operand. registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 9771bd2..d35566a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); - if (!viewLikeOp) + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 5d3cec4..860384f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) { /// propagate the type change and erase old subview ops. static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val) { - SmallVector<Operation *> opsToDelete; - SmallVector<OpOperand *> operandsToReplace; - - // Save the operand to replace / delete later (avoid iterator invalidation). - // TODO: can we use an early_inc iterator? - for (OpOperand &use : oldOp->getUses()) { - // Non-subview ops will be replaced by `val`. - auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner()); - if (!subviewUse) { - operandsToReplace.push_back(&use); + // Iterate with early_inc to erase current user inside the loop. + for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { + Operation *user = use.getOwner(); + if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) { + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subviewUse); + MemRefType newType = memref::SubViewOp::inferRankReducedResultType( + subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); + Value newSubview = memref::SubViewOp::create( + rewriter, subviewUse->getLoc(), newType, val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); + + // Ouch recursion ... is this really necessary? + replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); + + // Safe to erase. + rewriter.eraseOp(subviewUse); continue; } - - // `subview(old_op)` is replaced by a new `subview(val)`. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(subviewUse); - MemRefType newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), - subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), - subviewUse.getStaticStrides()); - Value newSubview = memref::SubViewOp::create( - rewriter, subviewUse->getLoc(), newType, val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); - - // Ouch recursion ... is this really necessary? - replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); - - opsToDelete.push_back(use.getOwner()); + // Non-subview: replace with new value. + rewriter.startOpModification(user); + use.set(val); + rewriter.finalizeOpModification(user); } - - // Perform late replacement. - // TODO: can we use an early_inc iterator? - for (OpOperand *operand : operandsToReplace) { - Operation *op = operand->getOwner(); - rewriter.startOpModification(op); - operand->set(val); - rewriter.finalizeOpModification(op); - } - - // Perform late op erasure. - // TODO: can we use an early_inc iterator? - for (Operation *op : opsToDelete) - rewriter.eraseOp(op); } // Transformation to do multi-buffering/array expansion to remove dependencies @@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); - // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to - // handle dealloc uses separately.. + // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need + // to handle dealloc uses separately.. for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner()); if (!deallocOp) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 5af46a4..3de9c38 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) { MemrefValue skipViewLikeOps(MemrefValue source) { while (auto op = source.getDefiningOp()) { if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) { - source = cast<MemrefValue>(viewLike.getViewSource()); - continue; + if (source == viewLike.getViewDest()) { + source = cast<MemrefValue>(viewLike.getViewSource()); + continue; + } } return source; } diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index cc03974..8474244 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -345,6 +345,19 @@ LogicalResult LdMatrixOp::verify() { // NVGPU_TmaAsyncLoadOp //===----------------------------------------------------------------------===// +unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { + switch (kind) { + case TensorMapSwizzleKind::SWIZZLE_32B: + return 32; + case TensorMapSwizzleKind::SWIZZLE_64B: + return 64; + case TensorMapSwizzleKind::SWIZZLE_128B: + return 128; + default: + return 0; + } +} + std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( Operation *op, nvgpu::TensorMapDescriptorType descType, std::optional<MemRefType> memrefType = std::nullopt) { @@ -373,10 +386,11 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) { unsigned lastDimensionByte = descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8; - if (lastDimensionByte != kMaxTMALastdimByte) + unsigned expectByte = getSwizzleBytes(descType.getSwizzle()); + if (lastDimensionByte != expectByte) return op->emitError() << "the tensormap descriptor must have last " "dimension of " - << kMaxTMALastdimByte << " bytes but it is " + << expectByte << " bytes but it is " << lastDimensionByte << " bytes"; } @@ -408,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( << descMemref << " != " << dstMemref; } + int lastDimBytes = + descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8; + if (lastDimBytes % 16 != 0) { + return op->emitError() << "the bytes in the last dimension of the tensor " + "map must be a multiple of 16"; + } return std::nullopt; } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 485bb73..ded4c7a 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -173,9 +173,7 @@ void OpenACCDialect::initialize() { //===----------------------------------------------------------------------===// static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { - if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) - return true; - return false; + return arrayAttr && *arrayAttr && arrayAttr->size() > 0; } static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, @@ -1390,6 +1388,36 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::ParallelOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + +void acc::ParallelOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + static ParseResult parseNumGangs( mlir::OpAsmParser &parser, llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, @@ -2041,6 +2069,36 @@ void acc::SerialOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::SerialOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + +void acc::SerialOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -3059,6 +3117,20 @@ void acc::LoopOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // DataOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767..6e43f28 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3874,6 +3874,159 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; + llvm::SmallVector<mlir::Type> typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. (<params> : <types>) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast<IntegerType>(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + // Check that region exists and is not empty + Region ®ion = getRegion(); + if (region.empty()) + return emitOpError("region cannot be empty"); + // Verify single entry point. + Block &entryBlock = region.front(); + if (entryBlock.empty()) + return emitOpError("region must contain a structured block"); + // Verify single exit point. + bool hasTerminator = false; + for (Block &block : region) { + if (isa<TerminatorOp>(block.back())) { + if (hasTerminator) { + return emitOpError("region must have exactly one terminator"); + } + hasTerminator = true; + } + } + if (!hasTerminator) { + return emitOpError("region must be terminated with omp.terminator"); + } + auto walkResult = region.walk([&](Operation *op) -> WalkResult { + // No implicit barrier at end + if (isa<BarrierOp>(op)) { + return emitOpError( + "explicit barriers are not allowed in workdistribute region"); + } + // Check for invalid nested constructs + if (isa<ParallelOp>(op)) { + return emitOpError( + "nested parallel constructs not allowed in workdistribute"); + } + if (isa<TeamsOp>(op)) { + return emitOpError( + "nested teams constructs not allowed in workdistribute"); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast<TeamsOp>(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 497468b..bd1e655 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -1,3 +1,22 @@ +set(LLVM_OPTIONAL_SOURCES + MemorySpaceInterfaces.cpp + PtrAttrs.cpp + PtrTypes.cpp + PtrDialect.cpp +) + +add_mlir_dialect_library( + MLIRPtrMemorySpaceInterfaces + MemorySpaceInterfaces.cpp + + DEPENDS + MLIRPtrOpsEnumsGen + MLIRPtrMemorySpaceInterfacesIncGen + LINK_LIBS + PUBLIC + MLIRIR +) + add_mlir_dialect_library( MLIRPtrDialect PtrAttrs.cpp @@ -15,4 +34,5 @@ add_mlir_dialect_library( MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces MLIRViewLikeInterface + MLIRPtrMemorySpaceInterfaces ) diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp new file mode 100644 index 0000000..059e67f --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp @@ -0,0 +1,15 @@ +//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ptr dialect memory space interfaces. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index 772d25d..ac3bcd6 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -22,26 +22,30 @@ constexpr const static unsigned kBitsInByte = 8; //===----------------------------------------------------------------------===// bool GenericSpaceAttr::isValidLoad( - Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidStore( - Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidAtomicOp( ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, - IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const { + std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidAtomicXchg( Type type, ptr::AtomicOrdering successOrdering, - ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c5ec0ca..d5976b9 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -85,6 +85,124 @@ LogicalResult FromPtrOp::verify() { } //===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +/// Verifies the attributes and the type of atomic memory access operations. +template <typename OpTy> +static LogicalResult +verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) { + if (memOp.getOrdering() != AtomicOrdering::not_atomic) { + if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering())) + return memOp.emitOpError("unsupported ordering '") + << stringifyAtomicOrdering(memOp.getOrdering()) << "'"; + if (!memOp.getAlignment()) + return memOp.emitOpError("expected alignment for atomic access"); + return success(); + } + if (memOp.getSyncscope()) { + return memOp.emitOpError( + "expected syncscope to be null for non-atomic access"); + } + return success(); +} + +/// Verifies that the alignment attribute is a power of 2 if present. +static LogicalResult +verifyAlignment(std::optional<int64_t> alignment, + function_ref<InFlightDiagnostic()> emitError) { + if (!alignment) + return success(); + if (alignment.value() <= 0) + return emitError() << "alignment must be positive"; + if (!llvm::isPowerOf2_64(alignment.value())) + return emitError() << "alignment must be a power of 2"; + return success(); +} + +void LoadOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable()); + // Volatile operations can have target-specific read-write effects on + // memory besides the one referred to by the pointer operand. + // Similarly, atomic operations that are monotonic or stricter cause + // synchronization that from a language point-of-view, are arbitrary + // read-writes into memory. + if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && + getOrdering() != AtomicOrdering::unordered)) { + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); + } +} + +LogicalResult LoadOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(), + &dataLayout, emitDiag)) + return failure(); + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + return verifyAtomicMemOp(*this, + {AtomicOrdering::release, AtomicOrdering::acq_rel}); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal, bool isInvariant, bool isInvariantGroup, + AtomicOrdering ordering, StringRef syncscope) { + build(builder, state, type, addr, + alignment ? std::optional<int64_t>(alignment) : std::nullopt, + isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering, + syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +void StoreOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable()); + // Volatile operations can have target-specific read-write effects on + // memory besides the one referred to by the pointer operand. + // Similarly, atomic operations that are monotonic or stricter cause + // synchronization that from a language point-of-view, are arbitrary + // read-writes into memory. + if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && + getOrdering() != AtomicOrdering::unordered)) { + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); + } +} + +LogicalResult StoreOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(), + &dataLayout, emitDiag)) + return failure(); + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + return verifyAtomicMemOp(*this, + {AtomicOrdering::acquire, AtomicOrdering::acq_rel}); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal, bool isInvariantGroup, + AtomicOrdering ordering, StringRef syncscope) { + build(builder, state, value, addr, + alignment ? std::optional<int64_t>(alignment) : std::nullopt, + isVolatile, isNonTemporal, isInvariantGroup, ordering, + syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); +} + +//===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -152,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" -#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc" - -#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" - #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt index 825d119..deb7109 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRQuantTransforms StripFuncQuantTypes.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms DEPENDS MLIRQuantTransformsIncGen diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0262a1b..84f9777 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -157,8 +157,7 @@ void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); - - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"}); } LogicalResult ExecuteRegionOp::verify() { @@ -318,9 +317,12 @@ void ConditionOp::getSuccessorRegions( void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, ValueRange initArgs, - BodyBuilderFn bodyBuilder) { + BodyBuilderFn bodyBuilder, bool unsignedCmp) { OpBuilder::InsertionGuard guard(builder); + if (unsignedCmp) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); result.addOperands({lb, ub, step}); result.addOperands(initArgs); for (Value v : initArgs) @@ -450,6 +452,9 @@ static void printInitializationList(OpAsmPrinter &p, } void ForOp::print(OpAsmPrinter &p) { + if (getUnsignedCmp()) + p << " unsigned"; + p << " " << getInductionVar() << " = " << getLowerBound() << " to " << getUpperBound() << " step " << getStep(); @@ -462,7 +467,8 @@ void ForOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!getInitArgs().empty()); - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/getUnsignedCmpAttrName().strref()); } ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { @@ -472,6 +478,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::Argument inductionVariable; OpAsmParser::UnresolvedOperand lb, ub, step; + if (succeeded(parser.parseOptionalKeyword("unsigned"))) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); + // Parse the induction variable followed by '='. if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || // Parse loop bounds. @@ -562,7 +572,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, inits.append(newInitOperands.begin(), newInitOperands.end()); scf::ForOp newLoop = scf::ForOp::create( rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, - [](OpBuilder &, Location, Value, ValueRange) {}); + [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp()); newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); // Generate the new yield values and append them to the scf.yield operation. @@ -806,7 +816,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, // 2. Create the new forOp shell. scf::ForOp newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterOperands); + forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), @@ -931,7 +942,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { scf::ForOp newForOp = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterArgs); + forOp.getUpperBound(), forOp.getStep(), newIterArgs, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -989,12 +1001,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { /// Util function that tries to compute a constant diff between u and l. /// Returns std::nullopt when the difference between two AffineValueMap is /// dynamic. -static std::optional<int64_t> computeConstDiff(Value l, Value u) { +static std::optional<APInt> computeConstDiff(Value l, Value u) { IntegerAttr clb, cub; if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { llvm::APInt lbValue = clb.getValue(); llvm::APInt ubValue = cub.getValue(); - return (ubValue - lbValue).getSExtValue(); + return ubValue - lbValue; } // Else a simple pattern match for x + c or c + x @@ -1003,7 +1015,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) { u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || matchPattern( u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) - return diff.getSExtValue(); + return diff; return std::nullopt; } @@ -1022,13 +1034,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { return success(); } - std::optional<int64_t> diff = + std::optional<APInt> diff = computeConstDiff(op.getLowerBound(), op.getUpperBound()); if (!diff) return failure(); // If the loop is known to have 0 iterations, remove it. - if (*diff <= 0) { + bool zeroOrLessIterations = + diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative()); + if (zeroOrLessIterations) { rewriter.replaceOp(op, op.getInitArgs()); return success(); } @@ -3384,9 +3398,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) - << "expected as many input types as operands " - << "(expected " << operands.size() << " got " - << functionType.getNumInputs() << ")"; + << "expected as many input types as operands " << "(expected " + << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. @@ -4222,14 +4235,15 @@ LogicalResult scf::IndexSwitchOp::verify() { << "see yield operation here"; } for (auto [idx, result, operand] : - llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), - yield.getOperandTypes())) { - if (result == operand) + llvm::enumerate(getResultTypes(), yield.getOperands())) { + if (!operand) + return yield.emitOpError() << "operand " << idx << " is null\n"; + if (result == operand.getType()) continue; return (emitOpError("expected result #") << idx << " of each region to be " << result) .attachNote(yield.getLoc()) - << name << " returns " << operand << " here"; + << name << " returns " << operand.getType() << " here"; } return success(); }; diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index aea842d..71fe987 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -147,6 +147,45 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, } //===----------------------------------------------------------------------===// +// ParallelForToNestedForOps +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto payload = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payload)) + return emitSilenceableError() << "expected a single payload op"; + + auto target = dyn_cast<scf::ParallelOp>(*payload.begin()); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected the payload to be scf.parallel"; + diag.attachNote((*payload.begin())->getLoc()) << "payload op"; + return diag; + } + + if (getNumResults() != 1) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "op expects one result, given " + << getNumResults(); + diag.attachNote(target.getLoc()) << "payload op"; + return diag; + } + + FailureOr<scf::LoopNest> loopNest = + scf::parallelForToNestedFors(rewriter, target); + if (failed(loopNest)) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "failed to convert parallel into nested fors"; + return diag; + } + + results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()}); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // LoopOutlineOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index f8799c5..fb179e6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -769,7 +769,8 @@ struct ForOpInterface // Construct a new scf.for op with memref instead of tensor values. auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), castedInitArgs); + forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block *loopBody = newForOp.getBody(); diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index 6d3bafb..a07d9d4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSCFTransforms LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp + ParallelForToNestedFors.cpp ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index bee7780..ae52af5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); - auto cmpOp = arith::CmpIOp::create( - rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt, - beforeBlock->getArgument(0), forOp.getUpperBound()); + arith::CmpIPredicate predicate = forOp.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate, + beforeBlock->getArgument(0), + forOp.getUpperBound()); scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(), beforeBlock->getArguments()); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 1130538..7e7fba4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, bool *modifiedIR) { if (modifiedIR) *modifiedIR = false; + + // TODO: Add support for unsigned loops. + if (forOp.getUnsignedCmp()) + return failure(); + LoopPipelinerInternal pipeliner; if (!pipeliner.initializeLoopInfo(forOp, options)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 4752c08..f1203b2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override { + if (forOp.getUnsignedCmp()) + return rewriter.notifyMatchFailure(forOp, + "unsigned loops are not supported"); + // Do not peel already peeled loops. if (forOp->hasAttr(kPeeledLoopLabel)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp new file mode 100644 index 0000000..8f7d5e3 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp @@ -0,0 +1,86 @@ +//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ParallelOp to nested scf.for ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "parallel-for-to-nested-fors" +using namespace mlir; + +FailureOr<scf::LoopNest> +mlir::scf::parallelForToNestedFors(RewriterBase &rewriter, + scf::ParallelOp parallelOp) { + + if (!parallelOp.getResults().empty()) + return rewriter.notifyMatchFailure( + parallelOp, "Currently scf.parallel to scf.for conversion doesn't " + "support scf.parallel with results."); + + rewriter.setInsertionPoint(parallelOp); + + Location loc = parallelOp.getLoc(); + SmallVector<Value> lowerBounds = parallelOp.getLowerBound(); + SmallVector<Value> upperBounds = parallelOp.getUpperBound(); + SmallVector<Value> steps = parallelOp.getStep(); + + assert(lowerBounds.size() == upperBounds.size() && + lowerBounds.size() == steps.size() && + "Mismatched parallel loop bounds"); + + SmallVector<Value> ivs; + scf::LoopNest loopNest = + scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps); + + SmallVector<Value> newInductionVars = llvm::map_to_vector( + loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); }); + Block *linearizedBody = loopNest.loops.back().getBody(); + Block *parallelBody = parallelOp.getBody(); + rewriter.eraseOp(parallelBody->getTerminator()); + rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(), + newInductionVars); + rewriter.eraseOp(parallelOp); + return loopNest; +} + +namespace { +struct ParallelForToNestedFors final + : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> { + void runOnOperation() override { + Operation *parentOp = getOperation(); + IRRewriter rewriter(parentOp->getContext()); + + parentOp->walk( + [&](scf::ParallelOp parallelOp) { + if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) { + LLVM_DEBUG( + llvm::dbgs() + << "Failed to convert scf.parallel to nested scf.for ops for:\n" + << parallelOp << "\n"); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + } +}; +} // namespace + +std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() { + return std::make_unique<ParallelForToNestedFors>(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 694cd85..4ea8321 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -269,10 +269,10 @@ namespace { struct ParallelLoopFusion : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { void runOnOperation() override { - auto &AA = getAnalysis<AliasAnalysis>(); + auto &aa = getAnalysis<AliasAnalysis>(); auto mayAlias = [&](Value val1, Value val2) -> bool { - return !AA.alias(val1, val2).isNo(); + return !aa.alias(val1, val2).isNo(); }; getOperation()->walk([&](Operation *child) { diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 1b07b77..072bc50 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -52,8 +52,8 @@ public: SmallVector<unsigned> offsets; offsets.push_back(0); // Do the type conversion and record the offsets. - for (Type type : op.getResultTypes()) { - if (failed(typeConverter->convertTypes(type, dstTypes))) + for (Value v : op.getResults()) { + if (failed(typeConverter->convertType(v, dstTypes))) return rewriter.notifyMatchFailure(op, "could not convert result type"); offsets.push_back(dstTypes.size()); } @@ -116,7 +116,8 @@ public: llvm::getSingleElement(adaptor.getLowerBound()), llvm::getSingleElement(adaptor.getUpperBound()), llvm::getSingleElement(adaptor.getStep()), - flattenValues(adaptor.getInitArgs())); + flattenValues(adaptor.getInitArgs()), + /*bodyBuilder=*/nullptr, op.getUnsignedCmp()); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -126,7 +127,6 @@ public: // Inline the type converted region from the original operation. rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), newOp.getRegion().end()); - return newOp; } }; @@ -225,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions( void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); - }); + target.addDynamicallyLegalOp<ForOp, IfOp>( + [&](Operation *op) { return typeConverter.isLegal(op->getResults()); }); target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp())) return true; - return typeConverter.isLegal(op.getOperandTypes()); + return typeConverter.isLegal(op.getOperands()); }); target.addDynamicallyLegalOp<WhileOp, ConditionOp>( [&](Operation *op) { return typeConverter.isLegal(op); }); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index c0e47ee..834c021 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( inits.append(newInitOperands.begin(), newInitOperands.end()); auto newLoop = scf::ForOp::create( rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), - loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loopOp.getUnsignedCmp()); // Move the loop body to the new op. Block *loopBody = loopOp.getBody(); @@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest( auto newLoop = scf::ForOp::create( rewriter, forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), forLoop.getStep(), newInits, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); + [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}, + forLoop.getUnsignedCmp()); // Merge the body of the new loop with the body of the old loops. SmallVector<Value> sourceBlockArgs; @@ -1914,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Check that the loop is perfectly nested. -/// The loops are expected to be ordered from outer most to inner most. -/// For example: -/// ``` -/// %0 = scf.for() -/// %1 = scf.for() -/// %2 = scf.for() -/// %3 = ... -/// yield %3 -/// yield %2 -/// yield %1 -/// ``` -/// Here loops should be [%0, %1]. -static bool -isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) { - assert(!loops.empty() && "unexpected empty loop nest"); - if (loops.size() == 1) { - return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); - } - for (auto [outerLoop, innerLoop] : - llvm::zip_equal(loops.drop_back(), loops.drop_front())) { - auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); - auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); - if (!outerFor || !innerFor) { - return false; - } - auto outerBBArgs = outerFor.getRegionIterArgs(); - auto innerIterArgs = innerFor.getInitArgs(); - if (outerBBArgs.size() != innerIterArgs.size()) { - return false; - } - - for (auto [outerBBArg, innerIterArg] : - llvm::zip_equal(outerBBArgs, innerIterArgs)) { - if (!llvm::hasSingleElement(outerBBArg.getUses()) || - innerIterArg != outerBBArg) { - return false; - } - } - - ValueRange outerYields = - cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); - ValueRange innerResults = innerFor.getResults(); - if (outerYields.size() != innerResults.size()) { - return false; - } - for (auto [outerYield, innerResult] : - llvm::zip_equal(outerYields, innerResults)) { - if (!llvm::hasSingleElement(innerResult.getUses()) || - outerYield != innerResult) { - return false; - } - } - } - return true; -} - /// Fetch the untiled consumer of the outermost scf.for's result which is /// yielded by a tensor.insert_slice from the innermost scf.for. This function /// makes the following assumptions : diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5731795..684dff8 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl( static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef<scf::ForOp> targets) { + assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported"); auto originalStep = forOp.getStep(); auto iv = forOp.getInductionVar(); @@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, Loops innerLoops; for (auto t : targets) { + assert(!t.getUnsignedCmp() && "unsigned loops are not supported"); + // Save information for splicing ops out of t when done auto begin = t.getBody()->begin(); auto nOps = t.getBody()->getOperations().size(); @@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { + assert(source.getUnsignedCmp() == target.getUnsignedCmp() && + "incompatible signedness"); unsigned numTargetOuts = target.getNumResults(); unsigned numSourceOuts = source.getNumResults(); @@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, rewriter.setInsertionPointAfter(source); scf::ForOp fusedLoop = scf::ForOp::create( rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); + source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr, + source.getUnsignedCmp()); // Map original induction variables and operands to those of the fused loop. IRMapping mapping; @@ -1506,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter, rewriter.replaceOp(forallOp, normalizedForallOp); return normalizedForallOp; } + +bool mlir::isPerfectlyNestedForLoops( + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "unexpected empty loop nest"); + if (loops.size() == 1) + return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(loops.drop_back(), loops.drop_front())) { + auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); + auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); + if (!outerFor || !innerFor) + return false; + auto outerBBArgs = outerFor.getRegionIterArgs(); + auto innerIterArgs = innerFor.getInitArgs(); + if (outerBBArgs.size() != innerIterArgs.size()) + return false; + + for (auto [outerBBArg, innerIterArg] : + llvm::zip_equal(outerBBArgs, innerIterArgs)) { + if (!llvm::hasSingleElement(outerBBArg.getUses()) || + innerIterArg != outerBBArg) + return false; + } + + ValueRange outerYields = + cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); + ValueRange innerResults = innerFor.getResults(); + if (outerYields.size() != innerResults.size()) + return false; + for (auto [outerYield, innerResult] : + llvm::zip_equal(outerYields, innerResults)) { + if (!llvm::hasSingleElement(innerResult.getUses()) || + outerYield != innerResult) + return false; + } + } + return true; +} diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp new file mode 100644 index 0000000..47fe4d9 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp @@ -0,0 +1,251 @@ +//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SPV_ARM_graph operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVParsingUtils.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/Support/InterleavedRange.h" + +using namespace mlir; +using namespace mlir::spirv::AttrNames; + +//===----------------------------------------------------------------------===// +// spirv.GraphARM +//===----------------------------------------------------------------------===// + +ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, + OperationState &result) { + Builder &builder = parser.getBuilder(); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + bool isVariadic = false; + SmallVector<OpAsmParser::Argument> entryArgs; + SmallVector<Type> resultTypes; + SmallVector<DictionaryAttr> resultAttrs; + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + SmallVector<Type> argTypes = llvm::map_to_vector( + entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; }); + GraphType grType = builder.getGraphType(argTypes, resultTypes); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(grType)); + + // If additional attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); + + // Parse the optional function body. + Region *body = result.addRegion(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs); + return failure(parseResult.has_value() && failed(*parseResult)); +} + +void spirv::GraphARMOp::print(OpAsmPrinter &printer) { + // Print graph name, signature, and control. + printer << " "; + printer.printSymbolName(getSymName()); + GraphType grType = getFunctionType(); + function_interface_impl::printFunctionSignature( + printer, *this, grType.getInputs(), + /*isVariadic=*/false, grType.getResults()); + function_interface_impl::printFunctionAttributes(printer, *this, + {getFunctionTypeAttrName(), + getArgAttrsAttrName(), + getResAttrsAttrName()}); + + // Print the body. + Region &body = this->getBody(); + if (!body.empty()) { + printer << ' '; + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +LogicalResult spirv::GraphARMOp::verifyType() { + if (getFunctionType().getNumResults() < 1) + return emitOpError("there should be at least one result"); + return success(); +} + +LogicalResult spirv::GraphARMOp::verifyBody() { + for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) { + if (!isa<spirv::TensorArmType>(graphArgType)) { + return emitOpError("type of argument #") + << index << " must be a TensorArmType, but got " << graphArgType; + } + } + for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) { + if (!isa<spirv::TensorArmType>(graphResType)) { + return emitOpError("type of result #") + << index << " must be a TensorArmType, but got " << graphResType; + } + } + + if (!isExternal()) { + Block &entryBlock = front(); + + unsigned numArguments = this->getNumArguments(); + if (entryBlock.getNumArguments() != numArguments) + return emitOpError("entry block must have ") + << numArguments << " arguments to match graph signature"; + + for (auto [index, grArgType, blockArgType] : + llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { + if (blockArgType != grArgType) { + return emitOpError("type of entry block argument #") + << index << '(' << blockArgType + << ") must match the type of the corresponding argument in " + << "graph signature(" << grArgType << ')'; + } + } + } + + GraphType grType = getFunctionType(); + auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult { + if (grType.getNumResults() != op.getNumOperands()) + return op.emitOpError("is returning ") + << op.getNumOperands() + << " value(s) but enclosing spirv.ARM.Graph requires " + << grType.getNumResults() << " result(s)"; + + ValueTypeRange<OperandRange> graphOutputOperandTypes = + op.getValue().getType(); + for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) { + if (type != grType.getResult(index)) + return op.emitError("type of return operand ") + << index << " (" << type << ") doesn't match graph result type (" + << grType.getResult(index) << ")"; + } + return WalkResult::advance(); + }); + + return failure(walkResult.wasInterrupted()); +} + +void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state, + StringRef name, GraphType type, + ArrayRef<NamedAttribute> attrs, bool entryPoint) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs); + state.addAttribute(getEntryPointAttrName(state.name), + builder.getBoolAttr(entryPoint)); + state.addRegion(); +} + +ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() { + return getFunctionType().getInputs(); +} + +ArrayRef<Type> spirv::GraphARMOp::getResultTypes() { + return getFunctionType().getResults(); +} + +Region *spirv::GraphARMOp::getCallableRegion() { + return isExternal() ? nullptr : &getBody(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphOutputsARM +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GraphOutputsARMOp::verify() { + auto graph = cast<GraphARMOp>((*this)->getParentOp()); + + // The operand number and types must match the graph signature. + const ArrayRef<Type> &results = graph.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@" + << graph.getName() << ") returns " << results.size(); + + for (auto [index, result] : llvm::enumerate(results)) + if (getOperand(index).getType() != result) + return emitError() << "type of return operand " << index << " (" + << getOperand(index).getType() + << ") doesn't match spirv.ARM.Graph result type (" + << result << ")" + << " in graph @" << graph.getName(); + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphEntryPointARM +//===----------------------------------------------------------------------===// + +void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, + OperationState &state, + spirv::GraphARMOp graph, + ArrayRef<Attribute> interfaceVars) { + build(builder, state, SymbolRefAttr::get(graph), + builder.getArrayAttr(interfaceVars)); +} + +ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, + OperationState &result) { + FlatSymbolRefAttr fn; + if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) + return failure(); + + SmallVector<Attribute, 4> interfaceVars; + if (!parser.parseOptionalComma()) { + // Parse the interface variables. + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + // The name of the interface variable attribute is not important. + FlatSymbolRefAttr var; + NamedAttrList attrs; + if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) + return failure(); + interfaceVars.push_back(var); + return success(); + })) + return failure(); + } + result.addAttribute("interface", + parser.getBuilder().getArrayAttr(interfaceVars)); + return success(); +} + +void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getFn()); + ArrayRef<Attribute> interfaceVars = getInterface().getValue(); + if (!interfaceVars.empty()) { + printer << ", " << llvm::interleaved(interfaceVars); + } +} diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt index b9aa7b7..60d705d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters) add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen) add_mlir_dialect_library(MLIRSPIRVDialect + ArmGraphOps.cpp AtomicOps.cpp CastOps.cpp ControlFlowOps.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index d8dfe16..2f3a28f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) { return isNestedInFunctionOpInterface(op->getParentOp()); } +/// Returns true if the given op is a GraphARM op or nested in a +/// GraphARM op without a module-like op in the middle. +static bool isNestedInGraphARMOpInterface(Operation *op) { + if (!op) + return false; + if (op->hasTrait<OpTrait::SymbolTable>()) + return false; + if (isa<spirv::GraphARMOp>(op)) + return true; + return isNestedInGraphARMOpInterface(op->getParentOp()); +} + /// Returns true if the given op is an module-like op that maintains a symbol /// table. static bool isDirectInModuleLikeOp(Operation *op) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index ddb3426..369b953 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1322,7 +1322,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage { } TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType) - : shape(std::move(shape)), elementType(std::move(elementType)) {} + : shape(shape), elementType(elementType) {} ArrayRef<int64_t> shape; Type elementType; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 8f4c4cc..49f4ce8 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -608,6 +608,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, return wrapInStructAndGetPointer(arrayType, storageClass); } +static spirv::Dim convertRank(int64_t rank) { + switch (rank) { + case 1: + return spirv::Dim::Dim1D; + case 2: + return spirv::Dim::Dim2D; + case 3: + return spirv::Dim::Dim3D; + default: + llvm_unreachable("Invalid memref rank!"); + } +} + +static spirv::ImageFormat getImageFormat(Type elementType) { + return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType) + .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; }) + .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; }) + .Case<IntegerType>([](IntegerType intType) { + auto const isSigned = intType.isSigned() || intType.isSignless(); +#define BIT_WIDTH_CASE(BIT_WIDTH) \ + case BIT_WIDTH: \ + return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \ + : spirv::ImageFormat::R##BIT_WIDTH##ui + + switch (intType.getWidth()) { + BIT_WIDTH_CASE(16); + BIT_WIDTH_CASE(32); + default: + llvm_unreachable("Unhandled integer type!"); + } + }) + .Default([](Type) { + llvm_unreachable("Unhandled element type!"); + // We need to return something here to satisfy the type switch. + return spirv::ImageFormat::R32f; + }); +#undef BIT_WIDTH_CASE +} + static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { @@ -623,6 +662,41 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } spirv::StorageClass storageClass = attr.getValue(); + // Images are a special case since they are an opaque type from which elements + // may be accessed via image specific ops or directly through a texture + // pointer. + if (storageClass == spirv::StorageClass::Image) { + const int64_t rank = type.getRank(); + if (rank < 1 || rank > 3) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot lower memref of rank " << rank + << " to a SPIR-V Image\n"); + return nullptr; + } + + // Note that we currently only support lowering to single element texels + // e.g. R32f. + auto elementType = type.getElementType(); + if (!isa<spirv::ScalarType>(elementType)) { + LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of " + << elementType << " to a SPIR-V Image\n"); + return nullptr; + } + + // Currently every memref in the image storage class is converted to a + // sampled image so we can hardcode the NeedSampler field. Future work + // will generalize this to support regular non-sampled images. + auto spvImageType = spirv::ImageType::get( + elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown, + spirv::ImageArrayedInfo::NonArrayed, + spirv::ImageSamplingInfo::SingleSampled, + spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType)); + auto spvSampledImageType = spirv::SampledImageType::get(spvImageType); + auto imagePtrType = spirv::PointerType::get( + spvSampledImageType, spirv::StorageClass::UniformConstant); + return imagePtrType; + } + if (isa<IntegerType>(type.getElementType())) { if (type.getElementTypeBitWidth() == 1) return convertBoolMemrefType(targetEnv, options, type, storageClass); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index a53d0a7..670eabf 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements( return success(); } +static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) { + SetVector<spirv::Capability> tmp; + for (spirv::Capability cap : caps) + tmp.insert_range(getRecursiveImpliedCapabilities(cap)); + caps.insert_range(std::move(tmp)); +} + void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); @@ -151,6 +158,12 @@ void UpdateVCEPass::runOnOperation() { if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) valueTypes.push_back(globalVar.getType()); + // If the op is FunctionLike make sure to process input and result types. + if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) { + llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes()); + llvm::append_range(valueTypes, funcOpInterface.getResultTypes()); + } + // Requirements from values' types SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; @@ -174,6 +187,8 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + addAllImpliedCapabilities(deducedCapabilities); + // Update min version requirement for capabilities after deducing them. for (spirv::Capability cap : deducedCapabilities) { if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) { diff --git a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp index d4e7618..7a05dfe 100644 --- a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp @@ -513,8 +513,9 @@ LogicalResult shard::detail::defaultAddShardingAnnotations( } #ifndef NDEBUG -static bool isValueCompatibleWithFullReplicationSharding(Value value, - Sharding sharding) { +static bool +isValueCompatibleWithFullReplicationSharding(Value value, + const Sharding &sharding) { if (isa<RankedTensorType>(value.getType())) { return isFullReplication(sharding); } diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 3e3d476..5dc61a2 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -477,10 +477,10 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, return targetShard; } -TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid, - Sharding sourceSharding, Sharding targetSharding, - TypedValue<ShapedType> sourceUnshardedValue, - TypedValue<ShapedType> sourceShard) { +static TypedValue<ShapedType> +reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, + Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue, + TypedValue<ShapedType> sourceShard) { // If source and destination sharding are the same, no need to do anything. if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && isFullReplication(targetSharding))) { @@ -535,7 +535,7 @@ using UnshardedToShardedValueMap = DenseMap<Value, Value>; // Get the types of block arguments for an partitioned block. // Reads the sharding annotations of the arguments to deduce the sharded types. // Types that are not ranked tensors are left unchanged. -SmallVector<Type> +static SmallVector<Type> shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection) { SmallVector<Type> res; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 56b435c..9694a40 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -231,7 +231,9 @@ ParseResult DimLvlMapParser::parseLvlSpecList() { const auto loc = parser.getCurrentLocation(); const auto res = parser.parseCommaSeparatedList( mlir::OpAsmParser::Delimiter::Paren, - [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); }, + [this, requireLvlVarBinding]() -> ParseResult { + return parseLvlSpec(requireLvlVarBinding); + }, " in level-specifier list"); FAILURE_IF_FAILED(res) const auto specLvlRank = lvlSpecs.size(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index 9e2e6ab..a1711a6 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -156,13 +156,14 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { return pair1 <= pair2 ? sm1 : sm2; } -bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { +static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, + StringRef name) { const auto &var = env.access(id); return (var.getName() == name && var.getID() == id); } -bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, - VarKind vk) { +static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, + llvm::SMLoc loc, VarKind vk) { const auto &var = env.access(id); return var.getKind() == vk; } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 3b97786..dabbea1 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createLowerAffinePass()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass()); pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass()); pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass()); @@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createConvertComplexToLibm()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertComplexToLLVMPass()); - pm.addPass( - createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertFuncToLLVMPass()); - pm.addPass(createArithToLLVMConversionPass()); - pm.addPass(createConvertControlFlowToLLVMPass()); // Finalize GPU code generation. if (gpuCodegen) { @@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions)); } - // Convert poison values. - pm.addPass(createUBToLLVMConversionPass()); + // Convert to LLVM. + pm.addPass(createConvertToLLVMPass()); // Ensure all casts are realized. pm.addPass(createReconcileUnrealizedCastsPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 3b4140e..ae7eef2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1219,8 +1219,9 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, /// Implements the rewriting for operator sort and sort_coo. template <typename OpTy> -LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, - uint64_t ny, PatternRewriter &rewriter) { +static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, + AffineMap xPerm, uint64_t ny, + PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3..0e88d31d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ public: {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 4464450..febec6d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, VectorType vtp = vectorType(vl, init.getType()); Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), forOp.getRegionIterArg(0), init, vtp); - forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), - forOp.getUpperBound(), step, vinit); + forOpNew = + scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), + forOp.getUpperBound(), step, vinit, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); forOpNew->setAttr( LoopEmitter::getLoopEmitterLoopAttrName(), forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); @@ -605,8 +607,8 @@ public: ForOpRewriter(MLIRContext *context, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32) - : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, - enableSIMDIndex32} {} + : OpRewritePattern(context), + vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} LogicalResult matchAndRewrite(scf::ForOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 7d4b112..68584ec 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3200,20 +3200,6 @@ void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { setNameFn(getResult(), "padded"); } -// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it -// supports optional types. -void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, - Type typeToInfer, Type typeToInferFrom) {} - -ParseResult -parseInferType(OpAsmParser &parser, - std::optional<OpAsmParser::UnresolvedOperand> optOperand, - Type &typeToInfer, Type typeToInferFrom) { - if (optOperand) - typeToInfer = typeToInferFrom; - return success(); -} - LogicalResult PadOp::verify() { auto sourceType = llvm::cast<RankedTensorType>(getSource().getType()); auto resultType = llvm::cast<RankedTensorType>(getResult().getType()); @@ -4059,7 +4045,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// -bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { +static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { // 1. InsertSliceOp has its own logic about folding tensor.cast ops. // 2. Exclude DPS ops that are also LoopLike from this interface as they // might need special handling of attached regions. diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 2ec23e1..dfce835 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice PatternRewriter &rewriter) const override { auto expandShapeOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); + if (!expandShapeOp) { + return rewriter.notifyMatchFailure( + sliceOp, "tensor.extract_slice source not produced by expand_shape"); + } + SmallVector<ReassociationIndices> reassociation = + expandShapeOp.getReassociationIndices(); - if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, - rewriter) - .failed()) + SmallVector<OpFoldResult> offsets, sizes, strides; + if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation, + offsets, sizes, strides))) return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) - // referring to the state before applying the pattern are named with the - // prefix "expanded", and ones referring to the state after applying the - // pattern are named with the prefix "collapsed". - SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); - SmallVector<OpFoldResult> expandedShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - // Helper variables and function for accumulating the size values. - Location loc = expandShapeOp->getLoc(); - AffineExpr d0, d1, d2; - bindDims(rewriter.getContext(), d0, d1, d2); - // Multiply two integers. - auto mul = [&](OpFoldResult v1, OpFoldResult v2) { - auto mulMap = AffineMap::get(2, 0, {d0 * d1}); - return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, - {v1, v2}); - }; - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank of - // ReassociationIndices.size(). In the loop a single offset, size, and - // stride value is computed per reassociation group. - SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, - collapsedStrides; - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - // collapsedSize will hold the size of the single dim that represents the - // reassociation group in the non expanded tensor. - OpFoldResult collapsedSize = rewriter.getIndexAttr(1); - // The reassocGroupSizes and reassocGroupOffsets are used to create an - // affine.linearize_index op to linearize the single offset value required - // for this reassociation group. - SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; - - for (long expandedDim : indices) { - // reassocGroupSizes and reassocGroupOffsets can be obtained directly - // from the expanded state, but the collapsed size requires calculation - // as it did not previously exist. - reassocGroupSizes.push_back(expandedShape[expandedDim]); - reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); - collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); - } - - SmallVector<Value> offsetVals = - llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { - return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); - }); - OpFoldResult collapsedOffset = - affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals, - reassocGroupSizes, - /*disjoint=*/true) - .getResult(); - collapsedOffsets.push_back(collapsedOffset); - collapsedSizes.push_back(collapsedSize); - - // Only unit stride is supported. - collapsedStrides.push_back(rewriter.getIndexAttr(1)); - } - // The shape of the result can be obtained from the sizes passed in. - SmallVector<Value> dynDims; - SmallVector<int64_t> shape; - dispatchIndexOpFoldResults(expandedSizes, dynDims, shape); - RankedTensorType resultType = RankedTensorType::get( - shape, expandShapeOp.getResultType().getElementType()); + SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); + RankedTensorType resultType = sliceOp.getResultType(); // Create a new ExtractSliceOp and ExpandShapeOp. + Location loc = sliceOp.getLoc(); Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, - collapsedStrides); + rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( sliceOp, resultType, newSliceOp, expandShapeOp.getReassociationIndices(), expandedSizes); return success(); } - - // Helper function to check if all the required conditions for the - // tensor.extract_slice to be bubbled up through the tensor.expand_shape are - // met. - LogicalResult - checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, - tensor::ExpandShapeOp expandShapeOp, - PatternRewriter &rewriter) const { - - if (!expandShapeOp) { - return rewriter.notifyMatchFailure( - sliceOp, "tensor.extract_slice source not produced by expand_shape"); - } - - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } - - SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); - - if (static_cast<size_t>(sliceOp.getResultType().getRank()) != - sizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } - - SmallVector<OpFoldResult> outputShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> - isZeroOffsetAndFullSize = - [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroInteger(offset)) - return false; - FailureOr<bool> maybeEqual = - ValueBoundsConstraintSet::areEqual(sliceSize, size); - return llvm::succeeded(maybeEqual) && maybeEqual.value(); - }; - - // Check that the slice is contiguous within each reassociation group. - // The slice is contiguous only if after the first dimension where a non - // unit slice is taken, the slice size on all subsequent dimensions of the - // group is equal to the entire size of the dimension. - // Examples of contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] - // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] - // Examples of non contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] - // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - int64_t i = 0; - int64_t e = indices.size(); - // Find the first expanded dim after the first dim with non-unit extracted - // size. - for (; i < e; ++i) { - if (!isOneInteger(sizes[indices[i]])) { - // +1 to skip the first non-unit size dim. - i++; - break; - } - } - - // Verify that all subsequent dimensions extract the full size of the - // source tensor. - for (; i < e; ++i) { - int64_t expandedDim = indices[i]; - if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], - outputShape[expandedDim])) { - return rewriter.notifyMatchFailure( - sliceOp, "Not a contiguous slice of the expanded tensor."); - } - } - } - - return success(); - } }; /// Converts `tensor.extract_slice(tensor.collapse_shape)` to @@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice "tensor.extract_slice source not produced by tensor.collapse_shape"); } - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } + SmallVector<OpFoldResult> offsets, sizes, strides; + if (failed(getExpandedExtractSliceInfo( + rewriter, sliceOp, collapseShapeOp.getReassociationIndices(), + collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides))) + return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.collapse_shape, so variables (i.e. inputs for - // ExtractSliceOp) referring to the state before applying the pattern are - // named with the prefix "collapsed", and ones referring to the state after - // applying the pattern are named with the prefix "expanded". - SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); - - if (static_cast<size_t>(sliceOp.getResultType().getRank()) != - collapsedSizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } + Value newSliceOp = tensor::ExtractSliceOp::create( + rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets, + sizes, strides); + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + sliceOp, sliceOp.getResultType(), newSliceOp, + collapseShapeOp.getReassociationIndices()); - ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); - SmallVector<ReassociationIndices, 4> reassociationIndices = - collapseShapeOp.getReassociationIndices(); - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank - // equal to the rank of the src of the collapse_shape. In each iteration of - // the loop, the offsets and sizes will be computed per reassociation group. - SmallVector<OpFoldResult> expandedOffsets, expandedSizes; - SmallVector<OpFoldResult> expandedStrides(srcShape.size(), - rewriter.getIndexAttr(1)); - - for (auto [collapsedSize, collapsedOffset, reassocIndices] : - llvm::zip_equal(collapsedSizes, collapsedOffsets, - collapseShapeOp.getReassociationIndices())) { - // CASE #1 - size and/or offset are dynamic. - // In this case, the slice can be represented as a contiguous slice only - // if there is a single dimension in the reassociation group that has a - // size not equal to 1. - if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { - int nonUnitSizeCount = 0; - for (int64_t expandedShapeIdx : reassocIndices) { - if (srcShape[expandedShapeIdx] != 1) { - nonUnitSizeCount++; - expandedSizes.push_back(collapsedSize); - expandedOffsets.push_back(collapsedOffset); - continue; - } - - expandedSizes.push_back(rewriter.getIndexAttr(1)); - expandedOffsets.push_back(rewriter.getIndexAttr(0)); - } + return success(); + } +}; - if (nonUnitSizeCount != 1) { - return rewriter.notifyMatchFailure( - sliceOp, - "unsupported: slice cannot be verified to be contiguous"); - } - continue; - } +} // namespace - // CASE #2 = size and offset are static. - // Verify that the slice can be represented as a contiguous slice of the - // src of the collapse_shape. - // Checking this is done on order of most internal dimensions first, - // so traversal is done in reverse order of the reassociation group. - // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, - // ...,An] then we first find the size and offset for n...k+1 then for k - // and then for k-1...0. - - // currentCollapsedsize and currentCollapsedOffset are initialized with - // the original collapsed size and offset and divided by the expanded - // shape size in each dimension as we go along the reassociation group. - // In essence we are spreading the original collapsed size and offset over - // the various expanded slice dimensions. - // The variables are used both to check the validity of the slice and to - // compute the expanded sizes and offsets. - int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); - int64_t currentCollapsedOffset = - getConstantIntValue(collapsedOffset).value(); - - SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; - - ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), - reassocIndices.rend()); - int64_t idx = 0; - int64_t reassocGroupSize = reassocIndices.size(); - - // First handle the trailing dimensions where the slice size should be - // equal to the tensor shape and the offset should be 0 (n...k+1). - for (; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - - if (currentCollapsedsize < expandedShapeSize) - break; - - // We need to make sure that the slice size can be set to the shape size - // and the offset to 0. - if ((currentCollapsedsize % expandedShapeSize) != 0 || - (currentCollapsedOffset % expandedShapeSize) != 0) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: cannot be extracted as a contiguous slice " - "of the src of the collapse_shape"); - } +LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef<ReassociationIndices> reassociation, + SmallVectorImpl<OpFoldResult> &collapsedOffsets, + SmallVectorImpl<OpFoldResult> &collapsedSizes, + SmallVectorImpl<OpFoldResult> &collapsedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); - groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); + if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) { + return failure(); + } - currentCollapsedsize /= expandedShapeSize; - currentCollapsedOffset /= expandedShapeSize; + auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, + OpFoldResult sliceSize, int64_t inputDim) { + if (!isZeroInteger(offset)) + return false; + ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); + FailureOr<bool> maybeEqual = + ValueBoundsConstraintSet::areEqual(sliceSize, inputSize); + return llvm::succeeded(maybeEqual) && maybeEqual.value(); + }; + + // Check that the slice is contiguous within each reassociation group. + // The slice is contiguous only if after the first dimension where a non + // unit slice is taken, the slice size on all subsequent dimensions of the + // group is equal to the entire size of the dimension. + // Examples of contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] + // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] + // Examples of non contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] + // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] + for (const ReassociationIndices &indices : reassociation) { + int64_t i = 0; + int64_t e = indices.size(); + // Find the first expanded dim after the first dim with non-unit extracted + // size. + for (; i < e; ++i) { + if (!isOneInteger(sizes[indices[i]])) { + // +1 to skip the first non-unit size dim. + i++; + break; } + } + + // Verify that all subsequent dimensions extract the full size of the + // source tensor. + for (; i < e; ++i) { + int64_t expandedDim = indices[i]; + if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], + expandedDim)) { + return failure(); + } + } + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) + // referring to the state before applying the pattern are named with the + // prefix "expanded", and ones referring to the state after applying the + // pattern are named with the prefix "collapsed". + Location loc = sliceOp.getLoc(); + SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); + SmallVector<OpFoldResult> expandedShape = + getMixedSizes(b, loc, sliceOp.getSource()); + + // Helper variables and function for accumulating the size values. + AffineExpr d0, d1, d2; + bindDims(b.getContext(), d0, d1, d2); + // Multiply two integers. + auto mul = [&](OpFoldResult v1, OpFoldResult v2) { + auto mulMap = AffineMap::get(2, 0, {d0 * d1}); + return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2}); + }; + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank of + // ReassociationIndices.size(). In the loop a single offset, size, and + // stride value is computed per reassociation group. + for (const ReassociationIndices &indices : reassociation) { + // collapsedSize will hold the size of the single dim that represents the + // reassociation group in the non expanded tensor. + OpFoldResult collapsedSize = b.getIndexAttr(1); + // The reassocGroupSizes and reassocGroupOffsets are used to create an + // affine.linearize_index op to linearize the single offset value required + // for this reassociation group. + SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; + + for (long expandedDim : indices) { + // reassocGroupSizes and reassocGroupOffsets can be obtained directly + // from the expanded state, but the collapsed size requires calculation + // as it did not previously exist. + reassocGroupSizes.push_back(expandedShape[expandedDim]); + reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); + collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); + } + + SmallVector<Value> offsetVals = + llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(b, loc, ofr); + }); + OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create( + b, loc, offsetVals, reassocGroupSizes, + /*disjoint=*/true) + .getResult(); + collapsedOffsets.push_back(collapsedOffset); + collapsedSizes.push_back(collapsedSize); + + // Only unit stride is supported. + collapsedStrides.push_back(b.getIndexAttr(1)); + } + return success(); +} + +LogicalResult mlir::tensor::getExpandedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef<ReassociationIndices> reassociation, + ArrayRef<int64_t> expandedShape, + SmallVectorImpl<OpFoldResult> &expandedOffsets, + SmallVectorImpl<OpFoldResult> &expandedSizes, + SmallVectorImpl<OpFoldResult> &expandedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.collapse_shape, so variables (i.e. inputs for + // ExtractSliceOp) referring to the state before applying the pattern are + // named with the prefix "collapsed", and ones referring to the state after + // applying the pattern are named with the prefix "expanded". + SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); + if (static_cast<size_t>(sliceOp.getResultType().getRank()) != + collapsedSizes.size()) { + return failure(); + } - // Now handle the first dim where slicing occurs on (k). - if (idx < reassocGroupSize) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - // We need to make sure that the slice size in this dim + offset will - // not exceed the shape size. - if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: slice cannot be extracted as a contiguous " - "slice of the src of the collapse_shape"); + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank + // equal to the rank of the src of the collapse_shape. In each iteration of + // the loop, the offsets and sizes will be computed per reassociation group. + expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1)); + for (auto [collapsedSize, collapsedOffset, reassocIndices] : + llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) { + // CASE #1 - size and/or offset are dynamic. + // In this case, the slice can be represented as a contiguous slice only + // if there is a single dimension in the reassociation group that has a + // size not equal to 1. + if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { + int nonUnitSizeCount = 0; + for (int64_t expandedShapeIdx : reassocIndices) { + if (expandedShape[expandedShapeIdx] != 1) { + nonUnitSizeCount++; + expandedSizes.push_back(collapsedSize); + expandedOffsets.push_back(collapsedOffset); + continue; } - groupExpandedSizes.push_back( - rewriter.getIndexAttr(currentCollapsedsize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); + expandedSizes.push_back(b.getIndexAttr(1)); + expandedOffsets.push_back(b.getIndexAttr(0)); + } - currentCollapsedOffset /= expandedShapeSize; + if (nonUnitSizeCount != 1) { + return failure(); } + continue; + } - // Now handle the leading dimensions where the slice size is equal to 1 - // (k-1...0). - // The size for these dimensions must be 1 because of how we constructed - // the slice size of the expanded shape. We spread the original collapsed - // size over the expanded shape sizes until we reached dimension k where - // the remaining size was smaller than the expanded shape size, and spread - // the remaining size on it. So, now we are left with only 1s. - for (idx++; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - currentCollapsedOffset /= expandedShapeSize; + // CASE #2 = size and offset are static. + // Verify that the slice can be represented as a contiguous slice of the + // src of the collapse_shape. + // Checking this is done on order of most internal dimensions first, + // so traversal is done in reverse order of the reassociation group. + // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, + // ...,An] then we first find the size and offset for n...k+1 then for k + // and then for k-1...0. + + // currentCollapsedsize and currentCollapsedOffset are initialized with + // the original collapsed size and offset and divided by the expanded + // shape size in each dimension as we go along the reassociation group. + // In essence we are spreading the original collapsed size and offset over + // the various expanded slice dimensions. + // The variables are used both to check the validity of the slice and to + // compute the expanded sizes and offsets. + int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); + int64_t currentCollapsedOffset = + getConstantIntValue(collapsedOffset).value(); + SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; + ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), + reassocIndices.rend()); + int64_t idx = 0; + int64_t reassocGroupSize = reassocIndices.size(); + + // First handle the trailing dimensions where the slice size should be + // equal to the tensor shape and the offset should be 0 (n...k+1). + for (; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + + if (currentCollapsedsize < expandedShapeSize) + break; + + // We need to make sure that the slice size can be set to the shape size + // and the offset to 0. + if ((currentCollapsedsize % expandedShapeSize) != 0 || + (currentCollapsedOffset % expandedShapeSize) != 0) { + return failure(); } - expandedSizes.append(groupExpandedSizes.rbegin(), - groupExpandedSizes.rend()); - expandedOffsets.append(groupExpandedOffsets.rbegin(), - groupExpandedOffsets.rend()); + groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize)); + groupExpandedOffsets.push_back(b.getIndexAttr(0)); + + currentCollapsedsize /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } - Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), - expandedOffsets, expandedSizes, expandedStrides); - rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( - sliceOp, sliceOp.getResultType(), newSliceOp, - collapseShapeOp.getReassociationIndices()); + // Now handle the first dim where slicing occurs on (k). + if (idx < reassocGroupSize) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + // We need to make sure that the slice size in this dim + offset will + // not exceed the shape size. + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { + return failure(); + } + groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } - return success(); + // Now handle the leading dimensions where the slice size is equal to 1 + // (k-1...0). + // The size for these dimensions must be 1 because of how we constructed + // the slice size of the expanded shape. We spread the original collapsed + // size over the expanded shape sizes until we reached dimension k where + // the remaining size was smaller than the expanded shape size, and spread + // the remaining size on it. So, now we are left with only 1s. + for (idx++; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + groupExpandedSizes.push_back(b.getIndexAttr(1)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } + expandedSizes.append(groupExpandedSizes.rbegin(), + groupExpandedSizes.rend()); + expandedOffsets.append(groupExpandedOffsets.rbegin(), + groupExpandedOffsets.rend()); } -}; - -} // namespace + return success(); +} void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index e3cba388..8d63646 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -122,8 +122,9 @@ struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { const APFloat lowestVal = APFloat::getLargest(padConstVal.getSemantics(), true); return padConstVal == lowestVal; - } else if (auto padConstIntAttr = - mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) { + } + if (auto padConstIntAttr = + mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) { const APInt padConstVal = *padConstIntAttr.begin(); const unsigned int bitWidth = padConstVal.getBitWidth(); const APInt lowestVal = @@ -555,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { // Check we have a valid NaN propagation combination. const auto opNanMode = op.getNanMode(); const auto clampNanMode = clampOp.getNanMode(); - if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") + if (opNanMode == NanPropagationMode::IGNORE && + clampNanMode == NanPropagationMode::PROPAGATE) return failure(); auto maxValAttr = op.getMaxValAttr(); @@ -636,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { } } + auto newMode = (opNanMode != clampNanMode) + ? tosa::NanPropagationMode::IGNORE + : opNanMode; + + auto newModeAttr = + NanPropagationModeAttr::get(rewriter.getContext(), newMode); + rewriter.replaceOpWithNewOp<tosa::ClampOp>( op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr, - rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" - : opNanMode)); + newModeAttr); return success(); } }; @@ -1120,13 +1128,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { } if (rhsTy == resultTy) { - if (isSplatZero(resultETy, lhsAttr)) + if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape()) + // constant values can only be resized if resulting type is static return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { - if (isSplatZero(resultETy, rhsAttr)) + if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape()) return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3cafb19..bd7aee5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -270,6 +270,244 @@ void mlir::tosa::printVariableOpTypeOrInitialValue( } } +namespace { + +// parse attributes with special handling for tosa enum attributes +template <typename EnumType> +ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, + NamedAttrList &outAttrs) { + llvm::StringRef name; + if (parser.parseOptionalKeyword(&name) || parser.parseEqual()) + return failure(); + + // special handling: rounding_mode accepts a *bare* RoundingMode enum + // keyword. + llvm::StringRef kw; + if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) { + if (name == "rounding_mode" && + succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeRoundingMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid rounding_mode value: " << kw; + auto attr = RoundingModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // special handling: mode accepts a *bare* ResizeMode enum keyword. + if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) { + if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeResizeMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid resize mode value: " << kw; + auto attr = ResizeModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // special handling: nan_mode accepts a *bare* NanPropagationMode enum + // keyword. + if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) { + if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeNanPropagationMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid nan_mode value: " << kw; + auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + + // Default path: parse any normal attribute literal, including fully qualified + // enum keyword + Attribute attr; + return parser.parseAttribute(attr, name, outAttrs); +} + +template <typename EnumType> +ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { + // parse operands + SmallVector<OpAsmParser::UnresolvedOperand, 5> operands; + if (parser.parseCommaSeparatedList( + [&]() { return parser.parseOperand(operands.emplace_back()); })) + return failure(); + + // Parse { attr-dict } with special handling for enum bare token + NamedAttrList attrs; + if (succeeded(parser.parseOptionalLBrace()) && + failed(parser.parseOptionalRBrace())) { + do { + if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs)) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRBrace()) + return failure(); + } + + FunctionType fnTy; + if (parser.parseColonType(fnTy)) + return failure(); + + // Resolve operands and types + if (failed(parser.resolveOperands(operands, fnTy.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + + result.addTypes(fnTy.getResult(0)); + result.addAttributes(attrs); + + return success(); +} + +void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { + parser << namedAttr.getName().strref() << " = "; + auto attr = namedAttr.getValue(); + if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) { + parser << roundingModeAttr.getValue(); + } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) { + parser << resizeModeAttr.getValue(); + } else if (auto nanPropagationModeAttr = + dyn_cast<tosa::NanPropagationModeAttr>(attr)) { + parser << nanPropagationModeAttr.getValue(); + } else { + parser.printAttribute(attr); + } +} + +// print with special handling for default valued NanPropagationMode attribute +void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) { + parser << " "; + parser.printOperands(op->getOperands()); + + NamedAttrList toPrint(op->getAttrs()); + // remove default NanPropagate attribute + const auto kDefaultNanValue = NanPropagationMode::PROPAGATE; + for (auto attr : op->getAttrs()) { + if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) { + if (nanAttr.getValue() == kDefaultNanValue) { + // elide from toPrint + toPrint.erase(attr.getName()); + break; + } + } + } + + if (!toPrint.empty()) { + parser << " {"; + llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) { + printNamedAttr(parser, namedAttr); + }); + parser << "}"; + } + + parser << " : "; + parser.printFunctionalType(op); +} + +// print with special handling for enums: RoundingMode, ResizeMode +void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) { + parser << " "; + parser.printOperands(op->getOperands()); + + if (!op->getAttrs().empty()) { + parser << " {"; + llvm::interleaveComma(op->getAttrs(), parser, + [&](const NamedAttribute namedAttr) { + printNamedAttr(parser, namedAttr); + }); + parser << "}"; + } + + parser << " : "; + parser.printFunctionalType(op); +} + +} // namespace + +ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::RoundingMode>(parser, result); +} + +void RescaleOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::RoundingMode>(parser, result); +} + +void ApplyScaleOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::ResizeMode>(parser, result); +} + +void ResizeOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ArgMaxOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MaxPool2dOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ClampOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MaximumOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MinimumOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ReduceMaxOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ReduceMinOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 5590927..8143b27 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -658,10 +658,10 @@ void TosaReduceTransposes::runOnOperation() { // (like the TransposeOp we insert for ReshapeOp), // but in this case, that is specialized enough and overlaps // with another direct-use TransposeOp case we need to cover anyway. - transposeInfo.push_back({transposeOp, dependentOps}); + transposeInfo.emplace_back(transposeOp, dependentOps); // This is for the final replacement across all transposes. - totalTransposeOrder.push({transposeOp, perms}); + totalTransposeOrder.emplace(transposeOp, perms); }); // We want to do a full fan-in analysis on a perms-level, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index c7b9534..790bbf7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -508,14 +508,15 @@ private: bool attributeCheckRescale(Operation *op) { if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) { - if (rescale.getRoundingMode() == "DOUBLE_ROUND" && + if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !targetEnv.allows(Extension::doubleround)) { op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND " << "requires extension [doubleround]"; return false; - } else if (rescale.getRoundingMode() == "INEXACT_ROUND" && - !targetEnv.allows(Extension::inexactround)) { + } + if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND && + !targetEnv.allows(Extension::inexactround)) { op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND " << "requires extension [inexactround]"; @@ -1122,7 +1123,7 @@ bool checkErrorIfRescale(Operation *op) { } // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND)) - if (!scale32 && roundingMode == "DOUBLE_ROUND") { + if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) { op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true."; return false; } @@ -1307,7 +1308,8 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(type); - } else if (auto intTy = dyn_cast<IntegerType>(type)) { + } + if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 9266a63..48df1a0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -37,16 +37,13 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include <optional> #define DEBUG_TYPE "transform-dialect" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") - #define DEBUG_TYPE_MATCHER "transform-matcher" -#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") -#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) using namespace mlir; @@ -182,8 +179,7 @@ transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = state.applyTransform(cast<TransformOpInterface>(transform)); if (result.isSilenceableFailure()) { - LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() - << "\n"); + LDBG() << "alternative failed: " << result.getMessage(); failed = true; break; } @@ -1155,12 +1151,10 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, std::optional<DiagnosedSilenceableFailure> maybeFailure; for (Operation *root : state.getPayloadOps(getRoot())) { WalkResult walkResult = root->walk([&](Operation *op) { - DEBUG_MATCHER({ - DBGS_MATCHER() << "matching "; - op->print(llvm::dbgs(), - OpPrintingFlags().assumeVerified().skipRegions()); - llvm::dbgs() << " @" << op << "\n"; - }); + LDBG(1, DEBUG_TYPE_MATCHER) + << "matching " + << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions()) + << " @" << op; // Try matching. SmallVector<SmallVector<MappedValue>> mappings; @@ -1172,8 +1166,8 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { - DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() - << " failed: " << diag.getMessage()); + LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName() + << " failed: " << diag.getMessage(); return WalkResult::advance(); } @@ -1304,12 +1298,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, if (!getRestrictRoot() && op == root) return WalkResult::advance(); - DEBUG_MATCHER({ - DBGS_MATCHER() << "matching "; - op->print(llvm::dbgs(), - OpPrintingFlags().assumeVerified().skipRegions()); - llvm::dbgs() << " @" << op << "\n"; - }); + LDBG(1, DEBUG_TYPE_MATCHER) + << "matching " + << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions()) + << " @" << op; firstMatchArgument.clear(); firstMatchArgument.push_back(op); @@ -1322,8 +1314,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { - DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() - << " failed: " << diag.getMessage()); + LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName() + << " failed: " << diag.getMessage(); continue; } @@ -2173,10 +2165,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( ::std::optional<::mlir::Operation *> maybeCurrent, transform::TransformResults &results, transform::TransformState &state) { if (!maybeCurrent.has_value()) { - DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); + LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success"; return DiagnosedSilenceableFailure::success(); } - DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); + LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure"; return emitSilenceableError() << "operation is not empty"; } diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp index d666390..773eb13 100644 --- a/mlir/lib/Dialect/Transform/IR/Utils.cpp +++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -90,7 +91,7 @@ transform::detail::mergeSymbolsInto(Operation *target, // // Rename private symbols in both ops in order to resolve conflicts that can // be resolved that way. - LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + LDBG() << "renaming private symbols to resolve conflicts:"; // TODO: Do we *actually* need to test in both directions? for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, @@ -102,7 +103,7 @@ transform::detail::mergeSymbolsInto(Operation *target, if (!symbolOp) continue; StringAttr name = symbolOp.getNameAttr(); - LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); + LDBG() << " found @" << name.getValue(); // Check if there is a colliding op in the other module. auto collidingOp = @@ -110,7 +111,7 @@ transform::detail::mergeSymbolsInto(Operation *target, if (!collidingOp) continue; - LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); + LDBG() << " collision found for @" << name.getValue(); // Collisions are fine if both opt are functions and can be merged. if (auto funcOp = dyn_cast<FunctionOpInterface>(op), @@ -119,13 +120,12 @@ transform::detail::mergeSymbolsInto(Operation *target, funcOp && collidingFuncOp) { if (canMergeInto(funcOp, collidingFuncOp) || canMergeInto(collidingFuncOp, funcOp)) { - LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " - "will be merged\n"); + LDBG() << " but both ops are functions and will be merged"; continue; } // If they can't be merged, proceed like any other collision. - LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); + LDBG() << " and both ops are function definitions"; } // Collision can be resolved by renaming if one of the ops is private. @@ -133,7 +133,7 @@ transform::detail::mergeSymbolsInto(Operation *target, [&](SymbolOpInterface op, SymbolOpInterface otherOp, SymbolTable &symbolTable, SymbolTable &otherSymbolTable) -> InFlightDiagnostic { - LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); + LDBG() << ", renaming"; FailureOr<StringAttr> maybeNewName = symbolTable.renameToUnique(op, {&otherSymbolTable}); if (failed(maybeNewName)) { @@ -142,8 +142,7 @@ transform::detail::mergeSymbolsInto(Operation *target, << "attempted renaming due to collision with this op"; return diag; } - LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() - << "\n"); + LDBG() << " renamed to @" << maybeNewName->getValue(); return InFlightDiagnostic(); }; @@ -161,7 +160,7 @@ transform::detail::mergeSymbolsInto(Operation *target, return diag; continue; } - LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); + LDBG() << ", emitting error"; InFlightDiagnostic diag = symbolOp.emitError() << "doubly defined symbol @" << name.getValue(); diag.attachNote(collidingOp->getLoc()) << "previously defined here"; @@ -179,7 +178,7 @@ transform::detail::mergeSymbolsInto(Operation *target, // Step 2: // // Move all ops from `other` into target and merge public symbols. - LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); + LDBG() << "moving all symbols into target"; { SmallVector<SymbolOpInterface> opsToMove; for (Operation &op : other->getRegion(0).front()) { @@ -193,13 +192,13 @@ transform::detail::mergeSymbolsInto(Operation *target, targetSymbolTable.lookup(op.getNameAttr())); // Move op even if we get a collision. - LLVM_DEBUG(DBGS() << " moving @" << op.getName()); + LDBG() << " moving @" << op.getName(); op->moveBefore(&target->getRegion(0).front(), target->getRegion(0).front().end()); // If there is no collision, we are done. if (!collidingOp) { - LLVM_DEBUG(llvm::dbgs() << " without collision\n"); + LDBG() << " without collision"; continue; } @@ -217,9 +216,9 @@ transform::detail::mergeSymbolsInto(Operation *target, } assert(canMergeInto(funcOp, collidingFuncOp)); - LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " - << collidingFuncOp.getLoc() << ":\n" - << collidingFuncOp << "\n"); + LDBG() << " with collision, trying to keep op at " + << collidingFuncOp.getLoc() << ":\n" + << collidingFuncOp; // Update symbol table. This works with or without the previous `swap`. targetSymbolTable.remove(funcOp); @@ -239,6 +238,6 @@ transform::detail::mergeSymbolsInto(Operation *target, return target->emitError() << "failed to verify target op after merging symbols"; - LLVM_DEBUG(DBGS() << "done merging ops\n"); + LDBG() << "done merging ops"; return InFlightDiagnostic(); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 14a4fdf..4f4620a 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -312,7 +312,7 @@ LogicalResult transform::TransformState::setParams(Value value, } template <typename Mapping, typename Key, typename Mapped> -void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { +static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { auto it = mapping.find(key); if (it == mapping.end()) return; @@ -771,7 +771,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( } template <typename T> -DiagnosedSilenceableFailure +static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef<T> payload, transform::TransformOpInterface transform, unsigned operandNumber) { diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp index 41955c8..3ced1a6 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -100,12 +100,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches( PatternApplicator applicator(it->second); // We want to discourage direct use of PatternRewriter in APIs but In this // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(root->getContext()); + PatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp index 35ace1b..9ab484f 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp @@ -121,6 +121,80 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) { ->getLibraryModule(); } +static transform::TransformOpInterface +findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) { + if (namedSequenceOp.getSymName() == entryPoint) { + return cast<transform::TransformOpInterface>( + namedSequenceOp.getOperation()); + } + } + } + } + return nullptr; +} + +static transform::TransformOpInterface +findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) { + transform::TransformOpInterface transform = nullptr; + op->walk<WalkOrder::PreOrder>( + [&](transform::NamedSequenceOp namedSequenceOp) { + if (namedSequenceOp.getSymName() == entryPoint) { + transform = cast<transform::TransformOpInterface>( + namedSequenceOp.getOperation()); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return transform; +} + +// Will look for the transform's entry point favouring NamedSequenceOps +// ops that exist within the operation without the need for nesting. +// If no operation exists in the blocks owned by op, then it will recursively +// walk the op in preorder and find the first NamedSequenceOp that matches +// the entry point's name. +// +// This allows for the following two use cases: +// 1. op is a module annotated with the transform.with_named_sequence attribute +// that has an entry point in its block. E.g., +// +// ```mlir +// module {transform.with_named_sequence} { +// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) -> +// () { +// transform.yield +// } +// } +// ``` +// +// 2. op is a program which contains a nested module annotated with the +// transform.with_named_sequence attribute. E.g., +// +// ```mlir +// module { +// func.func @foo () { +// } +// +// module {transform.with_named_sequence} { +// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) +// -> () { +// transform.yield +// } +// } +// } +// ``` +static transform::TransformOpInterface +findTransformEntryPointInOp(Operation *op, StringRef entryPoint) { + transform::TransformOpInterface transform = + findTransformEntryPointNonRecursive(op, entryPoint); + if (!transform) + transform = findTransformEntryPointRecursive(op, entryPoint); + return transform; +} + transform::TransformOpInterface transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint) { @@ -128,16 +202,8 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, if (module) l.push_back(module); for (Operation *op : l) { - transform::TransformOpInterface transform = nullptr; - op->walk<WalkOrder::PreOrder>( - [&](transform::NamedSequenceOp namedSequenceOp) { - if (namedSequenceOp.getSymName() == entryPoint) { - transform = cast<transform::TransformOpInterface>( - namedSequenceOp.getOperation()); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + TransformOpInterface transform = + findTransformEntryPointInOp(op, entryPoint); if (transform) return transform; } diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index bc85cf4..7b2734d 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -407,7 +407,7 @@ mlir::convertReassociationIndicesToExprs( } template <typename AffineExprTy> -unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { +static unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { unsigned pos = 0; for (const auto &exprs : exprArrays) { for (auto expr : exprs) { diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index e6ef028..34385d7 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, if (!ubConstant) return std::nullopt; std::optional<int64_t> stepConstant = getConstantIntValue(step); - if (!stepConstant) + if (!stepConstant || *stepConstant == 0) return std::nullopt; return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a450056..9b2a455 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, return foldToElementsFromElements(*this, results); } +LogicalResult +ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, + ToElementsOp::Adaptor adaptor, + SmallVectorImpl<Type> &inferredReturnTypes) { + auto vecType = cast<VectorType>(adaptor.getSource().getType()); + Type elType = vecType.getElementType(); + inferredReturnTypes.append(vecType.getNumElements(), elType); + return success(); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// @@ -2456,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, if (llvm::any_of(elements, [](Attribute attr) { return !attr; })) return {}; + // DenseElementsAttr only supports int/index/float/complex types. auto destVecType = fromElementsOp.getDest().getType(); auto destEltType = destVecType.getElementType(); + if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType)) + return {}; + // Constant attributes might have a different type than the return type. // Convert them before creating the dense elements attribute. auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) { @@ -2768,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo( Type srcType, VectorType dstVectorType, std::pair<VectorDim, VectorDim> *mismatchingDims) { // Broadcast scalar to vector of the same element type. - if (srcType.isIntOrIndexOrFloat() && dstVectorType && - getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) + if (isa<VectorElementTypeInterface>(srcType) && dstVectorType && + srcType == getElementTypeOrSelf(dstVectorType)) return BroadcastableToResult::Success; // From now on, only vectors broadcast. VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType); @@ -2841,9 +2855,47 @@ LogicalResult BroadcastOp::verify() { llvm_unreachable("unexpected vector.broadcast op error"); } +// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible +// with broadcast's result type and shape_cast only adds or removes ones in the +// leading dimensions. +static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { + auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); + if (!srcShapeCast) + return failure(); + + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + // Check type compatibility. + if (vector::isBroadcastableTo(srcType, destType) != + BroadcastableToResult::Success) + return failure(); + + ArrayRef<int64_t> srcShape = srcType.getShape(); + ArrayRef<int64_t> shapecastShape = + srcShapeCast.getResultVectorType().getShape(); + // Trailing dimensions should be the same if shape_cast only alters the + // leading dimensions. + unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size()); + if (!llvm::equal(srcShape.take_back(numTrailingDims), + shapecastShape.take_back(numTrailingDims))) + return failure(); + + assert(all_of(srcShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + all_of(shapecastShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + "ill-formed shape_cast"); + + broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); + return success(); +} + OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getSourceType() == getResultVectorType()) return getSource(); + if (succeeded(foldBroadcastOfShapeCast(*this))) + return getResult(); + if (!adaptor.getSource()) return {}; auto vectorType = getResultVectorType(); @@ -3238,6 +3290,18 @@ LogicalResult InsertOp::verify() { return success(); } +// Calculate the linearized position of the continuous chunk of elements to +// insert, based on the shape of the value to insert and the positions to insert +// at. +static int64_t calculateInsertPosition(VectorType destTy, + ArrayRef<int64_t> positions) { + llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); + assert(positions.size() <= completePositions.size() && + "positions size must be less than or equal to destTy rank"); + copy(positions, completePositions.begin()); + return linearize(completePositions, computeStrides(destTy.getShape())); +} + namespace { // If insertOp is only inserting unit dimensions it can be transformed to a @@ -3275,6 +3339,132 @@ public: return success(); } }; + +/// Pattern to optimize a chain of insertions. +/// +/// This pattern identifies chains of vector.insert operations that: +/// 1. Only insert values at static positions. +/// 2. Completely initialize all elements in the resulting vector. +/// 3. All intermediate insert operations have only one use. +/// +/// When these conditions are met, the entire chain can be replaced with a +/// single vector.from_elements operation. +/// +/// To keep this pattern simple, and avoid spending too much time on matching +/// fragmented insert chains, this pattern only considers the last insert op in +/// the chain. +/// +/// Example transformation: +/// %poison = ub.poison : vector<2xi32> +/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32> +/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32> +/// -> +/// %result = vector.from_elements %c1, %c2 : vector<2xi32> +class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter &rewriter) const override { + + VectorType destTy = op.getDestVectorType(); + if (destTy.isScalable()) + return failure(); + // Ensure this is the trailing vector.insert op in a chain of inserts. + for (Operation *user : op.getResult().getUsers()) + if (auto insertOp = dyn_cast<InsertOp>(user)) + if (insertOp.getDest() == op.getResult()) + return failure(); + + InsertOp currentOp = op; + SmallVector<InsertOp> chainInsertOps; + while (currentOp) { + // Check cond 1: Dynamic position is not supported. + if (currentOp.hasDynamicPosition()) + return failure(); + + chainInsertOps.push_back(currentOp); + currentOp = currentOp.getDest().getDefiningOp<InsertOp>(); + // Check cond 3: Intermediate inserts have only one use to avoid an + // explosion of vectors. + if (currentOp && !currentOp->hasOneUse()) + return failure(); + } + + int64_t vectorSize = destTy.getNumElements(); + int64_t initializedCount = 0; + SmallVector<bool> initializedDestIdxs(vectorSize, false); + SmallVector<int64_t> pendingInsertPos; + SmallVector<int64_t> pendingInsertSize; + SmallVector<Value> pendingInsertValues; + + for (auto insertOp : chainInsertOps) { + // This pattern can do nothing with poison index. + if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex)) + return failure(); + + // Calculate the linearized position for inserting elements. + int64_t insertBeginPosition = + calculateInsertPosition(destTy, insertOp.getStaticPosition()); + + // The valueToStore operand may be a vector or a scalar. Need to handle + // both cases. + int64_t insertSize = 1; + if (auto srcVectorType = + llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType())) + insertSize = srcVectorType.getNumElements(); + + assert(insertBeginPosition + insertSize <= vectorSize && + "insert would overflow the vector"); + + for (auto index : llvm::seq<int64_t>(insertBeginPosition, + insertBeginPosition + insertSize)) { + if (initializedDestIdxs[index]) + continue; + initializedDestIdxs[index] = true; + ++initializedCount; + } + + // Defer the creation of ops before we can make sure the pattern can + // succeed. + pendingInsertPos.push_back(insertBeginPosition); + pendingInsertSize.push_back(insertSize); + pendingInsertValues.push_back(insertOp.getValueToStore()); + + if (initializedCount == vectorSize) + break; + } + + // Check cond 2: all positions must be initialized. + if (initializedCount != vectorSize) + return failure(); + + SmallVector<Value> elements(vectorSize); + for (auto [insertBeginPosition, insertSize, valueToStore] : + llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize, + pendingInsertValues))) { + auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType()); + + if (!srcVectorType) { + elements[insertBeginPosition] = valueToStore; + continue; + } + + SmallVector<Type> elementToInsertTypes(insertSize, + srcVectorType.getElementType()); + // Get all elements from the vector in row-major order. + auto elementsToInsert = rewriter.create<vector::ToElementsOp>( + op.getLoc(), elementToInsertTypes, valueToStore); + for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) { + elements[insertBeginPosition + linearIdx] = + elementsToInsert.getResult(linearIdx); + } + } + + rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements); + return success(); + } +}; + } // namespace static Attribute @@ -3301,13 +3491,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, !insertOp->hasOneUse()) return {}; - // Calculate the linearized position of the continuous chunk of elements to - // insert. - llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); - copy(insertOp.getStaticPosition(), completePositions.begin()); + // Calculate the linearized position for inserting elements. int64_t insertBeginPosition = - linearize(completePositions, computeStrides(destTy.getShape())); - + calculateInsertPosition(destTy, insertOp.getStaticPosition()); SmallVector<Attribute> insertedValues; Type destEltType = destTy.getElementType(); @@ -3343,7 +3529,8 @@ static Value foldInsertUseChain(InsertOp insertOp) { void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context); + results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, + InsertChainFullyInitialized>(context); } OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { @@ -5599,7 +5786,7 @@ LogicalResult GatherOp::verify() { if (resVType.getElementType() != baseType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(getIndices()) != baseType.getRank()) + if (llvm::size(getOffsets()) != baseType.getRank()) return emitOpError("requires ") << baseType.getRank() << " indices"; if (resVType.getShape() != indVType.getShape()) return emitOpError("expected result dim to match indices dim"); @@ -5671,11 +5858,11 @@ public: if (!isa<MemRefType>(op.getBase().getType())) return rewriter.notifyMatchFailure(op, "base must be of memref type"); - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(), - op.getIndices(), op.getMask(), + op.getOffsets(), op.getMask(), op.getPassThru()); return success(); } @@ -5699,7 +5886,7 @@ LogicalResult ScatterOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) + if (llvm::size(getOffsets()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); @@ -5734,11 +5921,11 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp<MaskedStoreOp>( - op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); + op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2d5cc07..fe066dc 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( vector::populateVectorGatherLoweringPatterns(patterns); } +void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorFromElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 6619619..546099c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -162,7 +162,7 @@ struct GatherOpInterface return failure(); replaceOpWithNewBufferizedOp<vector::GatherOp>( rewriter, gatherOp, gatherOp.getVectorType(), *buffer, - gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(), gatherOp.getPassThru()); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9e287fc..acbf2b7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp + LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp new file mode 100644 index 0000000..c22fd54 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp @@ -0,0 +1,65 @@ +//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.from_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-from-elements" + +using namespace mlir; + +namespace { + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + +} // namespace + +void mlir::vector::populateVectorFromElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<UnrollFromElements>(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e062f55..9830189 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); - Value indexVec = op.getIndexVec(); + Value indexVec = op.getIndices(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = arith::ConstantOp::create(rewriter, loc, resultTy, - rewriter.getZeroAttr(resultTy)); - - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; + auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; Value indexSubVec = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); @@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); - Value subGather = vector::GatherOp::create( - rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, - maskSubVec, passThruSubVec); - result = - vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); - } + return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), + op.getOffsets(), indexSubVec, maskSubVec, + passThruSubVec); + }; - rewriter.replaceOp(op, result); - return success(); + return unrollVectorOp(op, rewriter, unrollGatherFn); } }; @@ -158,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { // 2. Generate new gather indices that will model the // strided access. IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); - VectorType vType = op.getIndexVec().getType(); + VectorType vType = op.getIndices().getType(); Value mulCst = arith::ConstantOp::create( rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); Value newIdxs = - arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst); + arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst); // 3. Create an updated gather op with the collapsed input memref and the // updated indices. Value newGather = vector::GatherOp::create( rewriter, op.getLoc(), op.getResult().getType(), collapsed, - op.getIndices(), newIdxs, op.getMask(), op.getPassThru()); + op.getOffsets(), newIdxs, op.getMask(), op.getPassThru()); rewriter.replaceOp(op, newGather); return success(); @@ -212,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { Value indexVec = rewriter.createOrFold<arith::IndexCastOp>( loc, op.getIndexVectorType().clone(rewriter.getIndexType()), - op.getIndexVec()); - auto baseOffsets = llvm::to_vector(op.getIndices()); + op.getIndices()); + auto baseOffsets = llvm::to_vector(op.getOffsets()); Value lastBaseOffset = baseOffsets.back(); Value result = op.getPassThru(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 45ef7f0..5617b06 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -269,7 +269,7 @@ public: // Replace the `vector.mask` operation. rewriter.replaceOpWithNewOp<GatherOp>( maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), - gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(), passthru); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index bb0f339..c84eb2c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); if (!writeOp) @@ -706,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern { } }; +/// Sink out step op feeding into a warp op yield. +/// Vector step op is treated similar to arith.constant, apart from +/// the result that represents a sequence [0, vec_size). +/// Due to the to vec_size == warp_size limitation, +/// we can simply wrap the lane id into a vector (i.e., broadcast). +/// Supporting vec_size != warp_size may involve preserving the step +/// result and using additional arith ops (the exact details are TBD). +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) { +/// ... +/// %cst = vector.step : vector<32xindex> +/// gpu.yield %cst : vector<1xindex> +/// } +/// ``` +/// To +/// ``` +/// gpu.warp_execute_on_lane_0(%arg0) { +/// ... +/// } +/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex> +struct WarpOpStep final : public WarpDistributionPattern { + using Base::Base; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>); + if (!yieldOperand) + return failure(); + const unsigned operandIdx = yieldOperand->getOperandNumber(); + auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>(); + VectorType resTy = stepOp.getResult().getType(); + if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize())) + return rewriter.notifyMatchFailure( + warpOp, + llvm::formatv("Expected result size ({0}) to be of warp size ({1})", + resTy.getNumElements(), warpOp.getWarpSize())); + VectorType newVecTy = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + rewriter.setInsertionPointAfter(warpOp); + Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(), + newVecTy, warpOp.getLaneid()); + rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec); + return success(); + } +}; + /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -846,8 +891,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern { newYieldValues.reserve(warpOp->getNumResults()); DenseMap<Value, int64_t> dedupYieldOperandPositionMap; DenseMap<OpResult, int64_t> dedupResultPositionMap; - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching @@ -901,8 +945,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Value valForwarded; unsigned resultIndex; for (OpOperand &operand : yield->getOpOperands()) { @@ -1708,8 +1751,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto warpOpYield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp warpOpYield = warpOp.getTerminator(); // Only pick up `ForOp` if it is the last op in the region. Operation *lastNode = warpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); @@ -1826,7 +1868,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands); + forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. @@ -2019,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, - WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>( + WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>( patterns.getContext(), benefit); patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, benefit); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 491b448..7dde631 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -762,6 +762,42 @@ struct LinearizeVectorStore final } }; +/// This pattern linearizes `vector.from_elements` operations by converting +/// the result type to a 1-D vector while preserving all element values. +/// The transformation creates a linearized `vector.from_elements` followed by +/// a `vector.shape_cast` to restore the original multidimensional shape. +/// +/// Example: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32> +/// +/// is converted to: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32> +/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> +/// +struct LinearizeVectorFromElements final + : public OpConversionPattern<vector::FromElementsOp> { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorFromElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + LogicalResult + matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType dstTy = + getTypeConverter()->convertType<VectorType>(fromElementsOp.getType()); + assert(dstTy && "vector type destination expected."); + + OperandRange elements = fromElementsOp.getElements(); + assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) && + "expected same number of elements"); + rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy, + elements); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, - LinearizeVectorStore>(typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c707f38..369857f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa<ViewLikeOpInterface>(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user)) @@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa<ViewLikeOpInterface>(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 2269a40..dbb5eb3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -600,7 +600,7 @@ struct BubbleDownVectorBitCastForExtract // Get the first element of the mixed position as integer. auto mixedPos = extractOp.getMixedPosition(); - if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0])) + if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0])) return failure(); uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt(); @@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { LogicalResult matchAndRewrite(MulOpType mulOp, PatternRewriter &rewriter) const override { - auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); + auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType()); if (!resType) return failure(); if (resType.getRank() != 2) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 501abec..e8ecb0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { // decomposed shape from each of the index, mask, and pass-through // vectors. Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); + loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); Value passThruSubVec = @@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); auto slicedGather = vector::GatherOp::create( - rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), + rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(), indexSubVec, maskSubVec, passThruSubVec); result = rewriter.createOrFold<vector::InsertStridedSliceOp>( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 10ed2bc..841e138 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) { // Attempt to unroll until targetRank or the first scalable dimension (which // cannot be unrolled). auto shapeToUnroll = vType.getShape().drop_back(targetRank); - auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank); - auto it = llvm::find(scalableDimsToUnroll, true); - auto firstScalableDim = it - scalableDimsToUnroll.begin(); + auto inputScalableVecDimsToUnroll = + vType.getScalableDims().drop_back(targetRank); + auto it = llvm::find(inputScalableVecDimsToUnroll, true); + auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin(); if (firstScalableDim == 0) return {}; // All scalable dimensions should be removed now. - scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim); - assert(!llvm::is_contained(scalableDimsToUnroll, true) && + inputScalableVecDimsToUnroll = + inputScalableVecDimsToUnroll.slice(0, firstScalableDim); + assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) && "unexpected leading scalable dimension"); // Create an unroll iterator for leading dimensions. shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim); @@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, ArrayRef<int64_t> inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking, - ArrayRef<bool> scalableDims) { + ArrayRef<bool> inputScalableVecDims) { assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast<ShapedType>(source.getType()); auto sourceShape = sourceShapedType.getShape(); assert(sourceShape.size() == inputVectorSizes.size() && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, padValue.getType(), scalableDims); + auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(), + inputScalableVecDims); assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); @@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = - VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims); + auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), + inputScalableVecDims); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) @@ -385,9 +387,34 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, staticSize <= inputSize; })) { LDBG() << "Input vector sizes must be greater than or equal to iteration " - "space " - "static sizes"; + "space static sizes"; return failure(); } return success(); } + +LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, + vector::UnrollVectorOpFn unrollFn) { + assert(op->getNumResults() == 1 && "expected single result"); + assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type"); + VectorType resultTy = cast<VectorType>(op->getResult(0).getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op->getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/lib/Dialect/WasmSSA/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/CMakeLists.txt new file mode 100644 index 0000000..f33061b2 --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt new file mode 100644 index 0000000..9fc2d7b --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_dialect_library(MLIRWasmSSADialect + WasmSSAOps.cpp + WasmSSADialect.cpp + WasmSSAInterfaces.cpp + WasmSSATypes.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/WasmSSA + + DEPENDS + MLIRWasmSSAOpsIncGen + MLIRWasmSSAInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRCastInterfaces + MLIRDataLayoutInterfaces + MLIRDialect + MLIRInferTypeOpInterface + MLIRIR + MLIRSupport + + PRIVATE + MLIRFunctionInterfaces + ) diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp new file mode 100644 index 0000000..98c3555 --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp @@ -0,0 +1,38 @@ +//===- WebAssemblyDialect.cpp - MLIR WebAssembly dialect implementation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" + +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::wasmssa; + +#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// TableGen'd types definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc" + +void wasmssa::WasmSSADialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp new file mode 100644 index 0000000..61cdf6f --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp @@ -0,0 +1,69 @@ +//===- WasmSSAInterfaces.cpp - WasmSSA Interfaces -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines op interfaces for the WasmSSA dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h" +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/LogicalResult.h" + +namespace mlir::wasmssa { +#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp.inc" + +namespace detail { +LogicalResult verifyLabelBranchingOpInterface(Operation *op) { + auto branchInterface = dyn_cast<LabelBranchingOpInterface>(op); + llvm::FailureOr<LabelLevelOpInterface> res = + LabelBranchingOpInterface::getTargetOpFromBlock( + op->getBlock(), branchInterface.getExitLevel()); + return res; +} + +LogicalResult verifyConstantExpressionInterface(Operation *op) { + Region &initializerRegion = op->getRegion(0); + WalkResult resultState = + initializerRegion.walk([&](Operation *currentOp) -> WalkResult { + if (isa<ReturnOp>(currentOp) || + currentOp->hasTrait<ConstantExprOpTrait>()) + return WalkResult::advance(); + op->emitError("expected a constant initializer for this operator, got ") + << currentOp; + return WalkResult::interrupt(); + }); + return success(!resultState.wasInterrupted()); +} + +LogicalResult verifyLabelLevelInterface(Operation *op) { + Block *target = cast<LabelLevelOpInterface>(op).getLabelTarget(); + Region *targetRegion = target->getParent(); + if (targetRegion != op->getParentRegion() && + targetRegion->getParentOp() != op) + return op->emitError("target should be a block defined in same level than " + "operation or in its region."); + return success(); +} +} // namespace detail + +llvm::FailureOr<LabelLevelOpInterface> +LabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block, + uint32_t breakLevel) { + LabelLevelOpInterface res{}; + for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) { + res = dyn_cast_or_null<LabelLevelOpInterface>(block->getParentOp()); + if (!res) + return failure(); + block = res->getBlock(); + } + return res; +} +} // namespace mlir::wasmssa diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp new file mode 100644 index 0000000..89b62a2 --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp @@ -0,0 +1,494 @@ +//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/Support/Casting.h" + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace { +ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) { + std::string keyword; + std::ignore = opParser.parseOptionalKeywordOrString(&keyword); + if (keyword == "else") + return opParser.parseRegion(elseRegion); + return ParseResult::success(); +} + +void printElseRegion(OpAsmPrinter &opPrinter, Operation *op, + Region &elseRegion) { + if (elseRegion.empty()) + return; + opPrinter.printKeywordOrString("else "); + opPrinter.printRegion(elseRegion); +} + +ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) { + std::string keyword; + auto initLocation = opParser.getCurrentLocation(); + std::ignore = opParser.parseOptionalKeywordOrString(&keyword); + if (keyword == "nested" or keyword == "") { + visibility = StringAttr::get(opParser.getContext(), "nested"); + return ParseResult::success(); + } + + if (keyword == "public" || keyword == "private") { + visibility = StringAttr::get(opParser.getContext(), keyword); + return ParseResult::success(); + } + opParser.emitError(initLocation, "expecting symbol visibility"); + return ParseResult::failure(); +} + +void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op, + Attribute visibility) { + opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref()); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/LogicalResult.h" + +using namespace wasmssa; + +namespace { +inline LogicalResult +inferTeeGetResType(ValueRange operands, + SmallVectorImpl<Type> &inferredReturnTypes) { + if (operands.empty()) + return failure(); + auto opType = dyn_cast<LocalRefType>(operands.front().getType()); + if (!opType) + return failure(); + inferredReturnTypes.push_back(opType.getElementType()); + return success(); +} + +ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) { + std::string importName; + auto *ctx = parser.getContext(); + ParseResult res = parser.parseString(&importName); + result.addAttribute("importName", StringAttr::get(ctx, importName)); + + std::string fromStr; + res = parser.parseKeywordOrString(&fromStr); + if (failed(res) || fromStr != "from") + return failure(); + + std::string moduleName; + res = parser.parseString(&moduleName); + if (failed(res)) + return failure(); + result.addAttribute("moduleName", StringAttr::get(ctx, moduleName)); + + std::string asStr; + res = parser.parseKeywordOrString(&asStr); + if (failed(res) || asStr != "as") + return failure(); + + StringAttr symbolName; + res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(), + result.attributes); + return res; +} +} // namespace + +//===----------------------------------------------------------------------===// +// BlockOp +//===----------------------------------------------------------------------===// + +Block *BlockOp::getLabelTarget() { return getTarget(); } + +//===----------------------------------------------------------------------===// +// BlockReturnOp +//===----------------------------------------------------------------------===// + +std::size_t BlockReturnOp::getExitLevel() { return 0; } + +Block *BlockReturnOp::getTarget() { + return cast<LabelBranchingOpInterface>(getOperation()) + .getTargetOp() + .getOperation() + ->getSuccessor(0); +} + +//===----------------------------------------------------------------------===// +// ExtendLowBitsSOp +//===----------------------------------------------------------------------===// + +LogicalResult ExtendLowBitsSOp::verify() { + auto bitsToTake = getBitsToTake().getValue().getLimitedValue(); + if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8) + return emitError("extend op can only take 8, 16 or 32 bits. Got ") + << bitsToTake; + + if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth()) + return emitError("trying to extend the ") + << bitsToTake << " low bits from a " << getInput().getType() + << " value is illegal"; + return success(); +} + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +Block *FuncOp::addEntryBlock() { + if (!getBody().empty()) { + emitError("adding entry block to a FuncOp which already has one"); + return &getBody().front(); + } + Block &block = getBody().emplaceBlock(); + for (auto argType : getFunctionType().getInputs()) + block.addArgument(LocalRefType::get(argType), getLoc()); + return █ +} + +void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, FunctionType funcType) { + FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested"); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes, + ArrayRef<Type> results, + function_interface_impl::VariadicFlag, + std::string &) { + SmallVector<Type> argTypesWithoutLocal{}; + argTypesWithoutLocal.reserve(argTypes.size()); + llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) { + auto refType = dyn_cast<LocalRefType>(argType); + auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + if (!refType) { + mlir::emitError(loc, "invalid type for wasm.func argument. Expecting " + "!wasm<local T>, got ") + << argType; + return; + } + argTypesWithoutLocal.push_back(refType.getElementType()); + }); + + return builder.getFunctionType(argTypesWithoutLocal, results); + }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +LogicalResult FuncOp::verifyBody() { + if (getBody().empty()) + return success(); + Block &entry = getBody().front(); + if (entry.getNumArguments() != getFunctionType().getNumInputs()) + return emitError("entry block should have same number of arguments as " + "function type. Function type has ") + << getFunctionType().getNumInputs() << ", entry block has " + << entry.getNumArguments(); + + for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate( + getFunctionType().getInputs(), entry.getArgumentTypes())) { + auto blockLocalRefType = dyn_cast<LocalRefType>(blockType); + if (!blockLocalRefType) + return emitError("entry block argument type should be LocalRefType, got ") + << blockType << " for block argument " << argNo; + if (blockLocalRefType.getElementType() != funcSignatureType) + return emitError("func argument type #") + << argNo << "(" << funcSignatureType + << ") doesn't match entry block referenced type (" + << blockLocalRefType.getElementType() << ")"; + } + return success(); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// FuncImportOp +//===----------------------------------------------------------------------===// + +void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, StringRef moduleName, + StringRef importName, FunctionType type) { + FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, + type, {}, {}, odsBuilder.getStringAttr("nested")); +} + +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, Type type, bool isMutable) { + GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable, + odsBuilder.getStringAttr("nested")); +} + +// Custom formats +ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { + StringAttr symbolName; + Type globalType; + auto *ctx = parser.getContext(); + ParseResult res = parser.parseSymbolName( + symbolName, SymbolTable::getSymbolAttrName(), result.attributes); + + res = parser.parseType(globalType); + result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType)); + std::string mutableString; + res = parser.parseOptionalKeywordOrString(&mutableString); + if (res.succeeded() && mutableString == "mutable") + result.addAttribute("isMutable", UnitAttr::get(ctx)); + std::string visibilityString; + res = parser.parseOptionalKeywordOrString(&visibilityString); + if (res.succeeded()) + result.addAttribute("sym_visibility", + StringAttr::get(ctx, visibilityString)); + res = parser.parseColon(); + Region *globalInitRegion = result.addRegion(); + res = parser.parseRegion(*globalInitRegion); + return res; +} + +void GlobalOp::print(OpAsmPrinter &printer) { + printer << " @" << getSymName().str() << " " << getType(); + if (getIsMutable()) + printer << " mutable"; + if (auto vis = getSymVisibility()) + printer << " " << *vis; + printer << " :"; + Region &body = getRegion(); + if (!body.empty()) { + printer << ' '; + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +//===----------------------------------------------------------------------===// +// GlobalGetOp +//===----------------------------------------------------------------------===// + +LogicalResult +GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // If the parent requires a constant context, verify that global.get is a + // constant as defined per the wasm standard. + if (!this->getOperation() + ->getParentWithTrait<ConstantExpressionInitializerOpTrait>()) + return success(); + Operation *symTabOp = SymbolTable::getNearestSymbolTable(*this); + StringRef referencedSymbol = getGlobal(); + Operation *definitionOp = symbolTable.lookupSymbolIn( + symTabOp, StringAttr::get(this->getContext(), referencedSymbol)); + if (!definitionOp) + return emitError() << "symbol @" << referencedSymbol << " is undefined"; + auto definitionImport = dyn_cast<GlobalImportOp>(definitionOp); + if (!definitionImport || definitionImport.getIsMutable()) { + return emitError("global.get op is considered constant if it's referring " + "to a import.global symbol marked non-mutable"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// GlobalImportOp +//===----------------------------------------------------------------------===// + +void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, StringRef moduleName, + StringRef importName, Type type, bool isMutable) { + GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, + type, isMutable, odsBuilder.getStringAttr("nested")); +} + +ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) { + auto *ctx = parser.getContext(); + ParseResult res = parseImportOp(parser, result); + if (res.failed()) + return failure(); + std::string mutableOrSymVisString; + res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); + if (res.succeeded() && mutableOrSymVisString == "mutable") { + result.addAttribute("isMutable", UnitAttr::get(ctx)); + res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); + } + + if (res.succeeded()) + result.addAttribute("sym_visibility", + StringAttr::get(ctx, mutableOrSymVisString)); + res = parser.parseColon(); + + Type importedType; + res = parser.parseType(importedType); + if (res.succeeded()) + result.addAttribute(getTypeAttrName(result.name), + TypeAttr::get(importedType)); + return res; +} + +void GlobalImportOp::print(OpAsmPrinter &printer) { + printer << " \"" << getImportName() << "\" from \"" << getModuleName() + << "\" as @" << getSymName(); + if (getIsMutable()) + printer << " mutable"; + if (auto vis = getSymVisibility()) + printer << " " << *vis; + printer << " : " << getType(); +} + +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +Block *IfOp::getLabelTarget() { return getTarget(); } + +//===----------------------------------------------------------------------===// +// LocalOp +//===----------------------------------------------------------------------===// + +LogicalResult LocalOp::inferReturnTypes( + MLIRContext *context, ::std::optional<Location> location, + ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { + LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties, + regions}; + auto type = adaptor.getTypeAttr(); + if (!type) + return failure(); + auto resType = LocalRefType::get(type.getContext(), type.getValue()); + inferredReturnTypes.push_back(resType); + return success(); +} + +//===----------------------------------------------------------------------===// +// LocalGetOp +//===----------------------------------------------------------------------===// + +LogicalResult LocalGetOp::inferReturnTypes( + MLIRContext *context, ::std::optional<Location> location, + ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { + return inferTeeGetResType(operands, inferredReturnTypes); +} + +//===----------------------------------------------------------------------===// +// LocalSetOp +//===----------------------------------------------------------------------===// + +LogicalResult LocalSetOp::verify() { + if (getLocalVar().getType().getElementType() != getValue().getType()) + return emitError("input type and result type of local.set do not match"); + return success(); +} + +//===----------------------------------------------------------------------===// +// LocalTeeOp +//===----------------------------------------------------------------------===// + +LogicalResult LocalTeeOp::inferReturnTypes( + MLIRContext *context, ::std::optional<Location> location, + ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { + return inferTeeGetResType(operands, inferredReturnTypes); +} + +LogicalResult LocalTeeOp::verify() { + if (getLocalVar().getType().getElementType() != getValue().getType() || + getValue().getType() != getResult().getType()) + return emitError("input type and output type of local.tee do not match"); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +Block *LoopOp::getLabelTarget() { return &getBody().front(); } + +//===----------------------------------------------------------------------===// +// MemOp +//===----------------------------------------------------------------------===// + +void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, LimitType limit) { + MemOp::build(odsBuilder, odsState, symbol, limit, + odsBuilder.getStringAttr("nested")); +} + +//===----------------------------------------------------------------------===// +// MemImportOp +//===----------------------------------------------------------------------===// + +void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, StringRef moduleName, + StringRef importName, LimitType limits) { + MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, + limits, odsBuilder.getStringAttr("nested")); +} + +//===----------------------------------------------------------------------===// +// ReinterpretOp +//===----------------------------------------------------------------------===// + +LogicalResult ReinterpretOp::verify() { + auto inT = getInput().getType(); + auto resT = getResult().getType(); + if (inT == resT) + return emitError("reinterpret input and output type should be distinct"); + if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth()) + return emitError() << "input type (" << inT << ") and output type (" << resT + << ") have incompatible bit widths"; + return success(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {} + +//===----------------------------------------------------------------------===// +// TableOp +//===----------------------------------------------------------------------===// + +void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, TableType type) { + TableOp::build(odsBuilder, odsState, symbol, type, + odsBuilder.getStringAttr("nested")); +} + +//===----------------------------------------------------------------------===// +// TableImportOp +//===----------------------------------------------------------------------===// + +void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef symbol, StringRef moduleName, + StringRef importName, TableType type) { + TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, + type, odsBuilder.getStringAttr("nested")); +} diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp new file mode 100644 index 0000000..bee8c81 --- /dev/null +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp @@ -0,0 +1,18 @@ +//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/Support/LogicalResult.h" + +#include <optional> + +namespace mlir::wasmssa { +#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.cpp.inc" +} // namespace mlir::wasmssa diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 242a97c..7869a28 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -7,13 +7,18 @@ add_mlir_dialect_library(MLIRXeGPUDialect DEPENDS MLIRXeGPUIncGen + MLIRXeGPUAttrInterfaceIncGen MLIRXeGPUAttrsIncGen MLIRXeGPUEnumsIncGen LINK_LIBS PUBLIC MLIRArithDialect + MLIRIndexDialect + MLIRAffineUtils MLIRArithUtils MLIRDialectUtils + MLIRGPUDialect + MLIRXeVMDialect MLIRIR MLIRViewLikeInterface MLIRVectorDialect diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 3c0ca114..7f3be7f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -6,12 +6,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" using std::optional; @@ -33,10 +37,61 @@ void XeGPUDialect::initialize() { >(); } +/// Generates instructions to compute offsets for a subgroup identified by +/// its multidimensional indices (sgId), using the specified subgroup layout +/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data +/// dimensions (sizePerWg). +static SmallVector<SmallVector<Value>> +genOffsetsComputingInsts(OpBuilder &builder, Location loc, + SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, + ArrayRef<int64_t> sizePerSg, + ArrayRef<int64_t> sizePerWg) { + + SmallVector<SmallVector<Value>> offsets; + + // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] + SmallVector<Value> localOffsets = llvm::map_to_vector( + llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { + return builder.createOrFold<index::MulOp>( + loc, std::get<0>(t), + builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); + }); + + // distUnit[i] is the minimum value between sizePerWg[i] and + // sgLayout[i] * sizePerSg[i] + SmallVector<int64_t> distUnit = llvm::map_to_vector( + llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); + + for (SmallVector<int64_t> unitOffs : + StaticTileOffsetRange(sizePerWg, distUnit)) { + SmallVector<Value> base = + llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { + return arith::ConstantIndexOp::create(builder, loc, d); + }); + + SmallVector<Value> adds = llvm::map_to_vector( + llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { + return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), + std::get<1>(t)); + }); + + SmallVector<Value> mods = llvm::map_to_vector( + llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { + return builder.createOrFold<index::RemUOp>( + loc, std::get<0>(t), + arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); + }); + + offsets.push_back(mods); + } + return offsets; +} + // Checks if the given shape can be evenly distributed based on the layout // and data factors provided by the LayoutAttr. bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, - xegpu::LayoutAttr attr) { + xegpu::DistributeLayoutAttr attr) { assert(attr && "Layout attribute is missing."); // Checks whether the given shape can be evenly distributed using the @@ -49,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, // smaller than `layout[i] * data[i]`, allowing multiple compute units to // share the data. auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape, - DenseI32ArrayAttr layout, DenseI32ArrayAttr data, + SmallVector<int64_t> layout, + SmallVector<int64_t> data, bool rr = true) -> optional<SmallVector<int64_t>> { llvm::SmallVector<int64_t> newShape(shape); - if (layout) { - auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef()); - if (vec.size() != shape.size()) + if (layout.size()) { + if (layout.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(shape, vec); + auto ratio = computeShapeRatio(shape, layout); if (!ratio.has_value()) return std::nullopt; newShape = ratio.value(); } - if (data) { - auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef()); - if (vec.size() != shape.size()) + if (data.size()) { + if (data.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(newShape, vec); + auto ratio = computeShapeRatio(newShape, data); if (!ratio.has_value() && rr) - ratio = computeShapeRatio(vec, newShape); + ratio = computeShapeRatio(data, newShape); if (!ratio.has_value()) return std::nullopt; // if data is not null, we always return it for next phase. - newShape = vec; + newShape = data; } return newShape; }; // check the sgLayout and sgData auto maybeSgShape = - tryDistribute(shape, attr.getSgLayout(), attr.getSgData()); + tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt()); if (!maybeSgShape) return false; auto sgShape = maybeSgShape.value(); // check InstData, it neither have layout nor need round-robin auto maybeInstShape = - tryDistribute(sgShape, nullptr, attr.getInstData(), false); + tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false); if (!maybeInstShape) return false; auto instShape = maybeInstShape.value(); // check LaneLayout and LaneData - auto maybeLaneShape = - tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false); + auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(), + attr.getLaneDataAsInt(), false); return maybeLaneShape.has_value(); } @@ -211,6 +265,150 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, return success(); } +FailureOr<SmallVector<Value>> +LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, + Value linearId) { + // delinearizeSubgroupId is only available for + // workgroup-level layout attribute + if (!isForWorkgroup()) + return failure(); + + // TODO: handle order attribute + auto hasDefaultOrder = [&]() { + DenseI32ArrayAttr order = getOrder(); + return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( + llvm::reverse(order.asArrayRef()))); + }; + if (!hasDefaultOrder()) + return mlir::emitError(loc, "order attribute is currently not supported."); + + auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value { + return builder.createOrFold<arith::ConstantIndexOp>(loc, d); + }); + + return affine::delinearizeIndex(builder, loc, linearId, dims); +} + +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// LayoutAttr. +FailureOr<SmallVector<SmallVector<Value>>> +LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, + ArrayRef<int64_t> shape) { + if (!isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sgLayout = getSgLayoutAsInt(); + SmallVector<int64_t> sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } + + // delinearize Ids + auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + if (failed(maybeIds)) + return failure(); + SmallVector<Value> sgIds = *maybeIds; + + return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, + shape); +} + +//===----------------------------------------------------------------------===// +// XeGPU_SliceAttr +//===----------------------------------------------------------------------===// +LogicalResult +SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError, + xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) { + if (!parent || !dims) + return emitError() << "expected parent layout and dims attribute"; + + int64_t rank = parent.getRank(); + + // check every element in dims is unique and smaller than rank + llvm::SmallDenseSet<int64_t> seen; + for (int64_t dim : dims.asArrayRef()) { + if (dim < 0 || dim >= rank) + return emitError() << "invalid dim (" << dim << ") in slice attribute."; + if (!seen.insert(dim).second) + return emitError() << "repeated dim (" << dim << ") in slice attribute."; + } + return success(); +} + +SliceAttr SliceAttr::flatten() const { + xegpu::DistributeLayoutAttr parent = getParent(); + SmallVector<DenseI64ArrayAttr> slicedDims({getDims()}); + + while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) { + parent = sliceAttr.getParent(); + slicedDims.push_back(sliceAttr.getDims()); + } + + auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent); + SmallVector<int64_t> indices = + llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank())); + + // get remaining dims (flattend) by applying slice ops with all slicedDims + SmallVector<int64_t> remainingDims(indices); + for (auto dim : llvm::reverse(slicedDims)) + remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims), + dim.asArrayRef()); + + // get flattend sliced dims by applying slice ops with the remaining dims + SmallVector<int64_t> flattendDims = XeGPUDialect::slice( + llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims)); + + return xegpu::SliceAttr::get( + getContext(), layoutAttr, + DenseI64ArrayAttr::get(getContext(), flattendDims)); +} + +FailureOr<SmallVector<Value>> +SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, + Value linearId) { + SliceAttr attr = flatten(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + return parent.delinearizeSubgroupId(builder, loc, linearId); +} + +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// SliceAttr. +FailureOr<SmallVector<SmallVector<Value>>> +SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, + ArrayRef<int64_t> shape) { + assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); + if (!isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sgLayout = getSgLayoutAsInt(); + SmallVector<int64_t> sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } + + // delinearize Ids + auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + if (failed(maybeIds)) + return failure(); + + // The effective sgIds for offsets computing correspond + // to the dims that are not sliced. + ArrayRef<int64_t> dims = flatten().getDims().asArrayRef(); + SmallVector<Value> sgIds = + XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); + + return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, + shape); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// @@ -230,7 +428,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, // XeGPU_TensorDescType //===----------------------------------------------------------------------===// -mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { +mlir::Type TensorDescType::parse(AsmParser &parser) { llvm::SmallVector<int64_t> shape; mlir::Type elementType; mlir::FailureOr<mlir::Attribute> encoding; @@ -280,7 +478,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { layout.value_or(mlir::Attribute())); } -void TensorDescType::print(::mlir::AsmPrinter &printer) const { +void TensorDescType::print(AsmPrinter &printer) const { printer << "<"; auto shape = getShape(); @@ -325,10 +523,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, return Base::get(context, shape, elementType, attr, layout); } -LogicalResult TensorDescType::verify( - llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - llvm::ArrayRef<int64_t> shape, mlir::Type elementType, - mlir::Attribute encoding, mlir::Attribute layout) { +LogicalResult +TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, + llvm::ArrayRef<int64_t> shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); if (rank == 0) @@ -394,6 +592,119 @@ LogicalResult TensorDescType::verify( return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// +mlir::Type MemDescType::parse(AsmParser &parser) { + llvm::SmallVector<int64_t> shape; + mlir::Type elementType; + mlir::FailureOr<MemLayoutAttr> layout; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + auto shapeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseDimensionList(shape, false, true))) { + parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); + return {}; + } + + auto elemTypeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseType(elementType))) { + parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); + return {}; + } + + // parse optional attributes + if (mlir::succeeded(parser.parseOptionalComma())) { + MemLayoutAttr attr; + ParseResult res = parser.parseAttribute(attr); + if (mlir::failed(res)) + return {}; + layout = attr; + } + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + MLIRContext *ctxt = parser.getContext(); + return MemDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape, + elementType, layout.value_or(MemLayoutAttr())); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + + printer.printDimensionList(getShape()); + printer << 'x'; + printer << getElementType(); + + if (auto layout = getMemLayout()) + printer << ", " << layout; + + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// + +Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) { + + auto context = parser.getContext(); + llvm::SMLoc loc = parser.getCurrentLocation(); + + llvm::SmallDenseSet<StringRef> seenKeys; + SmallVector<NamedAttribute> attributes; + + auto parseElt = [&]() -> ParseResult { + StringRef nameId; + if (failed(parser.parseKeyword(&nameId))) + return parser.emitError(loc, "expected valid attribute name"); + + if (!seenKeys.insert(nameId).second) + return parser.emitError(loc, "duplicate key '") + << nameId << " in mem layout attribute"; + + if (failed(parser.parseEqual())) + return failure(); + + Attribute attr; + if (failed(parser.parseAttribute(attr))) + return failure(); + attributes.emplace_back(nameId, attr); + return success(); + }; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + if (failed(parser.parseCommaSeparatedList(parseElt))) + return {}; + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + return parser.getChecked<MemLayoutAttr>( + loc, context, DictionaryAttr::get(context, attributes)); +} + +void MemLayoutAttr::print(AsmPrinter &printer) const { + printer << "<"; + ArrayRef<NamedAttribute> attrs = getAttrs().getValue(); + for (size_t i = 0; i < attrs.size(); i++) { + printer << attrs[i].getName().str() << " = " << attrs[i].getValue(); + if (i < attrs.size() - 1) + printer << ", "; + } + printer << ">"; +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 33450f3..aca6654 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -21,6 +23,17 @@ namespace mlir { namespace xegpu { +bool isSharedMemory(const MemRefType &memrefTy) { + Attribute attr = memrefTy.getMemorySpace(); + if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) + return intAttr.getInt() == 3; + if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr)) + return memrefSpace.getValue() == MemorySpace::SLM; + if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr)) + return xevmSpace.getValue() == xevm::AddrSpace::SHARED; + return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); +} + template <typename T> static std::string makeString(T array, bool breakline = false) { std::string buf; @@ -45,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) { return shape; } -static int64_t getRankOf(Value val) { - auto type = val.getType(); - if (auto ty = llvm::dyn_cast<ShapedType>(type)) - return ty.getRank(); - return 0; -} - static bool isReadHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; @@ -76,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, if (!tdescTy.isScattered()) return emitError() << "Expects a scattered TensorDesc."; - if (!valueTy) - return emitError() << "Expecting a vector type result."; + auto chunkSize = tdescTy.getChunkSizeAsInt(); + if (!valueTy) { + if (chunkSize > 1) + return emitError() << "Expecting chunk size == 1 for scalar result"; + if (dyn_cast<VectorType>(maskTy)) + return emitError() << "Expecting a vector type result."; + return success(); + } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); auto tdescShape = getShapeOf(tdescTy); - auto chunkSize = tdescTy.getChunkSizeAsInt(); if (valueTy.getElementType() != tdescTy.getElementType()) return emitError() @@ -111,25 +122,49 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, } static LogicalResult -isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, - int64_t chunkSize, +isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, + VectorType valueTy, int64_t chunkSize, function_ref<InFlightDiagnostic()> emitError) { - if (!valueTy) - return emitError() << "Expecting a vector type result."; + auto maskVecTy = dyn_cast<VectorType>(maskTy); + auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy); + if (!valueTy) { + if (chunkSize > 1) + return emitError() << "Expecting chunk size == 1 for scalar result"; + if (maskVecTy || offsetsVecTy) + return emitError() << "Expecting scalar mask and offsets."; + else if (maskVecTy && offsetsVecTy) + return emitError() << "Expecting a vector type result."; + return success(); + } + auto valueSize = valueTy.getNumElements(); + // SIMT mode with scalar mask and offsets. + if (!maskVecTy && !offsetsVecTy) { + if (valueSize != chunkSize) + return emitError() << "value elements must match chunk size " + << chunkSize; + return success(); + } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); - // a valid shape for SIMT case - if (valueTy.getRank() == 1) { - if (valueTy.getNumElements() != chunkSize) - return emitError() << "value elements must match chunk size " << chunkSize - << " for SIMT code."; - return success(); + if (!maskVecTy) + return emitError() << "Expecting a vector type mask."; + int64_t maskSize = maskVecTy.getNumElements(); + + if (chunkSize > 1) { + if ((valueTy.getRank() == 1) && (valueSize != chunkSize)) + return emitError() << "value elements must match chunk size " + << chunkSize; + } else { + if (valueSize != maskSize) + return emitError() + << "Mask should match value except the chunk size dim."; } - llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (maskSize == 1) + return success(); if (chunkSize > 1) expectedMaskShape.pop_back(); if (expectedMaskShape != maskShape) @@ -156,41 +191,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<MemRefType> source, + Type tdesc, Value source, llvm::ArrayRef<OpFoldResult> shape, llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); + Type srcTy = source.getType(); + assert((isa<IntegerType, MemRefType>(srcTy)) && + "Source has to be either int or memref."); - llvm::SmallVector<int64_t> staticShape; - llvm::SmallVector<int64_t> staticStrides; llvm::SmallVector<Value> dynamicShape; llvm::SmallVector<Value> dynamicStrides; - dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); - auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); - - build(builder, state, tdesc, source, ValueRange({}), dynamicShape, - dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, - staticStridesAttr); -} - -void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<IntegerType> source, - llvm::ArrayRef<OpFoldResult> shape, - llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); - llvm::SmallVector<int64_t> staticShape; llvm::SmallVector<int64_t> staticStrides; - llvm::SmallVector<Value> dynamicShape; - llvm::SmallVector<Value> dynamicStrides; dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); @@ -198,6 +210,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { + auto memrefShape = memrefTy.getShape(); + auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); + + // if shape and strides are from Memref, we don't need attributes for them + // to keep the IR print clean. + if (staticShape == memrefShape && staticStrides == memrefStrides) { + staticShapeAttr = DenseI64ArrayAttr(); + staticStridesAttr = DenseI64ArrayAttr(); + } + } + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, staticStridesAttr); @@ -265,8 +289,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } LogicalResult CreateNdDescOp::verify() { - auto rank = (int64_t)getMixedOffsets().size(); - bool invalidRank = false; + size_t rank = getMixedSizes().size(); + bool invalidRank = rank != getMixedStrides().size(); bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. @@ -280,31 +304,28 @@ LogicalResult CreateNdDescOp::verify() { << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; + if (size_t offsetRank = getMixedOffsets().size()) + invalidRank |= (offsetRank != rank); + // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. - auto memrefTy = dyn_cast<MemRefType>(getSourceType()); - if (memrefTy) { - invalidRank |= (memrefTy.getRank() != rank); + if (auto memrefTy = dyn_cast<MemRefType>(getSourceType())) invalidElemTy |= memrefTy.getElementType() != getElementType(); - } if (llvm::isa<IntegerType>(getSourceType())) { // strides and shape must present for integer source. if (getMixedStrides().empty() || getMixedSizes().empty()) - return emitOpError("Expecting strides and shape to be present for " + return emitOpError("expecting strides and shape to be present for " "integer source."); } - // mismatches among shape, strides, and offsets are - // already handeled by OffsetSizeAndStrideOpInterface. - // So they are not check here. if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank - if (getType().getRank() > rank) + if (getType().getRank() > (int64_t)rank) return emitOpError( "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); @@ -360,13 +381,10 @@ ParseResult parseOptionalDynamicIndexList( void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers) { - - if (!integers) + if (!integers || integers.empty()) return; - - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, {}, - AsmParser::Delimiter::Square); + printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp @@ -381,6 +399,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, + l2_hint, l3_hint); +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -423,6 +456,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, l3_hint); } +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + UnitAttr packed, DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, + packed, transpose, l1_hint, l2_hint, l3_hint); +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -529,6 +578,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, + l1_hint, l2_hint, l3_hint); +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector @@ -635,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, LogicalResult CreateDescOp::verify() { auto tdescTy = getTensorDescType(); - if (getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); @@ -673,12 +733,14 @@ LogicalResult CreateDescOp::verify() { LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); - if (!tdescTy && getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -689,6 +751,13 @@ LogicalResult PrefetchOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); + auto srcTy = getSourceType(); + if (srcTy.isInteger() && !getOffsetAlignByteAttr()) + return emitOpError("offset_align_byte is required with integer source."); + + if (getOffsetAlignByteAttr() && !srcTy.isInteger()) + return emitOpError("offset_align_byte only allowed with integer source."); + return success(); } @@ -696,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, + IntegerAttr{}); } //===----------------------------------------------------------------------===// @@ -707,13 +777,15 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); + + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc."); - if (!tdescTy && getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -730,10 +802,11 @@ LogicalResult LoadGatherOp::verify() { uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); auto memTy = dyn_cast<MemRefType>(srcTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType())) + if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; - return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + auto offsetsTy = getOffsets().getType(); + return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } @@ -746,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = source.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// @@ -754,12 +843,14 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); - if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); - if (!tdescTy && getRankOf(getDest()) > 1) - return emitOpError( - "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -778,10 +869,11 @@ LogicalResult StoreScatterOp::verify() { uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); auto memTy = dyn_cast<MemRefType>(destTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType())) + if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; - return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + auto offsetsTy = getOffsets().getType(); + return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } @@ -794,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, l2_hint, l3_hint); } +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = dest.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// @@ -888,8 +998,8 @@ LogicalResult ConvertLayoutOp::verify() { // both input and target layouts should be WgLayout or SgLayout at the same // time. - if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) && - (!srcLayout.isSgLayout() || !resLayout.isSgLayout())) + if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && + (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) return emitOpError("expected input layout and target layout be WgLayout or " "SgLayout at the same time."); @@ -928,9 +1038,107 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add<FoldConvertLayoutOp>(context); } +//===----------------------------------------------------------------------===// +// XeGPU_LoadMatrixOp +//===----------------------------------------------------------------------===// +void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + DistributeLayoutAttr layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult LoadMatrixOp::verify() { + VectorType resTy = getRes().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> valueShape = resTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed mem_desc shape."); + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreMatrixOp +//===----------------------------------------------------------------------===// +void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + DistributeLayoutAttr layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult StoreMatrixOp::verify() { + VectorType dataTy = getData().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed mem_desc shape."); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescSubviewOp +//===----------------------------------------------------------------------===// + +void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, + Type resTy, Value src, + llvm::ArrayRef<OpFoldResult> offsets) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); +} + +LogicalResult MemDescSubviewOp::verify() { + MemDescType srcTy = getSrc().getType(); + MemDescType resTy = getRes().getType(); + ArrayRef<int64_t> srcShape = srcTy.getShape(); + ArrayRef<int64_t> resShape = resTy.getShape(); + + if (srcTy.getRank() < resTy.getRank()) + return emitOpError("result rank must not exceed source rank."); + + if (llvm::any_of( + llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + + if (srcTy.getStrides() != resTy.getStrides()) + return emitOpError("result must inherit the source strides."); + + return success(); +} + } // namespace xegpu } // namespace mlir +namespace mlir { +#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> +} // namespace mlir #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> #define GET_OP_CLASSES #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index d82c541..9ee002e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { - xegpu::LayoutAttr input_layout = op.getInputLayoutAttr(); - xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr(); - if (!input_layout.getInstData() || !target_layout.getInstData()) + xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr(); + xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr(); + if (input_layout.getInstDataAsInt().empty() || + target_layout.getInstDataAsInt().empty()) return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp."); input_layout = input_layout.dropInstData(); @@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { else value = (Value)operandOrResult; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); - if (layout && layout.isSgLayout()) { - if (auto inst_data = layout.getInstData()) - return llvm::to_vector_of<int64_t>(inst_data.asArrayRef()); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(operandOrResult); + if (layout && layout.isForSubgroup()) { + if (!layout.getInstDataAsInt().empty()) + return layout.getInstDataAsInt(); if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); @@ -204,13 +206,15 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { // skip the op if any of its operands or results has workgroup level layouts bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); - return layout && layout.isWgLayout(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(opr); + return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); - return layout && layout.isWgLayout(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(result); + return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { LDBG() << "skip unrolling for op with workgroup level layout: " << *op; @@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) { Type valTy = value.getType(); if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) { - xegpu::LayoutAttr layout = tdescTy.getLayoutAttr(); - return layout && layout.getInstData(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); + return layout && !layout.getInstDataAsInt().empty(); } auto shapedType = dyn_cast<ShapedType>(valTy); return shapedType && !llvm::equal(tileShape, shapedType.getShape()); @@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() { // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. // This ensures that the LayoutAttr remains accessible even if the defining // operation is replaced. - xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); }); + xegpu::setDistributeLayoutAttrs( + op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); }); auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { @@ -272,7 +277,7 @@ void XeGPUBlockingPass::runOnOperation() { auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -289,7 +294,7 @@ void XeGPUBlockingPass::runOnOperation() { ArrayRef<int64_t> shape = type.getShape(); xegpu::LayoutAttr layout = type.getLayoutAttr(); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() { if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) { op->removeAttr(name); if (!isa<LoopLikeOpInterface>(op)) - xegpu::setLayoutAttr(result, layout.dropInstData()); + xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); } } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index bef8804..5cb47b2 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, } // If the result is a vector type, add a temporary layout attribute to the // op. - xegpu::setLayoutAttr(result, layout); + xegpu::setDistributeLayoutAttr(result, layout); } return success(); } @@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder, // If the type is a vector type and this region argument is an OpResult, // set the layout attribute on the OpResult. if (auto result = dyn_cast<OpResult>(successorInput)) - xegpu::setLayoutAttr(result, successorOperandLayout); + xegpu::setDistributeLayoutAttr(result, successorOperandLayout); } } return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 8957ea5..dddb5ea 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { descOp, "the tensor descriptor lacks layout attribute"); SmallVector<size_t> newRetIndices; - SmallVector<Value> newYieldValues; - SmallVector<Type> newYieldTypes; - - for (Value operand : descOp->getOperands()) { - newYieldValues.push_back(operand); - newYieldTypes.push_back(operand.getType()); - } rewriter.setInsertionPoint(warpOp); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, /* new yieled values = */ newYieldValues, - /* new yielded types = */ newYieldTypes, newRetIndices); + rewriter, warpOp, /* new yieled values = */ descOp->getOperands(), + /* new yielded types = */ descOp.getOperandTypes(), newRetIndices); - SmallVector<Value> newDescOperands; - for (size_t i : newRetIndices) { - newDescOperands.push_back(newWarpOp.getResult(i)); - } + SmallVector<Value> newDescOperands = llvm::map_to_vector( + newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); }); rewriter.setInsertionPointAfter(newWarpOp); xegpu::TensorDescType distributedTensorDescTy = descOp.getType().dropLayouts(); // Distributed tensor descriptor type @@ -345,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode); if (!storeOp) @@ -458,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { // Make sure the same load op is the last operation in the warp op body. // This ensure that load op is not sinked earlier violating any barrier // synchronizations. - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); return yield->getPrevNode() == op; }); @@ -696,39 +685,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { warpOp, "warp result is not a xegpu::UpdateNdOffset op"); auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>(); unsigned operandIdx = operand->getOperandNumber(); - // new update op does not have layout attribute. - xegpu::TensorDescType newTensorDescTy = - updateOp.getTensorDescType().dropLayouts(); - SmallVector<Value, 3> newYieldValues; - SmallVector<Type, 3> newYieldTypes; - for (Value operand : updateOp->getOperands()) { - newYieldValues.push_back(operand); - if (isa<xegpu::TensorDescType>(operand.getType())) { - newYieldTypes.push_back(newTensorDescTy); - } else { - newYieldTypes.push_back(operand.getType()); - } - } SmallVector<size_t> newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices); + rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(), + newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - SmallVector<Value> newUpdateOperands; - for (size_t i : newRetIndices) { - // For the tensor descriptor operand, the layout attribute is dropped - // after distribution. Types needs to be resolved in this case. - if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) { - newUpdateOperands.push_back(resolveDistributedTy( - newWarpOp.getResult(i), newTensorDescTy, rewriter)); - } else { - newUpdateOperands.push_back(newWarpOp.getResult(i)); - } - } + // new update op does not have layout attribute. + xegpu::TensorDescType distributedTensorDescTy = + updateOp.getTensorDescType().dropLayouts(); + SmallVector<Value> newUpdateOperands = + llvm::map_to_vector(newRetIndices, [&](size_t i) { + // For the tensor descriptor operand, the layout attribute is + // dropped after distribution. Types needs to be resolved in this + // case. + if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) { + return resolveDistributedTy(newWarpOp.getResult(i), + distributedTensorDescTy, rewriter); + } + return newWarpOp.getResult(i); + }); // Create a new update op outside the warp op. auto newUpdateOp = xegpu::UpdateNdOffsetOp::create( - rewriter, newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands, - updateOp->getAttrs()); + rewriter, newWarpOp.getLoc(), distributedTensorDescTy, + newUpdateOperands, updateOp->getAttrs()); xegpu::removeLayoutAttrs(newUpdateOp); Value distributedVal = newWarpOp.getResult(operandIdx); // Resolve the distributed type with the original type. @@ -770,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode); if (!prefetchOp) @@ -812,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); // The last node must be a gpu::BarrierOp. auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode); @@ -859,14 +837,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa<VectorType>(operand.get().getType())) continue; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand); + auto layout = + xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand); if (!layout) { op->emitError("Could not find layout attribute for operand ") << operand.getOperandNumber() << " of operation " << op->getName(); signalPassFailure(); return; } - xegpu::setLayoutAttr(operand, layout); + xegpu::setDistributeLayoutAttr(operand, layout); } }); // Step 2: Move all operations of a GPU function inside @@ -900,7 +879,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (vecRank == 0) return AffineMap::get(val.getContext()); // Get the layout of the vector type. - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); + // TODO: support more layout types + auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val); // If no layout is specified, assume the inner most dimension is distributed // for now. if (!layout) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 850f70c..9f627c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,38 +34,29 @@ using namespace mlir; namespace { -// Check if there is sg id range attached to the scf.if op. -static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, - int64_t &endOfRange) { - Operation *parent = op->getParentOp(); - // Find the outermost scf::IfOp with xegpu.sg_id_range. +// Retrieve the RangeAttr if it is specified. +static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { + Operation *parent = op->getParentOfType<scf::IfOp>(); while (parent) { - if (auto ifOp = dyn_cast<scf::IfOp>(parent)) { - if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>( - ifOp->getAttr("sg_id_range"))) { - startOfRange = attr.getStart().getInt(); - endOfRange = attr.getEnd().getInt(); - break; - } - } - parent = parent->getParentOp(); + if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>( + parent->getAttr("sg_id_range"))) + return attr; + parent = parent->getParentOfType<scf::IfOp>(); } - // Return false if startOfRange is 0 - return (startOfRange > 0 && endOfRange > startOfRange); + return {}; } static std::pair<SmallVector<int64_t>, int> -getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { +getSgShapeAndCount(ArrayRef<int64_t> shape, + xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector<int64_t> sgShape(shape); - - if (layout && layout.isWgLayout()) { - DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); - auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); - else - sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + if (layout && layout.isForWorkgroup()) { + SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt(); + if (!layout.getSgDataAsInt().empty()) + sgShape = layout.getSgDataAsInt(); + else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) + sgShape = *maybeDerivedSgData; SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original @@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { return std::make_pair(sgShape, count); } +/// Utility helper for deriving a list of offsets for each sub-TensorDescs +/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the +/// associated distribute layout attribute, the shape, subgroup id and the +/// original offsets of the op +template < + typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp, + xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> +static LogicalResult +genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, + SmallVector<SmallVector<OpFoldResult>> &offsetsList) { + Location loc = op.getLoc(); + SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets(); + // not applicable to ops without offsets operands. + if (origOffsets.empty()) + return failure(); + + // not applicable to ops without workgroup layout attributes + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr); + + // verify and adjust the sgId if the range specifier is present + xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); + if (sgIdRange) { + int64_t startOfRange = sgIdRange.getStart().getInt(); + int64_t endOfRange = sgIdRange.getEnd().getInt(); + // verify the RangeAttr against the layout attribute + if (layout.getNumSubgroups() != endOfRange - startOfRange) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + // adjust the sgId if necessary + if (startOfRange > 0) { + Value startOfRangeVal = + rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); + sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal); + } + } + + // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory + // descriptors to be accessed, based on the layout information. + ArrayRef<int64_t> wgShape = op.getDataShape(); + auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeDescOffsets)) + return failure(); + + // Compute the final global offsets for each accessed sub-tensor + // or sub-memory descriptor. + for (const auto &sgOffsets : *maybeDescOffsets) { + SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); + offsetsList.push_back(std::move(newOffsets)); + } + + // callback(offsetsList); + return success(); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -125,125 +177,74 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; - // Calculate offset for each subgroup - static SmallVector<OpFoldResult> - calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc, - const SmallVector<OpFoldResult> &originalOffsets, - const SmallVector<Value> &localOffset, - const SmallVector<int64_t> &distUnitBaseAddr, - const SmallVector<int64_t> &distUnitShape) { - assert(localOffset.size() == distUnitBaseAddr.size() && - "localOffset and distUnitBaseAddr must have the same rank"); - - SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(), - originalOffsets.end()); - size_t rank = localOffset.size(); - for (size_t i = 0; i < rank; ++i) { - size_t dimIdx = originalOffsets.size() - rank + i; - Value constOffset = - arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]); - Value offset = - rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset); - Value modValue = - arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]); - Value offsetMod = - rewriter.createOrFold<index::RemUOp>(loc, offset, modValue); - Value origOffset = getValueOrCreateConstantIndexOp( - rewriter, loc, originalOffsets[dimIdx]); - Value globalOffset = - rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod); - globalOffsets[dimIdx] = globalOffset; - } - - return globalOffsets; - } - LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); - auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); - if (!layout) - return failure(); - Type elemTy = tdescTy.getElementType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); - // sgLayout must be present for workgroup-level distribution. - SmallVector<int64_t> sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - else - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - + Type elemTy = tdescTy.getElementType(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + auto newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); - // TODO : Handle order attribute - // Get the subgroup ID - auto linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - // Create constants for layout dimensions - SmallVector<Value> sgLayoutDim(sgLayout.size()); - SmallVector<Value> sgDataDim(sgShape.size()); + SmallVector<Value> newOps; + for (auto offsets : offsetsList) { + auto newOp = xegpu::CreateNdDescOp::create( + rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, + op.getMixedSizes(), op.getMixedStrides()); - for (size_t i = 0; i < sgLayout.size(); i++) { - sgLayoutDim[i] = - arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]); - sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + newOps.push_back(newOp); } + rewriter.replaceOpWithMultiple(op, {newOps}); - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = - isSgIdRangeSpecified(op, startOfRange, endOfRange); - - Value adjustedSgId = linearSgId; - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - // Subtract startOfRange from the original subgroup id to get the adjusted - // sg id - Value startOfRangeVal = - arith::ConstantIndexOp::create(rewriter, loc, startOfRange); - adjustedSgId = - rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); - } + return success(); + } +}; + +// This pattern transforms the CreateNdDescOp without offsets to create a +// subgroup descriptor from a workgroup descriptor +struct WgToSgCreateNdOpNoOffset + : public OpConversionPattern<xegpu::CreateNdDescOp> { + using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; - auto deLinearizeSgId = - affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); - if (failed(deLinearizeSgId)) + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check no offsets are specified. + if (!op.getMixedOffsets().empty()) return failure(); - SmallVector<Value> sgIds = *deLinearizeSgId; - - // Calculate distribution unit shape and local offsets for subgroup - SmallVector<int64_t> distUnitShape(sgLayout.size()); - SmallVector<Value> localOffset(sgLayout.size()); - for (size_t i = 0; i < sgLayout.size(); i++) { - distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]); - localOffset[i] = - rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]); - } - SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets(); + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); + xegpu::TensorDescType tdescTy = op.getType(); + auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + Type elemTy = tdescTy.getElementType(); + ArrayRef<int64_t> wgShape = tdescTy.getShape(); + + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); xegpu::TensorDescType newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); - SmallVector<Value> newCreateNdOps; - for (SmallVector<int64_t> distUnitBaseAddr : - StaticTileOffsetRange(wgShape, distUnitShape)) { - SmallVector<OpFoldResult> globalOffsets = - calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, - distUnitBaseAddr, distUnitShape); - - auto newCreateNdOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), globalOffsets, - op.getMixedSizes(), op.getMixedStrides()); - newCreateNdOps.push_back(newCreateNdOp); - } + + SmallVector<Value> newCreateNdOps(count); + std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { + return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, + op.getSource(), op.getMixedSizes(), + op.getMixedStrides()); + }); rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); return success(); @@ -256,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> { LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> newLoadOps; - - int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); + SmallVector<Value> newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast<xegpu::TensorDescType>(src.getType()); @@ -284,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) @@ -298,6 +295,84 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { } }; +// This pattern transforms the LoadNdOp with explicit offsets to load +// subgroup data. +struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { + using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + SmallVector<Value> newOps; + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType()); + VectorType newResTy = + VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); + auto newOp = xegpu::LoadNdOp::create( + rewriter, op.getLoc(), newResTy, tdesc, offsets, + /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +// This pattern transforms the StoreNdOp with explicit offsets to store +// subgroup data. +struct WgToSgStoreNdOpWithOffset + : public OpConversionPattern<xegpu::StoreNdOp> { + using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + for (auto [v, tdesc, offsets] : + llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { + rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + + return success(); + } +}; + +// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch +// subgroup data. +struct WgToSgPrefetchNdOpWithOffset + : public OpConversionPattern<xegpu::PrefetchNdOp> { + using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + rewriter.create<xegpu::PrefetchNdOp>( + op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + + return success(); + } +}; + /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. @@ -331,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { if (resultTy.getRank() != 2) return failure(); - auto originalLayout = xegpu::getLayoutAttr(op.getResult()); + auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!originalLayout) return failure(); @@ -354,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands); - xegpu::setLayoutAttr(cast<OpResult>(tmpC), - originalLayout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC), + originalLayout.dropSgLayoutAndData()); newDpasOps.push_back(tmpC); } @@ -395,8 +470,9 @@ struct WgToSgVectorBroadcastOp VectorType resultType = op.getResult().getType(); ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) return failure(); // TODO: Currently only supports cases where the source and result ranks @@ -411,10 +487,8 @@ struct WgToSgVectorBroadcastOp VectorType::get(sgShape, resultType.getElementType()); // Check if the output layout is distributable - SmallVector<int64_t> sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - else + SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt(); + if (sgLayout.empty()) return failure(); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) @@ -433,8 +507,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - xegpu::setLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -460,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern { ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); + if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -526,8 +601,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout<inst_data = [16, 16]> // #b = #xegpu.layout<inst_data = [8, 16]> -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp @@ -536,10 +611,12 @@ struct WgToSgConvertLayoutOp LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - xegpu::LayoutAttr input = op.getInputLayout(); - xegpu::LayoutAttr target = op.getTargetLayout(); + // TODO: currently, we only support LayoutAttr + auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout()); + auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout()); - if (!input || !target || !input.isWgLayout() || !target.isWgLayout()) + if (!input || !target || !input.isForWorkgroup() || + !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); @@ -649,16 +726,213 @@ struct UnrealizedConversionCastOpPattern } }; +// This pattern distributes arith.constant op into subgroup-level constants +struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { + using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecAttr || !vecAttr.isSplat() || !vecType) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + ArrayRef<int64_t> wgShape = vecType.getShape(); + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Current limitation: constant of vector with single value. + // TODO: support more complex cases, e.g., vector with multiple values. + Attribute singleVal = vecAttr.getSplatValue<Attribute>(); + + auto newType = VectorType::get(sgShape, vecType.getElementType()); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); + if (auto newLayout = layout.dropSgLayoutAndData()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + SmallVector<Value> newConsts(count, cstOp); + + rewriter.replaceOpWithMultiple(op, {newConsts}); + return success(); + } +}; + +// This pattern transforms the LoadGatherOp with explicit offsets to load +// subgroup data +struct WgToSgLoadGatherOpWithOffset + : public OpConversionPattern<xegpu::LoadGatherOp> { + using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType resultType = dyn_cast<VectorType>(op.getResult().getType()); + if (!resultType) + return failure(); + ArrayRef<int64_t> wgShape = resultType.getShape(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast<VectorType>(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + SmallVector<Value> newLoadOps; + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); + for (auto [offsets, mask] : + llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>( + loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreScatterOp with explicit offsets to store +// subgroup data +struct WgToSgStoreScatterOpWithOffset + : public OpConversionPattern<xegpu::StoreScatterOp> { + using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType valueType = dyn_cast<VectorType>(op.getValue().getType()); + if (!valueType) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getValue()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast<VectorType>(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + auto chunkSizeOpt = op.getChunkSize(); + int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); + for (auto [val, offs, mask] : llvm::zip( + adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + rewriter.create<xegpu::StoreScatterOp>( + loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + // Update the layout attribute to drop sg_layout and sg_data. + if (auto newLayout = layout.dropSgLayoutAndData()) + op->setAttr("layout", newLayout); + } + rewriter.eraseOp(op); + return success(); + } +}; + +struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { + using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + ArrayRef<int64_t> wgShape = op.getDataShape(); + VectorType valueTy = op.getRes().getType(); + Type elemTy = valueTy.getElementType(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResTy = VectorType::get(sgShape, elemTy); + SmallVector<Value> newOps; + for (auto offsets : offsetsList) { + auto newOp = rewriter.create<xegpu::LoadMatrixOp>( + op.getLoc(), newResTy, op.getMemDesc(), offsets, + layout.dropSgLayoutAndData()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> { + using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) + rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(), + offsets, + layout.dropSgLayoutAndData()); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, - WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, - UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, - WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>( - patterns.getContext()); + patterns + .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, + WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset, + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, + WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, + WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, + WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, + WgToSgStoreMatrixOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -748,8 +1022,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, @@ -761,13 +1035,46 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }); target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); + target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>( + [=](xegpu::LoadMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>( + [=](xegpu::StoreMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp<arith::ConstantOp>( + [=](arith::ConstantOp op) -> bool { + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecType) + return true; + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + + target.addDynamicallyLegalOp<xegpu::LoadGatherOp>( + [=](xegpu::LoadGatherOp op) -> bool { + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( + [=](xegpu::StoreScatterOp op) -> bool { + // Check if the layout attribute is present on the result. + auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); + if (!layout) + return true; + return isLegal(layout); + }); + target.addDynamicallyLegalOp<vector::BroadcastOp>( [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( @@ -795,7 +1102,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); return isLegal(layout); }); diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index 98e84a4..d9bf4a1 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils LINK_LIBS PUBLIC MLIRIR MLIRSCFTransforms + MLIRGPUDialect + MLIRXeVMDialect MLIRXeGPUDialect ) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2cf21fb..cac1ffe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -11,6 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -38,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout()); // It only works for subgroup level layout, which only has lane_layout // and lane_data, and is to distribute a SIMD code into SIMT code. - if (!layout || !layout.isSgLayout()) + if (!layout || !layout.isForSubgroup()) return failure(); SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef()); @@ -111,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) { return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); } -xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { +xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { if (!value) return nullptr; @@ -129,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { // for LoadNdOp, the layout is stored in the tensor descriptor if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp)) - return getLayoutAttr(loadNd.getTensorDesc()); + return getDistributeLayoutAttr(loadNd.getTensorDesc()); std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) - return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName); + return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); } if (auto arg = dyn_cast<BlockArgument>(value)) { @@ -141,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) { OpOperand *tiedInit = loop.getTiedLoopInit(arg); if (tiedInit) - return getLayoutAttr(tiedInit->get()); + return getDistributeLayoutAttr(tiedInit->get()); } } return nullptr; } -xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { +xegpu::DistributeLayoutAttr +xegpu::getDistributeLayoutAttr(const OpOperand &opr) { Operation *op = opr.getOwner(); std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) - return op->getAttrOfType<xegpu::LayoutAttr>(layoutName); - return getLayoutAttr(opr.get()); + return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + return getDistributeLayoutAttr(opr.get()); } template <typename T, typename> -void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { +void xegpu::setDistributeLayoutAttr(const T &operandOrResult, + const DistributeLayoutAttr layout) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (layout && !owner->hasAttrOfType<LayoutAttr>(name)) + if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name)) owner->setAttr(name, layout); } // Explicit instantiation for OpResult -template void -xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr<mlir::OpResult>( + const mlir::OpResult &result, + const mlir::xegpu::DistributeLayoutAttr layout); // Explicit instantiation for OpOperand -template void -xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>( + const mlir::OpOperand &operand, + const mlir::xegpu::DistributeLayoutAttr layout); -void xegpu::setLayoutAttrs(Operation *op, - function_ref<LayoutAttr(Value)> getLayoutImpl) { +void xegpu::setDistributeLayoutAttrs( + Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) { op->walk([&](Operation *nestOp) { for (OpOperand &opr : nestOp->getOpOperands()) { auto layout = getLayoutImpl(opr.get()); - setLayoutAttr(opr, layout); + setDistributeLayoutAttr(opr, layout); } for (OpResult result : nestOp->getOpResults()) { auto layout = getLayoutImpl(result); - setLayoutAttr(result, layout); + setDistributeLayoutAttr(result, layout); } }); } @@ -192,7 +197,7 @@ template <typename T, typename> void xegpu::removeLayoutAttr(const T &operandOrResult) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (owner->hasAttrOfType<LayoutAttr>(name)) + if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) owner->removeAttr(name); } @@ -303,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( if (!inputTy || !resultTy) return WalkResult::skip(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(input); if (!layout) return WalkResult::skip(); @@ -341,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( } { // perform the conversion from RankedTensorType to VectorType based on the - // LayoutAttr + // DistributeLayoutAttr // Handle the UnrealizedConversionCastOp introduced by the first step. // For vector->RankedTensorType, it will simply forward the inputs. @@ -404,3 +410,49 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( (void)mlir::applyPartialConversion(op, target, std::move(patterns)); } } + +std::optional<std::string> xegpu::getChipStr(Operation *op) { + auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>(); + + if (!gpuModuleOp) + return std::nullopt; + + auto targetAttrs = gpuModuleOp.getTargets(); + if (targetAttrs) { + for (auto &attr : *targetAttrs) { + auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr); + if (xevmAttr) + return xevmAttr.getChip().str(); + } + } + + return std::nullopt; +} + +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector<OpFoldResult> +xegpu::addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> lhs, + ArrayRef<OpFoldResult> rhs) { + // ensure a is longer than b + ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs; + ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs; + SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size())); + a = a.slice(a.size() - b.size()); + for (auto [l, r] : llvm::zip(a, b)) { + auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); + auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); + results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + } + return results; + return {}; +} |