diff options
Diffstat (limited to 'mlir/lib/Dialect')
127 files changed, 11075 insertions, 1399 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index df955fc..b7a665b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" @@ -339,19 +343,45 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// ScaledExtPacked816Op +// ScaledExtPackedMatrixOp //===----------------------------------------------------------------------===// -LogicalResult ScaledExtPacked816Op::verify() { +LogicalResult ScaledExtPackedMatrixOp::verify() { int blockSize = getBlockSize(); - assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); - if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 1."); - } - if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError( - "blockSize of 32 can only have firstScaleByte be 0 or 2."); + int firstScaleLane = getFirstScaleLane(); + auto sourceType = cast<VectorType>(getSource().getType()); + Type elementType = sourceType.getElementType(); + auto floatType = cast<FloatType>(elementType); + unsigned bitWidth = floatType.getWidth(); + + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8) { + if (is_block_16) { + if (!llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " + "or 1 for f4 and f6."); + } + } else { + if (!llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " + "or 2 for f4 and f6."); + } + } + } else { + if (is_block_16) { + bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || + ((firstScaleLane == 16) && (firstScaleByte == 2)); + if (!is_valid) { + return emitOpError("blockSize of 16 can only have (firstScaleLane, " + "firstScaleByte) be (0, 0) or (16, 2) for f8."); + } + } } return success(); @@ -567,6 +597,53 @@ LogicalResult PermlaneSwapOp::verify() { } //===----------------------------------------------------------------------===// +// MemoryCounterWaitOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fuse adjacent memory counter wait ops, taking the minimum value of the +/// counters. +struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> { + using Base::Base; + + LogicalResult matchAndRewrite(MemoryCounterWaitOp op, + PatternRewriter &rewriter) const override { + auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode()); + if (!next) + return failure(); + + auto setters = {&MemoryCounterWaitOp::setLoad, + &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs, + &MemoryCounterWaitOp::setExp, + &MemoryCounterWaitOp::setTensor}; + auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(), + op.getTensor()}; + auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(), + next.getExp(), next.getTensor()}; + rewriter.modifyOpInPlace(op, [&] { + for (auto [setter, lhs, rhs] : + llvm::zip_equal(setters, lhsVals, rhsVals)) { + if (lhs && rhs) { + (op.*setter)(std::min(*lhs, *rhs)); + } else if (lhs) { + (op.*setter)(*lhs); + } else if (rhs) { + (op.*setter)(*rhs); + } + } + }); + rewriter.eraseOp(next); + return success(); + } +}; +} // namespace + +void MemoryCounterWaitOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<FuseMemoryCounterWaitOp>(context); +} + +//===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// @@ -662,19 +739,123 @@ LogicalResult TransposeLoadOp::verify() { }; auto validNumElems = kValidLoadSizeMap.find(elementTypeSize); - if (validNumElems == kValidLoadSizeMap.end()) { + if (validNumElems == kValidLoadSizeMap.end()) return emitOpError("Unsupported element type size for transpose load: ") << elementTypeSize << " bits"; - } - if (numElements != validNumElems->second) { + + if (numElements != validNumElems->second) return emitOpError( "Transferring type size mismatch: expected num of elements: ") << validNumElems->second; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaBaseOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaBaseOp::verify() { + + auto ldsType = cast<MemRefType>(getLds().getType()); + auto globalType = cast<MemRefType>(getGlobal().getType()); + if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) + return emitOpError( + "lds memref must have workgroup address space attribute."); + if (!hasGlobalMemorySpace(globalType.getMemorySpace())) + return emitOpError( + "global memref must have global address space attribute."); + + Type elementType = ldsType.getElementType(); + unsigned width = elementType.getIntOrFloatBitWidth(); + + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width)) + return emitOpError( + "element type must be 1, 2, 4, or 8 bytes long but type was ") + << width << " bits long."; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaDescriptorOp::verify() { + ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides(); + + if (globalStaticStrides.empty()) + return emitOpError("strides must not be empty."); + if (globalStaticStrides.back() != 1) + return emitOpError("strides for the innermost dimension must be 1."); + + ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes(); + size_t rank = globalStaticSizes.size(); + if (rank > 5) + return emitOpError("tensor and tile must be at most of rank 5."); + if (rank != globalStaticStrides.size()) + return emitOpError("strides and sizes must have same rank."); + + ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes(); + if (rank != sharedStaticSizes.size()) + return emitOpError("tensor must have same rank as tile."); + + unsigned elementTypeWidth = getElementTypeWidth(); + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth)) + return emitOpError( + "element type width must be 1, 2, 4 or 8 bytes, but was ") + << elementTypeWidth << " bits long"; + + if (Value atomicBarrierAddress = getAtomicBarrierAddress()) { + auto atomicBarrierAddressType = + cast<MemRefType>(atomicBarrierAddress.getType()); + bool barrierInLDS = + hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace()); + if (!barrierInLDS) + return emitOpError("atomic barrier address must be in LDS."); } + if (getEarlyTimeout() && !getWorkgroupMask()) + return emitOpError( + "early timeout does not apply when workgroup_mask is not set."); return success(); } +OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) { + SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes()); + SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides()); + SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes()); + + if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true))) + return nullptr; + + SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides, + dynamicSharedSizes; + SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides, + staticSharedSizes; + + dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes, + staticGlobalSizes); + setGlobalStaticSizes(staticGlobalSizes); + getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes); + + dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides, + staticGlobalStrides); + setGlobalStaticStrides(staticGlobalStrides); + getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides); + + dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes, + staticSharedSizes); + setSharedStaticSizes(staticSharedSizes); + getSharedDynamicSizesMutable().assign(dynamicSharedSizes); + return getResult(); +} + //===----------------------------------------------------------------------===// // ScaledMFMAOp //===----------------------------------------------------------------------===// @@ -813,5 +994,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c..89ef51f 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast<MemRefType>(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, PatternRewriter &rewriter) const override { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) - return failure(); + return rewriter.notifyMatchFailure(maskedOp, "already rewritten"); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { - return failure(); + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { + return rewriter.notifyMatchFailure( + maskedOp, "isn't a load from a fat buffer resource"); } // Check if this is either a full inbounds load or an empty, oob load. If @@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { - return failure(); + return rewriter.notifyMatchFailure(loadOp, + "isn't loading a broadcasted scalar"); } Value cond = maybeCond.value(); @@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 4d2d873..3d1a734 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos, .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; }) .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; }) .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; }) + .Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; }) + .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; }) + .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; }) .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> { - // TODO: AtomicRMW supports other kinds of reductions this is - // currently not detecting, add those when the need arises. return std::nullopt; }); if (!maybeKind) diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index b405ec2..edfae7e 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -342,8 +342,7 @@ void FlatAffineValueConstraints::getIneqAsAffineValueMap( if (inequality[pos] > 0) // Lower bound. - std::transform(bound.begin(), bound.end(), bound.begin(), - std::negate<int64_t>()); + llvm::transform(bound, bound.begin(), std::negate<int64_t>()); else // Upper bound (which is exclusive). bound.back() += 1; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0c35921..c6addfb 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -5421,7 +5421,7 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final return rewriter.notifyMatchFailure(op, "no unit basis entries to replace"); - if (newIndices.size() == 0) { + if (newIndices.empty()) { rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index c942c02..b04e2d6 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -82,7 +82,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; - std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); + llvm::copy(oldShape, newShape.begin() + 1); return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 4743941..8f1249e 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1711,6 +1711,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { outermost.getBody()->getOperations().splice( Block::iterator(secondOutermostLoop.getOperation()), innermost.getBody()->getOperations()); + for (auto [iter, init] : + llvm::zip_equal(secondOutermostLoop.getRegionIterArgs(), + secondOutermostLoop.getInits())) { + iter.replaceAllUsesWith(init); + iter.dropAllUses(); + } secondOutermostLoop.erase(); return success(); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 845be20..deba160 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1327,9 +1327,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( assert(cast<MemRefType>(oldMemRef.getType()).getElementType() == cast<MemRefType>(newMemRef.getType()).getElementType()); - std::unique_ptr<DominanceInfo> domInfo; - std::unique_ptr<PostDominanceInfo> postDomInfo; - // Walk all uses of old memref; collect ops to perform replacement. We use a // DenseSet since an operation could potentially have multiple uses of a // memref (although rare), and the replacement later is going to erase ops. diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index de3efc9f..e256915 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -389,8 +389,8 @@ def TruncIExtUIToExtUI : // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index adeb50b..c4e81e5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -35,7 +35,7 @@ static Value createConst(Location loc, Type type, int value, } /// Create a float constant. -static Value createFloatConst(Location loc, Type type, APFloat value, +static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 39e398b..cb7c3d7 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -150,7 +150,7 @@ public: rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); } - auto extOp = op.getLhs().getDefiningOp(); + auto *extOp = op.getLhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { @@ -311,8 +311,8 @@ public: rhsMask = packInputs(rhs0Mask, rhs1Mask); } - auto lhsExtOp = op.getLhs().getDefiningOp(); - auto rhsExtOp = op.getRhs().getDefiningOp(); + auto *lhsExtOp = op.getLhs().getDefiningOp(); + auto *rhsExtOp = op.getRhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 8e4a49d..e19b917 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -17,8 +17,6 @@ using namespace mlir::async; #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" -constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; - void AsyncDialect::initialize() { addOperations< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index e0cf353..9b11270 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType()); - assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { @@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 6c08cdf..bd177ba 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -21,25 +21,6 @@ using namespace mlir::bufferization; #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc" -/// Attribute name used to mark function arguments who's buffers can be written -/// to during One-Shot Module Bufferize. -constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName; - -/// Attribute name used to mark the bufferization layout for region arguments -/// during One-Shot Module Bufferize. -constexpr const ::llvm::StringLiteral - BufferizationDialect::kBufferLayoutAttrName; - -/// An attribute that can be attached to ops with an allocation and/or -/// deallocation side effect. It indicates that the op is under a "manual -/// deallocation" scheme. In the case of an allocation op, the returned -/// value is *not* an automatically managed allocation and assigned an -/// ownership of "false". Furthermore, only deallocation ops that are -/// guaranteed to deallocate a buffer under "manual deallocation" are -/// allowed to have this attribute. (Deallocation ops without this -/// attribute are rejected by the ownership-based buffer deallocation pass.) -constexpr const ::llvm::StringLiteral BufferizationDialect::kManualDeallocation; - //===----------------------------------------------------------------------===// // Bufferization Dialect Interfaces //===----------------------------------------------------------------------===// @@ -73,9 +54,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const { - assert(isa<TensorType>(tensor) && "expected tensor type"); - assert(isa<BaseMemRefType>(bufferType) && "expected memref type"); - auto tensorType = cast<ShapedType>(tensor); auto memrefType = cast<ShapedType>(bufferType); diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp index 51feec7..f8eb45c 100644 --- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp +++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp @@ -17,6 +17,10 @@ // Pipeline implementation. //===----------------------------------------------------------------------===// +void mlir::bufferization::buildBufferDeallocationPipeline(OpPassManager &pm) { + buildBufferDeallocationPipeline(pm, BufferDeallocationPipelineOptions()); +} + void mlir::bufferization::buildBufferDeallocationPipeline( OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { memref::ExpandReallocPassOptions expandAllocPassOptions{ @@ -44,5 +48,7 @@ void mlir::bufferization::registerBufferizationPipelines() { "The default pipeline for automatically inserting deallocation " "operations after one-shot bufferization. Deallocation operations " "(except `memref.realloc`) may not be present already.", - buildBufferDeallocationPipeline); + [](OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { + buildBufferDeallocationPipeline(pm, options); + }); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index b9ee0a4..d0742ec 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -217,7 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, } if (!options.filterFn(&callee)) return; - if (callee.isExternal() || callee.isPublic()) + if (callee.isPublic() && !options.modifyPublicFunctions) + return; + if (callee.isExternal()) return; SmallVector<Value, 6> replaceWithNewCallResults; @@ -295,7 +297,9 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( // function. AllocDynamicSizesMap map; for (auto func : module.getOps<func::FuncOp>()) { - if (func.isExternal() || func.isPublic()) + if (func.isPublic() && !options.modifyPublicFunctions) + continue; + if (func.isExternal()) continue; if (!options.filterFn(&func)) continue; @@ -326,6 +330,8 @@ struct BufferResultsToOutParamsPass options.hoistStaticAllocs = true; if (hoistDynamicAllocs) options.hoistDynamicAllocs = true; + if (modifyPublicFunctions) + options.modifyPublicFunctions = true; if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), options))) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 1784964..677c0ba 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SubsetOpInterface.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace bufferization { @@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, // this replacement. Operation *insertionPoint = findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - return {}; + if (!insertionPoint) { + // If no already suitable insertion point was found, attempt to move all + // needed values before the user. + if (failed(moveValueDefinitions(rewriter, neededValues, user))) + return {}; + insertionPoint = user; + } rewriter.setInsertionPoint(insertionPoint); Value replacement = diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 9ccbfd3..5dfe3e6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -497,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, // terminates. All of them must be equivalent subsets. SetVector<Value> backwardSlice = state.findValueInReverseUseDefChain(opOperand, matchingSubset); - return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); + return llvm::all_of(backwardSlice, matchingSubset); } /// Return "true" if the given "read" and potentially conflicting "write" are diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb..05a787f 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a1..d2078d8 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { return success(replaced); } }; + +/// If the destination block of a conditional branch contains only +/// ub.unreachable, unconditionally branch to the other destination. +struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination is unreachable, branch to the "false" + // destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa<ub::UnreachableOp>(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination is unreachable, branch to the "true" + // destination. + if (llvm::hasSingleElement(*falseDest) && + isa<ub::UnreachableOp>(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, SimplifyCondBranchIdenticalSuccessors, SimplifyCondBranchFromCondBranchOnSameCondition, - CondBranchTruthPropagation>(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp index 173d58b..da572f1 100644 --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -606,11 +606,6 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringRef> keys, return dlti::query(op, entryKeys, emitError); } -constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName; -constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey; -constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig; -constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle; - namespace { class TargetDataLayoutInterface : public DataLayoutDialectInterface { public: diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 0992ce14..b0566dd 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -226,6 +226,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString( } //===----------------------------------------------------------------------===// +// AddressOfOp +//===----------------------------------------------------------------------===// + +LogicalResult AddressOfOp::verify() { + emitc::LValueType referenceType = getReference().getType(); + emitc::PointerType resultType = getResult().getType(); + + if (referenceType.getValueType() != resultType.getPointee()) + return emitOpError("requires result to be a pointer to the type " + "referenced by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -380,6 +395,20 @@ LogicalResult emitc::ConstantOp::verify() { OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } //===----------------------------------------------------------------------===// +// DereferenceOp +//===----------------------------------------------------------------------===// + +LogicalResult DereferenceOp::verify() { + emitc::PointerType pointerType = getPointer().getType(); + + if (pointerType.getPointee() != getResult().getType().getValueType()) + return emitOpError("requires result to be an lvalue of the type " + "pointed to by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // ExpressionOp //===----------------------------------------------------------------------===// @@ -584,6 +613,10 @@ void ForOp::print(OpAsmPrinter &p) { LogicalResult ForOp::verifyRegions() { // Check that the body defines as single block argument for the induction // variable. + if (getBody()->getNumArguments() != 1) + return emitOpError("expected body to have a single block argument for the " + "induction variable"); + if (getInductionVar().getType() != getLowerBound().getType()) return emitOpError( "expected induction variable to be same type as bounds and step"); diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb093..d6dfd02 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr<func::FuncOp> +func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, SymbolTableCollection *symbolTables) { + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn<FuncOp>( + symTable, StringAttr::get(symTable->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null<FuncOp>( + SymbolTable::lookupSymbolIn(symTable, name)); + } + + if (!func) + return func; + + mlir::FunctionType foundFuncT = func.getFunctionType(); + // Assert the signature of the found function is same as expected + if (funcT != foundFuncT) { + return func.emitError("matched function '") + << name << "' but with different type: " << foundFuncT + << " (expected " << funcT << ")"; + } + return func; +} diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2..61a630a 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32() || + return elementType.isF16() || elementType.isF32() || elementType.isF64() || elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || elementType.isInteger(32); } @@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, if (!MMAMatrixType::isValidElementType(elementType)) return emitError() - << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; + << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64"; return success(); } diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt index ec68acf..85b7b1ce 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRNVVMToLLVM MLIRReconcileUnrealizedCasts MLIRSCFToControlFlow + MLIRVectorToLLVMPass MLIRVectorToSCF MLIRXeGPUTransforms MLIRXeGPUToXeVM diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp index 2c3e466..5462cdd 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp @@ -72,6 +72,7 @@ void buildGpuPassPipeline(OpPassManager &pm, ConvertGpuOpsToNVVMOpsOptions opt; opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv; opt.indexBitwidth = options.indexBitWidth; + opt.allowPatternRollback = options.allowPatternRollback; pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt)); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index 1a1485b..38313dc 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm, if (options.xegpuOpLevel == "workgroup") { pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "inst"; + pm.addNestedPass<gpu::GPUModuleOp>( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking()); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); } if (options.xegpuOpLevel == "subgroup" || options.xegpuOpLevel == "workgroup") { - pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "lane"; + pm.addNestedPass<gpu::GPUModuleOp>( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute()); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); @@ -104,8 +111,11 @@ void buildPostGPUCommonPassPipeline( pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions)); } pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertVectorToLLVMPass()); pm.addPass(createConvertToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); // gpu-module-to-binary { GpuModuleToBinaryPassOptions gpuToModuleBinOptions; diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index cd13840..70d2e11 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -143,8 +143,8 @@ private: }; /// Erases `executeOp` and returns a clone with additional `results`. -async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, - ValueRange results) { +static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, + ValueRange results) { // Add values to async.yield op. Operation *yieldOp = executeOp.getBody()->getTerminator(); yieldOp->insertOperands(yieldOp->getNumOperands(), results); diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index 3c44733..95d5cad 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -39,10 +39,10 @@ void GpuModuleToBinaryPass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto targetFormat = llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget) - .Cases("offloading", "llvm", CompilationTarget::Offload) - .Cases("assembly", "isa", CompilationTarget::Assembly) - .Cases("binary", "bin", CompilationTarget::Binary) - .Cases("fatbinary", "fatbin", CompilationTarget::Fatbin) + .Cases({"offloading", "llvm"}, CompilationTarget::Offload) + .Cases({"assembly", "isa"}, CompilationTarget::Assembly) + .Cases({"binary", "bin"}, CompilationTarget::Binary) + .Cases({"fatbinary", "fatbin"}, CompilationTarget::Fatbin) .Default(std::nullopt); if (!targetFormat) getOperation()->emitError() << "Invalid format specified."; diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 212ccc9..8d10aac 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -169,7 +169,7 @@ LogicalResult getSegmentSizes(Operation *op, StringRef elemName, LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "operand", "operand_segment_sizes", + return getSegmentSizes(op, "operand", "operandSegmentSizes", op->getNumOperands(), variadicities, segmentSizes); } @@ -180,7 +180,7 @@ LogicalResult getOperandSegmentSizes(Operation *op, LogicalResult getResultSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "result", "result_segment_sizes", + return getSegmentSizes(op, "result", "resultSegmentSizes", op->getNumResults(), variadicities, segmentSizes); } diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp index 183d0e3..887e8e1 100644 --- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::index; @@ -15,10 +16,23 @@ using namespace mlir::index; //===----------------------------------------------------------------------===// // IndexDialect //===----------------------------------------------------------------------===// +namespace { +/// This class defines the interface for handling inlining for index +/// dialect operations. +struct IndexInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// All index dialect ops can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace void IndexDialect::initialize() { registerAttributes(); registerOperations(); + addInterfaces<IndexInlinerInterface>(); declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>(); } diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index cc66fac..a73f0c1 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRFunctionInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa3..160b6ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index b8331e0..9f87e50 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) { MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context, ArrayRef<ModRefInfo> memInfoArgs) { if (memInfoArgs.empty()) - return MemoryEffectsAttr::get(context, ModRefInfo::ModRef, - ModRefInfo::ModRef, ModRefInfo::ModRef); - if (memInfoArgs.size() == 3) + return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef, + /*argMem=*/ModRefInfo::ModRef, + /*inaccessibleMem=*/ModRefInfo::ModRef, + /*errnoMem=*/ModRefInfo::ModRef, + /*targetMem0=*/ModRefInfo::ModRef, + /*targetMem1=*/ModRefInfo::ModRef); + if (memInfoArgs.size() == 6) return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1], - memInfoArgs[2]); + memInfoArgs[2], memInfoArgs[3], + memInfoArgs[4], memInfoArgs[5]); return {}; } @@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() { return false; if (this->getOther() != ModRefInfo::ModRef) return false; + if (this->getErrnoMem() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem0() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem1() != ModRefInfo::ModRef) + return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 2731069..5b81948 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -640,8 +640,6 @@ SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { // Code for LLVM::GEPOp. //===----------------------------------------------------------------------===// -constexpr int32_t GEPOp::kDynamicIndex; - GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(), getDynamicIndices()); @@ -4226,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() { } //===----------------------------------------------------------------------===// +// UDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability UDivOp::getSpeculatability() { + // X / 0 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroU())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// +// SDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability SDivOp::getSpeculatability() { + // This function conservatively assumes that all signed division by -1 are + // not speculatable. + // X / 0 => UB + // INT_MIN / -1 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroS()) && + matchPattern(divisor, m_IntRangeWithoutNegOneS())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ce93d18..5dc4fa2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, static constexpr llvm::StringRef kSpirvPrefix = "spirv."; static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; +static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier"; bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { // See llvm/lib/IR/Type.cpp for reference. @@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { properties |= (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + if (getExtTypeName() == kAMDGCNNamedBarrier) + properties |= LLVMTargetExtType::CanBeGlobal; + return (properties & prop) == prop; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a5ffb9e..5ce56e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/NVPTXAddrSpace.h" @@ -48,6 +49,47 @@ using namespace NVVM; static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; //===----------------------------------------------------------------------===// +// Helper/Utility methods +//===----------------------------------------------------------------------===// + +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInGenericSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + +static bool isPtrInSharedClusterSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster); +} + +static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder, + llvm::Value *ptr, + NVVMMemorySpace targetAS) { + unsigned AS = static_cast<unsigned>(targetAS); + return builder.CreateAddrSpaceCast( + ptr, llvm::PointerType::get(builder.getContext(), AS)); +} + +// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM +static llvm::nvvm::CTAGroupKind +getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) { + switch (ctaGroup) { + case NVVM::CTAGroupKind::CTA_1: + return llvm::nvvm::CTAGroupKind::CG_1; + case NVVM::CTAGroupKind::CTA_2: + return llvm::nvvm::CTAGroupKind::CG_2; + } + llvm_unreachable("unsupported cta_group value"); +} + +//===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { return success(); } +LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { + bool isSharedCTA = isPtrInSharedCTASpace(getDstMem()); + if (isSharedCTA && getMulticastMask()) + return emitError("Multicast is not supported with shared::cta mode."); + + return success(); +} + +static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, + NVVM::MemScopeKind scope, + Value retVal = nullptr) { + if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER) + return op->emitError("mbarrier scope must be either CTA or Cluster"); + + bool isSharedCluster = isPtrInSharedClusterSpace(addr); + bool hasRetValue = static_cast<bool>(retVal); + if (isSharedCluster && hasRetValue) + return op->emitError( + "mbarrier in shared_cluster space cannot return any value"); + + return success(); +} + +LogicalResult MBarrierArriveOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveExpectTxOp::verify() { + // The inline-ptx version of this Op does not support all features. + // With predicate, this Op lowers to inline-ptx. So, verify and + // error-out if there are unsupported features. + if (getPredicate()) { + if (getScope() != NVVM::MemScopeKind::CTA) + return emitError("mbarrier scope must be CTA when using predicate"); + + if (isPtrInSharedClusterSpace(getAddr())) + return emitError("mbarrier in shared_cluster space is not supported when " + "using predicate"); + + if (getRes()) + return emitError("return-value is not supported when using predicate"); + + if (getRelaxed() == true) + return emitError("mbarrier with relaxed semantics is not supported when " + "using predicate"); + } + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierCompleteTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTestWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTryWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -365,6 +484,108 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +LogicalResult PermuteOp::verify() { + using Mode = NVVM::PermuteMode; + bool hasHi = static_cast<bool>(getHi()); + + switch (getMode()) { + case Mode::DEFAULT: + case Mode::F4E: + case Mode::B4E: + if (!hasHi) + return emitError("mode '") + << stringifyPermuteMode(getMode()) << "' requires 'hi' operand."; + break; + case Mode::RC8: + case Mode::ECL: + case Mode::ECR: + case Mode::RC16: + if (hasHi) + return emitError("mode '") << stringifyPermuteMode(getMode()) + << "' does not accept 'hi' operand."; + break; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Stochastic Rounding Conversion Ops +//===----------------------------------------------------------------------===// + +static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, + FPRoundingMode rnd, + bool hasRandomBits, + Operation *op) { + static constexpr FPRoundingMode validRndModes[] = { + FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; + + if (!llvm::is_contained(validRndModes, rnd)) { + return op->emitOpError( + "Only RN, RZ, and RS rounding modes are supported for " + "conversions from f32x2 to ") + << dstType << "."; + } + + if (rnd == FPRoundingMode::RS) { + if (!hasRandomBits) { + return op->emitOpError("random_bits is required for RS rounding mode."); + } + } else { + if (hasRandomBits) { + return op->emitOpError( + "random_bits not supported for RN and RZ rounding modes."); + } + } + + return success(); +} + +LogicalResult ConvertF32x2ToF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + +LogicalResult ConvertF32x2ToBF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + +LogicalResult ConvertF32x4ToF8x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f32x4 to f8x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF6x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x4 to f6x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF4x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from " + "f32x4 to f4x4."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -866,16 +1087,517 @@ LogicalResult MmaOp::verify() { return success(); } -LogicalResult ShflOp::verify() { - if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) +MMATypes MmaSpOp::accumPtxType() { + std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); + assert(val.has_value() && "accumulator PTX type should always be inferrable"); + return val.value(); +} + +MMATypes MmaSpOp::resultPtxType() { + std::optional<mlir::NVVM::MMATypes> val = + MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); + assert(val.has_value() && "result PTX type should always be inferrable"); + return val.value(); +} + +mlir::NVVM::IDArgPair +MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MmaSpOp>(op); + + // Get operands + llvm::SmallVector<llvm::Value *> args; + for (mlir::Value v : thisOp.getOperands()) + args.push_back(mt.lookupValue(v)); + + // Get intrinsic ID using the existing getIntrinsicID method + auto intId = MmaSpOp::getIntrinsicID( + thisOp.getShape().getM(), thisOp.getShape().getN(), + thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(), + thisOp.getOrderedMetadata(), thisOp.getKind(), + *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(), + thisOp.accumPtxType(), thisOp.resultPtxType()); + + return {intId, args}; +} + +void MmaSpOp::print(OpAsmPrinter &p) { + SmallVector<Type, 4> regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector<Value, 4> regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array<OperandFragment, 5> frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", ""), OperandFragment("sparseMetadata", ""), + OperandFragment("selector", "")}; + SmallVector<StringRef, 4> ignoreAttrNames{ + mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()}; + + // Handle variadic operands A, B, C + for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == varOperandSpec.first) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + // Handle sparse metadata and selector (single operands) + frags[3].regs.push_back(getSparseMetadata()); + frags[4].regs.push_back(getSparsitySelector()); + + auto printMmaSpOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "]"; + }; + + for (const auto &frag : frags) + printMmaSpOperand(frag); + + p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames); + p << " : "; + p << "("; + for (int i = 0; i < 3; ++i) { + p << regTypes[i]; + if (i < 2) + p << ", "; + } + p << ") -> " << getResult().getType(); +} + +void MmaSpOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape, + std::optional<MMAIntOverflow> intOverflow, + std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands(sparseMetadata); + result.addOperands(sparsitySelector); + + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (intOverflow.has_value()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + + result.addTypes(resultType); + result.addAttribute( + MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), + static_cast<int32_t>(operandB.size()), + static_cast<int32_t>(operandC.size()), 1, + 1})); // sparseMetadata and sparsitySelector +} + +ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + std::optional<MMATypes> elemtype; + SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; + SmallVector<Type> regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaSpOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); return success(); - auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); - auto elementType = (type && type.getBody().size() == 2) - ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); + }; + + // Parse the operand segments. + if (parseMmaSpOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaSpOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaSpOperand("C", frags[2]).failed()) + return failure(); + if (parseMmaSpOperand("sparseMetadata", frags[3]).failed()) + return failure(); + if (parseMmaSpOperand("selector", frags[4]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector<Type, 3> operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (const auto &iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + MmaOp::inferOperandMMAType(frag.regTypes[0], + /*isAccumulator*/ iter.index() >= 2); + } + + Type resultType; + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + frags[5].elemtype = + MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true); + + // Resolve sparse metadata and selector (assume i32 type) + Type i32Type = builder.getIntegerType(32); + if (parser + .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + if (parser + .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + + std::array<StringRef, 2> names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.has_value() && !attr.has_value()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.has_value()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast<int32_t>(frags[0].regs.size()), + static_cast<int32_t>(frags[1].regs.size()), + static_cast<int32_t>(frags[2].regs.size()), + 1, // sparseMetadata + 1 // sparsitySelector + })); + return success(); +} + +LogicalResult MmaSpOp::verify() { + MLIRContext *context = getContext(); + auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32Ty = Float32Type::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), + getShapeAttr().getK()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; + using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector<Type> expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (*getMultiplicandAPtxType()) { + case MMATypes::tf32: + kFactor = 4; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k8 and m16n8k16 for tf32 + allowedShapes.push_back({16, 8, 8}); + allowedShapes.push_back({16, 8, 16}); + break; + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k16 and m16n8k32 for bf16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::f16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k16 and m16n8k32 for f16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4 + allowedShapes.push_back({16, 8, 64}); + allowedShapes.push_back({16, 8, 128}); + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8 + allowedShapes.push_back({16, 8, 32}); + allowedShapes.push_back({16, 8, 64}); + break; + case MMATypes::e4m3: + case MMATypes::e5m2: + case MMATypes::e3m2: + case MMATypes::e2m3: + case MMATypes::e2m1: + kFactor = 32; + multiplicandFragType = i32Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k64 for FP8 types + allowedShapes.push_back({16, 8, 64}); + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(getMultiplicandAPtxType().value())); + } + + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 && + *getMultiplicandAPtxType() <= MMATypes::e2m1) { + // FP8 types + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + // For sparse MMA, A operand is compressed (2:4 sparsity means half the + // elements) + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2; + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + + if (resultPtxType() != accumPtxType()) + return emitOpError("ctype does not match dtype"); + } + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (*getMultiplicandAPtxType() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (*getMultiplicandAPtxType() == MMATypes::f64) { + Type f64Ty = Float64Type::get(context); + expectedA.emplace_back(1, f64Ty); + expectedB.emplace_back(1, f64Ty); + expectedC.emplace_back(2, f64Ty); + expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( + context, SmallVector<Type>(2, f64Ty))); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 16}); + } + } + + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); + + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::is_contained(allowedShapes, mmaShape)) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); + } + + // Verify the operand types for segments of A, B, and C operands. + std::array<StringRef, 3> operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = llvm::is_contained(iter.value(), operandTySeg); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorMessage); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorMessage); + } + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*getMultiplicandAPtxType()) || + isInt8PtxType(*getMultiplicandAPtxType())) { + if (!getIntOverflowBehavior()) + return emitOpError("op requires " + + getIntOverflowBehaviorAttrName().strref() + + " attribute"); + } + + // Validate sparse metadata type (should be i32) + if (!getSparseMetadata().getType().isInteger(32)) { + return emitOpError() << "sparse metadata must be i32 type"; + } + + // Validate sparsity selector type (should be i32) + if (!getSparsitySelector().getType().isInteger(32)) { + return emitOpError() << "sparsity selector must be i32 type"; + } + + return success(); +} + +LogicalResult ShflOp::verify() { + auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); + + auto verifyTypeError = [&](Twine desc, Type expectedType, + Type actualType) -> LogicalResult { + return emitOpError("expected " + desc + " to be of type ") + << expectedType << " but got " << actualType << " instead"; + }; + + if (returnStructType) { + if (!getReturnValueAndIsValid()) + return emitOpError("\"return_value_and_is_valid\" attribute must be " + "specified when the return type is a struct type"); + + if (returnStructType.getBody().size() != 2) + return emitOpError("expected return type to be a two-element struct"); + + llvm::ArrayRef<Type> returnStruct = returnStructType.getBody(); + auto resultType = returnStruct[0]; + if (resultType != getVal().getType()) + return verifyTypeError("first element in the returned struct", + getVal().getType(), resultType); + + auto predicateType = returnStruct[1]; + if (!predicateType.isInteger(1)) + return verifyTypeError("second element in the returned struct", + mlir::IntegerType::get(getContext(), 1), + predicateType); + } else { + if (getReturnValueAndIsValid()) + return emitOpError("expected return type to be a two-element struct"); + + if (getType() != getVal().getType()) + return verifyTypeError("return type", getVal().getType(), getType()); + } return success(); } @@ -1376,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues( return true; // Has manual mapping } +LogicalResult NVVM::FenceSyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + return success(); +} + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1398,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); - return success(); } @@ -1410,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); + return success(); +} + +LogicalResult NVVM::FenceProxySyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + if (getFromProxy() != NVVM::ProxyKind::GENERIC) + return emitOpError("only generic is support for from_proxy attribute"); + + if (getToProxy() != NVVM::ProxyKind::async) + return emitOpError("only async is supported for to_proxy attribute"); return success(); } @@ -1426,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); + + if (getBarrierId() && (getReductionOp() || getReductionPredicate())) + return emitOpError("reduction are only available when id is 0"); + + if ((getReductionOp() && !getReductionPredicate()) || + (!getReductionOp() && getReductionPredicate())) + return emitOpError("reduction predicate and reduction operation must be " + "specified together"); + return success(); } @@ -1577,6 +2326,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() { return success(); } +LogicalResult NVVM::ReduxOp::verify() { + mlir::Type reduxType = getType(); + + if (!reduxType.isF32()) { + if (getAbs()) + return emitOpError("abs attribute is supported only for f32 type"); + if (getNan()) + return emitOpError("nan attribute is supported only for f32 type"); + } + + NVVM::ReduxKind kind = getKind(); + switch (kind) { + case NVVM::ReduxKind::ADD: + case NVVM::ReduxKind::AND: + case NVVM::ReduxKind::OR: + case NVVM::ReduxKind::XOR: + case NVVM::ReduxKind::MAX: + case NVVM::ReduxKind::MIN: + case NVVM::ReduxKind::UMAX: + case NVVM::ReduxKind::UMIN: + if (!reduxType.isInteger(32)) + return emitOpError("'") + << stringifyEnum(kind) << "' redux kind unsupported with " + << reduxType << " type. Only supported type is 'i32'."; + break; + case NVVM::ReduxKind::FMIN: + case NVVM::ReduxKind::FMAX: + if (!reduxType.isF32()) + return emitOpError("'") + << stringifyEnum(kind) << "' redux kind unsupported with " + << reduxType << " type. Only supported type is 'f32'."; + break; + } + + return success(); +} + /// Packs the given `field` into the `result`. /// The `result` is 64-bits and each `field` can be 32-bits or narrower. static llvm::Value * @@ -1626,26 +2412,76 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, //===----------------------------------------------------------------------===// std::string NVVM::MBarrierInitOp::getPtx() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); - return (addressSpace == NVVMMemorySpace::Shared) - ? std::string("mbarrier.init.shared.b64 [%0], %1;") - : std::string("mbarrier.init.b64 [%0], %1;"); + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +std::string NVVM::MBarrierArriveExpectTxOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared + ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;") + : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); +} + +std::string NVVM::MBarrierTryWaitParityOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + llvm::StringRef space = isShared ? ".shared" : ""; + + return llvm::formatv("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}", + space); } //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::BarrierOp>(op); + llvm::Value *barrierId = thisOp.getBarrierId() + ? mt.lookupValue(thisOp.getBarrierId()) + : builder.getInt32(0); + llvm::Intrinsic::ID id; + llvm::SmallVector<llvm::Value *> args; + if (thisOp.getNumberOfThreads()) { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count; + args.push_back(barrierId); + args.push_back(mt.lookupValue(thisOp.getNumberOfThreads())); + } else if (thisOp.getReductionOp()) { + switch (*thisOp.getReductionOp()) { + case NVVM::BarrierReduction::AND: + id = llvm::Intrinsic::nvvm_barrier0_and; + break; + case NVVM::BarrierReduction::OR: + id = llvm::Intrinsic::nvvm_barrier0_or; + break; + case NVVM::BarrierReduction::POPC: + id = llvm::Intrinsic::nvvm_barrier0_popc; + break; + } + args.push_back(mt.lookupValue(thisOp.getReductionPredicate())); + } else { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all; + args.push_back(barrierId); + } + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInitOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) - ? llvm::Intrinsic::nvvm_mbarrier_init_shared - : llvm::Intrinsic::nvvm_mbarrier_init; + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; @@ -1658,16 +2494,353 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInvalOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_inval_shared : llvm::Intrinsic::nvvm_mbarrier_inval; return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +bool MBarrierArriveExpectTxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + // Add all the operands but not the attrs to the asmValues list. + // The attrs here are used to generate the right variants for + // intrinsics-lowering. So, we ignore them while generating inline-PTX. + for (auto val : getOperands()) + asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); + + return false; +} + +mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: isPhaseParity + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, input}}; +} + +mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + bool hasTicks = static_cast<bool>(thisOp.getTicks()); + // bit-0: isPhaseParity + // bit-1: Scope + // bit-2: hasTicks + size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) | + (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the mbarrier pointer + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mbar); + args.push_back(mt.lookupValue(thisOp.getStateOrPhase())); + if (hasTicks) + args.push_back(mt.lookupValue(thisOp.getTicks())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + + llvm::Intrinsic::ID id; + if (thisOp.getNoinc()) { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc; + } else { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; + } + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix @@ -1737,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); - // Multicast mask, if available. + // Multicast mask for shared::cluster only, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast<bool>(multicastMask); - llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); - args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem()); + if (!isSharedCTA) { + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) + : i16Unused); + } // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); @@ -1750,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. - args.push_back(builder.getInt1(hasMulticastMask)); + if (!isSharedCTA) + args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + isSharedCTA + ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta + : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } @@ -2469,6 +3649,155 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() +NVVM::IDArgPair +ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rn, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rz, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rs, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); + } +} + +NVVM::IDArgPair +ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rn, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rz, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rs, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); + } +} + +llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite; + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite; + }) + .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast<NVVM::Tcgen05CpOp>(op); bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; @@ -2508,6 +3837,9 @@ LogicalResult Tcgen05LdOp::verify() { if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset()) result = emitError("shape 16x32bx2 requires offset argument"); + if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) + result = emitError("offset argument is only supported for shape 16x32bx2"); + auto resTy = getRes().getType(); unsigned resLen = isa<VectorType>(resTy) ? llvm::cast<VectorType>(resTy).getNumElements() @@ -2751,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( return {intrinsicID, args}; } +mlir::NVVM::IDArgPair +PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::PermuteOp>(op); + NVVM::PermuteMode mode = thisOp.getMode(); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e, + llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8, + llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr, + llvm::Intrinsic::nvvm_prmt_rc16}; + + unsigned modeIndex = static_cast<unsigned>(mode); + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getLo())); + + // Only first 3 modes (Default, f4e, b4e) need the hi operand. + if (modeIndex < 3) + args.push_back(mt.lookupValue(thisOp.getHi())); + + args.push_back(mt.lookupValue(thisOp.getSelector())); + + return {IDs[modeIndex], args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair +Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + const bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + const unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, + NVVM::CTAGroupKind ctaGroup, bool hasAShift, + NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) { + + if (disableOutputLane) { + mlir::VectorType disableOutputLaneType = + cast<mlir::VectorType>(disableOutputLane.getType()); + if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 && + disableOutputLaneType.getNumElements() != 4) || + (ctaGroup == NVVM::CTAGroupKind::CTA_2 && + disableOutputLaneType.getNumElements() != 8)) + return emitError(loc) << "Disable Output Lane of length " + << disableOutputLaneType.getNumElements() + << " is incompatible with CtaGroupAttr"; + } + + if (hasAShift && !isATensor) + return emitError( + loc, "A-shift can be applied only when matrix A is in tensor memory"); + + if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL || + collectorOp == Tcgen05MMACollectorOp::USE)) + return emitError( + loc, "Cannot use collector buffer operation fill or use with ashift"); + + return success(); +} + +LogicalResult Tcgen05MMAOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale + : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.block_scale attributes"); + }(); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, + NVVM::Tcgen05MMABlockScaleKind kind, + NVVM::Tcgen05MMABlockScale blockScale, + Location loc) { + + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT && + kind == Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, "mxf4nvf4 requires block scale attribute"); + + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 && + kind != Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, + llvm::formatv("{} kind does not support block16 attribute", + stringifyEnum(kind))); + + return success(); +} + +LogicalResult Tcgen05MMABlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes"); + }(); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseBlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -2954,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { "Minimum NVVM target SM version is sm_20"); } - gpuModuleOp->walk([&](Operation *op) { - if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { - const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); - if (!requirement.isCompatibleWith(targetSMVersion)) { - op->emitOpError() << "is not supported on " << getChip(); - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + if (gpuModuleOp + ->walk([&](Operation *op) { + if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { + const NVVMCheckSMVersion requirement = + reqOp.getRequiredMinSMVersion(); + if (!requirement.isCompatibleWith(targetSMVersion)) { + op->emitOpError() << "is not supported on " << getChip(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 67573c4..12dd225 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); } +/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations +/// from different files than their containing function. static void setLexicalBlockFileAttr(Operation *op) { - if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) { + Location opLoc = op->getLoc(); + + if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) { auto callerLoc = callSiteLoc.getCaller(); auto calleeLoc = callSiteLoc.getCallee(); LLVM::DIScopeAttr scopeAttr; @@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) { op->setLoc( CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc)); } + + return; + } + + auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); + if (!funcOp) + return; + + FileLineColLoc opFileLoc = extractFileLoc(opLoc); + if (!opFileLoc) + return; + + FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc()); + if (!funcFileLoc) + return; + + StringRef opFile = opFileLoc.getFilename().getValue(); + StringRef funcFile = funcFileLoc.getFilename().getValue(); + + // Handle cross-file operations: add DILexicalBlockFileAttr when the + // operation's source file differs from its containing function. + if (opFile != funcFile) { + auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc()); + if (!funcOpLoc) + return; + auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata()); + if (!scopeAttr) + return; + + auto *context = op->getContext(); + LLVM::DIFileAttr opFileAttr = + LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile), + llvm::sys::path::parent_path(opFile)); + + LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr = + LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0); + + Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr); + op->setLoc(newLoc); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index dcc1ef9..b4b1347 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { // FillOpInterface implementation //===----------------------------------------------------------------------===// +namespace { enum class MatchFillResult { Success = 0, NotLinalgOp, WrongNumOperands, - NotScalarInput + NotScalarInput, + TypeMismatch }; +} // namespace static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast<linalg::LinalgOp>(op); @@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; + // Check that the scalar input type matches the output element type. + OpOperand *output = linalgOp.getDpsInitOperand(0); + Type scalarType = value->get().getType(); + Type outputElementType = getElementTypeOrSelf(output->get().getType()); + if (scalarType != outputElementType) + return MatchFillResult::TypeMismatch; + return MatchFillResult::Success; } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { - auto res = isFillInterfaceImpl(op); + MatchFillResult res = isFillInterfaceImpl(op); if (res == MatchFillResult::NotLinalgOp) return op->emitError("expected a LinalgOp"); if (res == MatchFillResult::WrongNumOperands) return op->emitError("expected op with 1 input and 1 output"); if (res == MatchFillResult::NotScalarInput) return op->emitError("expected op with scalar input"); + if (res == MatchFillResult::TypeMismatch) { + auto linalgOp = cast<linalg::LinalgOp>(op); + Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType(); + Type outputElementType = + getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType()); + return op->emitOpError("expected fill value type (") + << scalarType << ") to match output element type (" + << outputElementType << ")"; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45ed..33ec79b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } -LogicalResult GenericOp::verify() { return success(); } - namespace { /// Remove linalg operations that are just copying the values from inputs to @@ -2091,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor, return failure(); // Single dimension transpose. - if (getPermutation().size() == 0) { + if (getPermutation().empty()) { result.push_back(getInput()); return success(); } @@ -4885,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) { elidedAttrs); } -LogicalResult ElementwiseOp::verify() { - // All necessary checks are done either by - // - EnumAttr (e.g. unknown operation kind) - // - verifyStructuredOpInterface (incorrect map, sizes). - return success(); -} - /// Implements the block region builder for the ElementwiseOp. This is called by /// 'fillStructuredOpRegion'. void ElementwiseOp::regionBuilder( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3a43382..b8c1bad 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -176,7 +176,8 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults( if (auto attr = dyn_cast<Attribute>(paramOrHandle)) { reified.push_back(cast<IntegerAttr>(attr).getInt()); continue; - } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { + } + if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle)); if (params.size() != 1) return transformOp.emitSilenceableError() << "expected a single param"; @@ -997,8 +998,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Iterate over the outputs of the producer and over the loop bbArgs and // check if any bbArg points to the same value as the producer output. In // such case, make the producer output point to the bbArg directly. - for (OpOperand &initOperandPtr : - cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) { + auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone); + if (!dpsInterface) + return; + + for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) { Value producerOperand = clone->getOperand(initOperandPtr.getOperandNumber()); for (BlockArgument containerIterArg : @@ -1060,7 +1064,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, resultNumber, offsets, sizes); // Cleanup clone. - if (dyn_cast<LoopLikeOpInterface>(containingOp)) + if (isa<LoopLikeOpInterface>(containingOp)) rewriter.eraseOp(tileableProducer); return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 22690da..9e6c1e6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); auto rankReducedType = cast<RankedTensorType>( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides)); + reassociation->size(), sliceOp.getSourceType(), sizes)); Location loc = sliceOp.getLoc(); Value newSlice = tensor::ExtractSliceOp::create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cb..421ab5e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1038,6 +1038,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Carries information about a padded dimension. +struct PadDimInfo { + // The resulting shape after padding each dimension. + SmallVector<int64_t> paddedShape; + + // Low and high padding amounts for each dimension. + SmallVector<OpFoldResult> lowPad; + SmallVector<OpFoldResult> highPad; +}; + +/// Computes the expanded padding information for the given pad operation based +/// on the provided expanded shape and reassociation indices. Returns a list of +/// PadDimInfo containing the low and high padding amounts and the padded +/// size for each dimension, or failure if the expansion is not possible. +static FailureOr<PadDimInfo> +computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to expand the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad expansion patterns, but it could be implemented + // similarly to the expansion of linalg.generic ops with linalg.index ops in + // the body, as is done in `updateExpandedGenericOpRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Expanded dimensions cannot have padding because the resulting padding may + // not be representable by a tensor.pad op. There are some special cases where + // it is possible (like expanding unit dims), but supporting these cases is + // NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && (l != 0 || h != 0)) + return failure(); + } + + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.paddedShape.assign(expandedShape); + padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.paddedShape[reInd[0]] = paddedShape[idx]; + padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx]; + padDimInfo.highPad[reInd[0]] = mixedHighPad[idx]; + } + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByExpansion : public OpRewritePattern<tensor::PadOp> { public: @@ -1053,46 +1109,96 @@ public: padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); + RankedTensorType expandedType = reshapeOp.getSrcType(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); - for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() != 1 && (l != 0 || h != 0)) - return failure(); + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), + expandedPadding.lowPad, expandedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + +class FoldReshapeWithProducerPadOpByExpansion + : public OpRewritePattern<tensor::ExpandShapeOp> { +public: + FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); } - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType expandedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); + RankedTensorType expandedType = expandOp.getResultType(); + SmallVector<ReassociationIndices> reassociations = + expandOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); + + Location loc = expandOp->getLoc(); + SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape(); + SmallVector<int64_t> newExpandedShape(expandedType.getShape()); + rewriter.setInsertionPointAfterValue(padOp.getSource()); + SmallVector<OpFoldResult> padSrcSizes = + tensor::getMixedSizes(rewriter, loc, padOp.getSource()); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + // We know that any reassociation with multiple dims is not padded because + // of the requirements of computeExpandedPadding. if (reInd.size() == 1) { - expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; - } - for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(padOp.getMixedLowPad()[idx]); - newHigh.push_back(padOp.getMixedHighPad()[idx]); + newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx); + newExpandedSizes[reInd[0]] = padSrcSizes[idx]; } } - - Location loc = padOp->getLoc(); - RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + RankedTensorType newExpandedType = expandedType.clone(newExpandedShape); + auto newExpandOp = tensor::ExpandShapeOp::create( + rewriter, loc, newExpandedType, padOp.getSource(), reassociations, + newExpandedSizes); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + rewriter.setInsertionPoint(expandOp); auto newPadOp = tensor::PadOp::create( - rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, expandedPaddedType, newExpandOp.getResult(), + expandedPadding.lowPad, expandedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); - rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( - padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + rewriter.replaceOp(expandOp, newPadOp.getResult()); return success(); } @@ -1921,6 +2027,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Computes the collapsed padding information for the given pad operation based +/// on the provided collapsed shape and reassociation indices. Returns a +/// PadDimInfo containing the low and high padding amounts and the collapsed +/// shape for each dimension, or failure if the collapse is not possible. +static FailureOr<PadDimInfo> +computeCollapsedPadding(tensor::PadOp padOp, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to collapse the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad collapsing patterns, but it could be implemented + // similarly to the collapsing of linalg.generic ops with linalg.index ops in + // the body, as is done in `generateCollapsedIndexingRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Collapsed dimensions cannot have padding because this can produce strided + // padding that isn't representable by a tensor.pad op. There are some special + // cases where it is possible (like collapsing unit dims), but supporting + // these cases is NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + for (int64_t dim : reInd) { + if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1) + return failure(); + } + } + + // Initialize padding values for collapsed tensors with zeros + ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + + // Update padding for dimensions that are not being collapsed, and compute + // the collapsed padded shape. + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]]; + padDimInfo.highPad[idx] = mixedHighPad[reInd[0]]; + } + SaturatedInteger collapsedSize = SaturatedInteger::wrap(1); + for (int64_t dim : reInd) { + collapsedSize = + collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]); + } + padDimInfo.paddedShape.push_back(collapsedSize.asInteger()); + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByCollapsing : public OpRewritePattern<tensor::PadOp> { public: @@ -1936,57 +2098,40 @@ public: padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); - for (auto reInd : reassociations) { - if (reInd.size() == 1) - continue; - if (llvm::any_of(reInd, [&](int64_t ind) { - return low[ind] != 0 || high[ind] != 0; - })) { - return failure(); - } - } - - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); - SmallVector<OpFoldResult> expandedPaddedSizes( - getMixedValues(reshapeOp.getStaticOutputShape(), - reshapeOp.getOutputShape(), rewriter)); + SmallVector<OpFoldResult> expandedPaddedSizes = + reshapeOp.getMixedOutputShape(); AffineExpr d0, d1, d2; bindDims(rewriter.getContext(), d0, d1, d2); auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); Location loc = reshapeOp->getLoc(); - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; - OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; + for (auto [reInd, l, h] : + llvm::zip_equal(reassociations, collapsedPadding.lowPad, + collapsedPadding.highPad)) { if (reInd.size() == 1) { - collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; - OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply( rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); - expandedPaddedSizes[reInd[0]] = paddedSize; } - newLow.push_back(l); - newHigh.push_back(h); } RankedTensorType collapsedPaddedType = - paddedType.clone(collapsedPaddedShape); + padOp.getType().clone(collapsedPadding.paddedShape); auto newPadOp = tensor::PadOp::create( - rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), + collapsedPadding.lowPad, collapsedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( @@ -2000,6 +2145,52 @@ private: ControlFusionFn controlFoldingReshapes; }; +class FoldReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern<tensor::CollapseShapeOp> { +public: + FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + SmallVector<ReassociationIndices> reassociations = + reshapeOp.getReassociationIndices(); + RankedTensorType collapsedPaddedType = reshapeOp.getResultType(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); + + Location loc = reshapeOp->getLoc(); + auto newCollapseOp = tensor::CollapseShapeOp::create( + rewriter, loc, padOp.getSource(), reassociations); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(), + collapsedPadding.lowPad, collapsedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(reshapeOp, newPadOp.getResult()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template <typename LinalgType> class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { @@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(), + controlFoldingReshapes); patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); } @@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByCollapsing>( + patterns.getContext(), controlFoldingReshapes); patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index 9974ccd..cbd6357 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -200,10 +200,10 @@ static void populateOpPayload( SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range( - genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); - SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range( - newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector<OpOperand *> origOutputOperands = + llvm::to_vector(llvm::make_pointer_range(genericOp.getDpsInitsMutable())); + SmallVector<OpOperand *> newOutputOperands = + llvm::to_vector(llvm::make_pointer_range(newOp.getDpsInitsMutable())); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 9436f1c..161d978 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -913,8 +913,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter, llvm_unreachable("loop independence prerequisite not met"); // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. - std::copy(loopIterationCounts.begin(), loopIterationCounts.end(), - offsets.begin()); + llvm::copy(loopIterationCounts, offsets.begin()); hoistedPackedTensor = scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) ->getResult(0); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d6..c2485a0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,69 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template <typename ConvOpTy> +static FailureOr<LinalgOp> +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) { + SmallVector<Value> inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + // Ops with no dilations and no strides. + if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || + std::is_same_v<ConvOpTy, linalg::Conv2DOp> || + std::is_same_v<ConvOpTy, linalg::Conv3DOp>) { + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +/// Converts linalg.generic to named linalg.*conv/pooling* where possible. +static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector<int64_t> dilations, strides; +#define CONV_OP_SPECIALIZER(ConvOpTy) \ + if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \ + return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \ + strides); \ + // ----------------------------- + // Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::Conv1DOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv3DOp); + // ----------------------------- + // Depthwise Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); + // ----------------------------- + // Pooling ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); +#undef CONV_OP_SPECIALIZER + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +379,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 705d6f2..8e14ef4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); - if (!shapeSizesToLoopsMap) - return failure(); + assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap"); auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 57b610b..50a84ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -167,7 +167,7 @@ struct LinalgOpTilingInterface llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr); if (!dimExpr) - continue; + return failure(); unsigned position = dimExpr.getPosition(); auto it = mappedOffsets.find(position); if (it != mappedOffsets.end()) { @@ -216,8 +216,6 @@ struct LinalgOpTilingInterface SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { auto linalgOp = cast<LinalgOp>(op); - std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets, - iterationSpaceSizes; SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); @@ -359,6 +357,32 @@ struct LinalgOpTilingInterface /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } + + bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { + // The verifier gives all the necessary requirements for consumer fusion. + return true; + } + + bool isOpFusableWithProducerSlices( + Operation *op, ArrayRef<unsigned> operandNumbers, + ArrayRef<SmallVector<OpFoldResult>> allOffsets, + ArrayRef<SmallVector<OpFoldResult>> allSizes) const { + + auto linalgOp = cast<LinalgOp>(op); + SmallVector<AffineMap> indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + // Check that offsets/sizes are consistent across all operands. + OpBuilder b(op); + SmallVector<OpFoldResult> mappedOffsets, mappedSizes; + return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps, + allOffsets, allSizes, mappedOffsets, + mappedSizes)); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index bd25e94..67e2b9f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, // 2. Compute the permutation vector to shuffle packed shape into the shape // before any outer or inner permutations have been applied. - PackingMetadata packingMetadata = computePackingMetadata( - packedTensorType.getRank(), packOp.getInnerDimsPos()); + PackingMetadata packingMetadata; SmallVector<int64_t> packedToStripMinedShapePerm = - getPackInverseDestPerm(packOp); + getPackInverseDestPerm(packOp, packingMetadata); // 3. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. @@ -1168,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( "this is not supported ATM!"); } - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); - int64_t destRank = packOp.getDestRank(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1263,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( writeSizes.push_back(tileSizeOfr); } - // TODO: Add a constructor for tensor.insert_slice that doesn't require - // strides nor offsets. - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - auto insert = tensor::InsertSliceOp::create( - rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), - writeOffsets, writeSizes, writeStrides); + rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes); // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); @@ -1280,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { - int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); @@ -1297,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Value source = unpackOp.getSource(); DenseMap<int64_t, OpFoldResult> dimAndTileMapping = unpackOp.getDimAndTileMapping(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of @@ -1308,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-tiled-dims being all 1), this will be // [ outer-untiled-dims, tile-sizes ] SmallVector<OpFoldResult> extractSliceSizes; - // The offset and strides attributes for ExtractSliceOp. - SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); - SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); // Shape for EmptyOp that's used as the init value for TransposeOp below. // This should be: @@ -1365,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = tensor::ExtractSliceOp::create( - rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, - extractSliceSizes, extractSliceStrides); + rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( @@ -1382,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. - int numLoops = shapeForEmptyOp.size(); - SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); - SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); SmallVector<OpFoldResult> tileSizes; ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq<unsigned>(0, destRank)) { @@ -1394,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( } auto partialTile = - tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], - tileOffsets, tileSizes, tileStrides); + tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(), + transposedOp.getResult()[0], tileSizes); // 4. Insert the result to the destination tensor. SmallVector<OpFoldResult> writeSizes; - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); for (int i = 0, idx = 0; i < destRank; ++i) { if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); @@ -1408,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( writeSizes.push_back(oneIdxAttr); } auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, - unpackOp.getDest(), writeOffsets, - writeSizes, writeStrides); + unpackOp.getDest(), writeSizes); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index cb6199f..bb3bccd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, auto vectorType = state.getCanonicalVecType( getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); + SmallVector<Value> indices(linalgOp.getRank(outputOperand), + arith::ConstantIndexOp::create(rewriter, loc, 0)); + Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); - SmallVector<Value> indices( - linalgOp.getRank(outputOperand), - arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create( @@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create(rewriter, loc, value, - outputOperand->get(), ValueRange{}); + outputOperand->get(), indices); } write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); @@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, return success(); } -/// Given a linalg::PackOp, return the `dest` shape before any packing -/// permutations. -static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, - ArrayRef<int64_t> destShape) { - return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); -} - /// Determines whether a mask for xfer_write is trivially "all true" /// /// Given all the inputs required to generate a mask (mask sizes and shapes), @@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, return mlir::vector::maskOperation(builder, write, maskForWrite); } -/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant -/// padding value and (3) input vector sizes into: -/// -/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds -/// -/// As in the following example: -/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] -/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> -/// -/// This pack would be vectorized to: -/// -/// %load = vector.mask %mask { -/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst -/// {in_bounds = [true, true, true]} : -/// tensor<32x7x16xf32>, vector<32x8x16xf32> -/// } : vector<32x8x16xi1> -> vector<32x8x16xf32> -/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> -/// to vector<32x4x2x1x16xf32> -/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] -/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> -/// %write = vector.transfer_write %transpose, -/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] -/// {in_bounds = [true, true, true, true, true]} -/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -/// -/// If the (3) input vector sizes are not provided, the vector sizes are -/// determined by the result tensor shape and the `in_bounds` -/// attribute is used instead of masking to mark out-of-bounds accesses. -/// -/// NOTE: The input vector sizes specify the dimensions corresponding to the -/// outer dimensions of the output tensor. The remaining dimensions are -/// computed based on, e.g., the static inner tiles. -/// Supporting dynamic inner tiles will require the user to specify the -/// missing vector sizes. This is left as a TODO. -static LogicalResult -vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, - ArrayRef<int64_t> inputVectorSizes, - SmallVectorImpl<Value> &newResults) { - // TODO: Introduce a parent class that will handle the insertion point update. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(packOp); - - Location loc = packOp.getLoc(); - std::optional<Value> padValue = packOp.getPaddingValue() - ? std::optional(packOp.getPaddingValue()) - : std::nullopt; - - // If the input vector sizes are not provided, then the vector sizes are - // determined by the result tensor shape. In case the vector sizes aren't - // provided, we update the inBounds attribute instead of masking. - bool useInBoundsInsteadOfMasking = false; - if (inputVectorSizes.empty()) { - ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); - inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank()); - useInBoundsInsteadOfMasking = true; - } - - // Create masked TransferReadOp. - SmallVector<int64_t> inputShape(inputVectorSizes); - auto innerTiles = packOp.getStaticInnerTiles(); - auto innerDimsPos = packOp.getInnerDimsPos(); - auto outerDimsPerm = packOp.getOuterDimsPerm(); - if (!outerDimsPerm.empty()) - applyPermutationToVector(inputShape, - invertPermutationVector(outerDimsPerm)); - for (auto [idx, size] : enumerate(innerTiles)) - inputShape[innerDimsPos[idx]] *= size; - auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, packOp.getSource(), inputShape, padValue, - useInBoundsInsteadOfMasking, - /*inputScalableVecSizes=*/{}); - - // Create ShapeCastOp. - SmallVector<int64_t> destShape(inputVectorSizes); - destShape.append(innerTiles.begin(), innerTiles.end()); - auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), - packOp.getDestType().getElementType()); - auto shapeCastOp = - vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead); - - // Create TransposeOp. - auto destPermutation = - invertPermutationVector(getPackInverseDestPerm(packOp)); - auto transposeOp = vector::TransposeOp::create( - rewriter, loc, shapeCastOp.getResult(), destPermutation); - - // Create TransferWriteOp. - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, transposeOp.getResult(), packOp.getDest()); - newResults.push_back(write->getResult(0)); - return success(); -} - /// Given the re-associations, "collapses" the input Vector type /// /// This is similar to CollapseShapeOp::inferCollapsedType with two notable @@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type, return VectorType::get(newShape, type.getElementType(), newScalableFlags); } +/// Vectorize `linalg.pack` as: +/// * xfer_read -> shape_cast -> transpose -> xfer_write +/// +/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector +/// sizes for the xfer_write operation). This is sufficient to infer the other +/// vector sizes required here. +/// +/// If the vector sizes are not provided: +/// * the vector sizes are determined from the destination tensor static shape. +/// * the inBounds attribute is used instead of masking. +/// +/// EXAMPLE (no vector sizes): +/// ``` +/// %pack = tensor.pack %src +/// inner_dims_pos = [2, 1] +/// inner_tiles = [16, 2] +/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> +/// `` +/// is vectorizes as: +/// ``` +/// %read = vector.transfer_read %src +/// : tensor<32x7x16xf32>, vector<32x8x16xf32> +/// %sc = vector.shape_cast %read +/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2] +/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +/// %write = vector.transfer_write %tr into %dest +/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +/// ``` +static LogicalResult +vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, + ArrayRef<int64_t> inputVectorSizes, + SmallVectorImpl<Value> &newResults) { + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == packOp.getDestRank() && + "Invalid number of input vector sizes!"); + } + + // TODO: Introduce a parent class that will handle the insertion point update. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(packOp); + + Location loc = packOp.getLoc(); + std::optional<Value> padValue = packOp.getPaddingValue() + ? std::optional(packOp.getPaddingValue()) + : std::nullopt; + + SmallVector<int64_t> destShape = + SmallVector<int64_t>(packOp.getDestType().getShape()); + + // This is just a convenience alias to clearly communicate that the input + // vector sizes determine the _write_ sizes. + ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes; + + // In the absence of input-vector-sizes, use the _static_ input tensor shape. + // In addition, use the inBounds attribute instead of masking. + bool useInBoundsInsteadOfMasking = false; + if (writeVectorSizes.empty()) { + if (ShapedType::isDynamicShape(destShape)) + return rewriter.notifyMatchFailure(packOp, + "unable to infer vector sizes"); + + writeVectorSizes = destShape; + useInBoundsInsteadOfMasking = true; + } + + // Compute pre-transpose-write-vector-type, i.e. the write vector type + // _before_ the transposition (i.e. before dimension permutation). This is + // done by inverting the permutation/transposition that's part of the Pack + // operation. This type is required to: + // 1) compute the read vector type for masked-read below, and + // 2) generate shape-cast Op below that expands the read vector type. + PackingMetadata packMetadata; + SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes); + auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata); + applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation); + auto preTransposeWriteVecType = VectorType::get( + preTransposeWriteVecSizses, packOp.getType().getElementType()); + + // Compute vector type for the _read_ opeartion. This is simply + // pre-transpose-write-vector-type with the dimensions collapsed + // as per the Pack operation. + VectorType readVecType = getCollapsedVecType( + preTransposeWriteVecType, + getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + rewriter.getContext(), packMetadata.reassociations))); + + // Create masked TransferReadOp. + auto maskedRead = vector::createReadOrMaskedRead( + rewriter, loc, packOp.getSource(), readVecType, padValue, + useInBoundsInsteadOfMasking); + + // Create ShapeCastOp. + auto shapeCastOp = vector::ShapeCastOp::create( + rewriter, loc, preTransposeWriteVecType, maskedRead); + + // Create TransposeOp. + auto destPermutation = invertPermutationVector(destInvPermutation); + auto transposeOp = vector::TransposeOp::create( + rewriter, loc, shapeCastOp.getResult(), destPermutation); + + // Create TransferWriteOp. + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, transposeOp.getResult(), packOp.getDest()); + newResults.push_back(write->getResult(0)); + return success(); +} + /// Vectorize `linalg.unpack` as: /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write /// -/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes -/// for the xfer_read operation). This is sufficient to infer the other vector -/// sizes required here. +/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector +/// sizes for the xfer_read operation). This is sufficient to infer the other +/// vector sizes required here. /// /// If the vector sizes are not provided: /// * the vector sizes are determined from the input tensor static shape. @@ -1960,16 +1968,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, // In the absence of input-vector-sizes, use the _static_ input tensor shape. if (inputVectorSizes.empty()) { if (ShapedType::isDynamicShape(sourceShape)) - return failure(); + return rewriter.notifyMatchFailure(unpackOp, + "Unable to infer vector sizes!"); readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); useInBoundsInsteadOfMasking = true; } // -- Generate the read operation -- + VectorType readVecType = + VectorType::get(readVectorSizes, unpackTensorType.getElementType(), + readScalableVectorFlags); Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt, - useInBoundsInsteadOfMasking, readScalableVectorFlags); + rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt, + useInBoundsInsteadOfMasking); // -- Generate the transpose operation -- PackingMetadata packMetadata; @@ -2015,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); + auto readType = VectorType::get(inputVectorSizes, padValue.getType()); auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); + rewriter, loc, padOp.getSource(), readType, padValue, + /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2212,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, opOperand.get(), readType.getShape(), + rewriter, loc, opOperand.get(), readType, /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), - /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + /*useInBoundsInsteadOfMasking=*/false); vecOperands.push_back(read); } @@ -2443,6 +2456,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef<int64_t> inputVectorSizes) { auto padValue = packOp.getPaddingValue(); Attribute cstAttr; + // TODO: Relax this condiiton if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { LDBG() << "pad value is not constant: " << packOp; return failure(); @@ -3154,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, SmallVector<Value> readIndices( vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), - /*inputScalableVecSizes=*/{}); + rewriter, loc, source, vecType, padValue, + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722..01e6e1e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos, namespace mlir { namespace linalg { -SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) { +SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp, + PackingMetadata &metadata) { - PackingMetadata pMetadata; int64_t packedRank = packOp.getDestType().getRank(); ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos(); ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm(); SmallVector<int64_t> packInvDestPerm = - computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata); + computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata); return packInvDestPerm; } -SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) { - PackingMetadata metadata; - return getUnPackInverseSrcPerm(unpackOp, metadata); -} - SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp, PackingMetadata &metadata) { - int64_t unpackRank = unpackOp.getSourceType().getRank(); + int64_t packedRank = unpackOp.getSourceType().getRank(); ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm(); SmallVector<int64_t> unpackInvSrcPerm = - computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata); + computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata); return unpackInvSrcPerm; } @@ -240,6 +235,731 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg = dyn_cast<BlockArgument>(val); + if ((blockArg)) + return blockArg; + + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present<arith::ExtFOp>(defOp) && + !dyn_cast_if_present<arith::ExtSIOp>(defOp) && + !dyn_cast_if_present<arith::ExtUIOp>(defOp)) { + return nullptr; + } + return dyn_cast<BlockArgument>(defOp->getOperand(0)); +} + +/// Utility to match block body for convolution ops. +/// The body is thus expected to yield :- +/// %out + (%lhs * %rhs) +/// where: %lhs, %rhs and %out are block arguments and +/// %lhs and %rhs can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + +/// Utility to match block body for linalg.pool* ops. +template <typename... OpTypes> +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present<OpTypes>(defOp) || ...)) + return false; + + BlockArgument lhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + BlockArgument rhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, + body); +} + +// max_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, + body); +} + +// min_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body); +} + +static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, + uint32_t dimIndex) { + auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast<AffineDimExpr>(expr))) + return 1; + + auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return -1; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast<AffineDimExpr>(lhs)) && + (cst = dyn_cast<AffineConstantExpr>(rhs))) || + ((dim = dyn_cast<AffineDimExpr>(rhs)) && + (cst = dyn_cast<AffineConstantExpr>(lhs)))) + return cst.getValue(); + return -1; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * <c0> + +/// indexingMaps[n-1].getResult(oDim) * <c1> +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); + auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); + + if (c0 == -1 || c1 == -1) + return false; + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } + if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + return false; +} + +/// Returns true if the given indexing maps matches with the expected indexing +/// maps. +static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected, + ArrayAttr indexingMaps, MLIRContext *context) { + SmallVector<AffineMap, 4> expectedIndexingMaps = + AffineMap::inferFromExprList(mapListExpected, context); + return indexingMaps == + ArrayAttr::get( + context, llvm::to_vector<4>(llvm::map_range( + expectedIndexingMaps, [&](AffineMap m) -> Attribute { + return AffineMapAttr::get(m); + }))); +} + +/// Enum representing pooling operation types used by ConvMatcherBuilder. +enum class PoolingType { + None, + MaxSigned, + MaxUnsigned, + MinSigned, + MinUnsigned, + Sum +}; + +/// Helper class for building convolution op matchers with minimal boilerplate. +/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well +/// as Pooling ops. +/// +/// Usage: Create an instance with the op, spatial rank, and output pointers for +/// extracted dilations/strides. Then chain matchStride() calls for each spatial +/// dimension, followed by matchMaps() to verify indexing maps, and finally +/// matchBody() to verify the operation body pattern. +/// +/// The `matched` flag starts as `true` and is set to `false` if any match step +/// fails. This allows chaining multiple match calls; once any match fails, all +/// subsequent calls become no-ops and the final result is `false`. +/// +/// The `dilations` and `strides` pointers are output parameters that get +/// populated with the extracted dilation and stride values from the operation's +/// indexing maps during matchStride() calls. These values are initially set to +/// 1 for each spatial dimension and updated as patterns are matched. +class ConvMatcherBuilder { + LinalgOp op; + MLIRContext *ctx; + SmallVector<int64_t> *dilations, *strides; + ArrayAttr indexingMaps; + PoolingType poolingType; + bool matched = true; + +public: + ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d, + SmallVector<int64_t> *s, + PoolingType poolingType = PoolingType::None) + : op(op), ctx(op->getContext()), dilations(d), strides(s), + indexingMaps(op.getIndexingMaps()), poolingType(poolingType) { + *dilations = SmallVector<int64_t>(spatialRank, 1); + *strides = SmallVector<int64_t>(spatialRank, 1); + } + + /// Get affine dimension expression for dimension `i`. + AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); } + + /// Build strided expression: base * stride[idx] + kernel * dilation[idx]. + AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) { + return base * (*strides)[idx] + kernel * (*dilations)[idx]; + } + + /// Match stride/dilation pattern for a spatial dimension. + /// Returns *this for method chaining. + ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim, + unsigned idx) { + if (matched) { + matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, + (*dilations)[idx], (*strides)[idx]); + } + return *this; + } + + /// Match expected indexing maps layout. Returns *this for method chaining. + ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) { + if (matched) + matched &= convLayoutMatches(maps, indexingMaps, ctx); + return *this; + } + + /// Match body pattern. This should be called last. + bool matchBody() { + if (!matched) + return false; + Block *body = op.getBlock(); + auto yieldOp = cast<linalg::YieldOp>(body->getTerminator()); + switch (poolingType) { + case PoolingType::None: + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + case PoolingType::MaxSigned: + return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MaxUnsigned: + return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinSigned: + return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinUnsigned: + return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::Sum: + return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body); + } + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Matchers for specific convolution operation. +//===----------------------------------------------------------------------===// + +// #inputMap = affine_map<(W, w) -> (W + w)> +// #filterMap = affine_map<(W, w) -> (w)> +// #outputMap = affine_map<(W, w) -> (W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr W = m.dim(0); + AffineExpr w = m.dim(1); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchMaps({/*inputMap=*/{m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)> +// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)> +// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNwcWcfOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr w = m.dim(3); + AffineExpr c = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)> +// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)> +// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNcwFcwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr c = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)> +// #filterMap = affine_map<(H, W, h, w) -> (h, w)> +// #outputMap = affine_map<(H, W, h, w) -> (H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv2DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr H = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr h = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> +// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> +// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv3DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr D = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr d = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2) + .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> +// #filterMap = affine_map<(N, W, C, w) -> (C, w)> +// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)> +// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr CM = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr C = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcSumOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7ac..d358362 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df..a1e3f10 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e..5404238 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa<MemRefType>(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch<Type, Value>(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional<PromotableAllocationOpInterface> diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f..1035d7c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return subview.getDynamicSize(sourceIndex); } - if (auto sizeInterface = - dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { - assert(sizeInterface.isDynamicSize(unsignedIndex) && - "Expected dynamic subview size"); - return sizeInterface.getDynamicSize(unsignedIndex); - } - // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index bd02516..c9352e8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); - if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) + // ViewLikeOpInterface by itself doesn't guarantee to preserve the base + // pointer in general and `memref.view` is one such example, so just check + // for a few specific cases. + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() || + !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp)) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 214410f..3667fdb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation())))) return failure(); - llvm::TypeSwitch<Operation *, void>(loadOp) + + return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp) .Case([&](affine::AffineLoadOp op) { rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices); + return success(); }) .Case([&](memref::LoadOp op) { rewriter.replaceOpWithNewOp<memref::LoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::LoadOp op) { rewriter.replaceOpWithNewOp<vector::LoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::MaskedLoadOp op) { rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); + return success(); + }) + .Case([&](vector::TransferReadOp op) { + // We only support minor identity maps in the permutation attribute. + if (!op.getPermutationMap().isMinorIdentity()) + return failure(); + + // We only support the case where the source of the expand shape has + // rank greater than or equal to the vector rank. + const int64_t sourceRank = sourceIndices.size(); + const int64_t vectorRank = op.getVectorType().getRank(); + if (sourceRank < vectorRank) + return failure(); + + // We need to construct a new minor identity map since we will have lost + // some dimensions in folding away the expand shape. + auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank, + op.getContext()); + + rewriter.replaceOpWithNewOp<vector::TransferReadOp>( + op, op.getVectorType(), expandShapeOp.getViewSource(), + sourceIndices, minorIdMap, op.getPadding(), op.getMask(), + op.getInBounds()); + return success(); }) .DefaultUnreachable("unexpected operation"); - return success(); } template <typename OpTy> @@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfExpandShapeOpFolder<memref::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, + LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>, StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, StoreOpOfExpandShapeOpFolder<memref::StoreOp>, StoreOpOfExpandShapeOpFolder<vector::StoreOp>, diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 6a81a15..c498c8a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { if (!dimIndex) return failure(); - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), - reifiedResultShapes))) + FailureOr<OpFoldResult> replacement = reifyDimOfResult( + rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex); + if (failed(replacement)) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - // Do not apply pattern if the IR is invalid (dim out of bounds). - if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) - return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds"); - Value replacement = getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); - rewriter.replaceOp(dimOp, replacement); + // Check if the OpFoldResult is empty (unreifiable dimension). + if (!replacement.value()) + return failure(); + Value replacementVal = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), replacement.value()); + rewriter.replaceOp(dimOp, replacementVal); return success(); } }; @@ -166,12 +165,14 @@ namespace { struct ResolveRankedShapeTypeResultDimsPass final : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase< ResolveRankedShapeTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; struct ResolveShapedTypeResultDimsPass final : public memref::impl::ResolveShapedTypeResultDimsPassBase< ResolveShapedTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; @@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 14152c5..e5cc41e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -268,61 +268,82 @@ struct SubViewOpInterface MemRefType sourceType = subView.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); + auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + // Reset insertion point to before the operation for each dimension. builder.setInsertionPoint(subView); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, subView.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = metadataOp.getSizes()[i]; - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage(op, diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 6200366..e548698 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -133,17 +133,20 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, } /// Returns true if all the uses of op are not read/load. -/// There can be SubviewOp users as long as all its users are also +/// There can be view-like-op users as long as all its users are also /// StoreOp/transfer_write. If return true it also fills out the uses, if it /// returns false uses is unchanged. static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) { std::vector<Operation *> opUses; for (OpOperand &use : op->getUses()) { Operation *useOp = use.getOwner(); + // Use escaped the scope + if (useOp->mightHaveTrait<OpTrait::IsTerminator>()) + return false; if (isa<memref::DeallocOp>(useOp) || (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && !mlir::hasEffect<MemoryEffects::Read>(useOp)) || - (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) { + (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) { opUses.push_back(useOp); continue; } diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 2a857ed..0d05313 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { - auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn)); + auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); @@ -727,7 +727,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { toStore.push_back(v); }); - return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn)); + return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, @@ -792,7 +792,7 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { if (failed(maybeInfo)) return failure(); - MmaSyncInfo info = *maybeInfo; + const MmaSyncInfo &info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e..1d775fb 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 35eba72..47f1222 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" @@ -203,12 +204,91 @@ struct MemRefPointerLikeModel return false; } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // Load from a memref - only valid for scalar memrefs (rank 0). + // This is because the address computation for memrefs is part of the load + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr); + if (!memrefValue) + return {}; + + auto memrefTy = memrefValue.getType(); + + // Only load from scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return {}; + + return memref::LoadOp::create(builder, loc, memrefValue); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + // Store to a memref - only valid for scalar memrefs (rank 0) + // This is because the address computation for memrefs is part of the store + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr); + if (!memrefValue) + return false; + + auto memrefTy = memrefValue.getType(); + + // Only store to scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return false; + + memref::StoreOp::create(builder, loc, valueToStore, memrefValue); + return true; + } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // For LLVM pointers, we need the valueType to determine what to load + if (!valueType) + return {}; + + return LLVM::LoadOp::create(builder, loc, valueType, srcPtr); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + LLVM::StoreOp::create(builder, loc, valueToStore, destPtr); + return true; + } +}; + +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } + + Region *getInitRegion(Operation *op) const { + // GlobalOp uses attributes for initialization, not regions + return nullptr; + } }; /// Helper function for any of the times we need to modify an ArrayAttr based on @@ -302,6 +382,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); } //===----------------------------------------------------------------------===// @@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { return success(); } +template <typename OpT, typename RecipeOpT> +static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) { + // Mappable types do not need a recipe because it is possible to generate one + // from its API. Reject reductions though because no API is available for them + // at this time. + if (mlir::acc::isMappableType(op.getVar().getType()) && + !std::is_same_v<OpT, acc::ReductionOp>) + return success(); + + mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr(); + if (!operandRecipe) + return op->emitOpError() << "recipe expected for " << operandName; + + auto decl = + SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe); + if (!decl) + return op->emitOpError() + << "expected symbol reference " << operandRecipe << " to point to a " + << operandName << " declaration"; + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, } } +static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, + mlir::SymbolRefAttr &recipeAttr) { + if (failed(parser.parseAttribute(recipeAttr))) + return failure(); + return success(); +} + +static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::SymbolRefAttr recipeAttr) { + p << recipeAttr; +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed( + checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private"))) + return failure(); return success(); } @@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>( + *this, "firstprivate"))) + return failure(); return success(); } @@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>( + *this, "reduction"))) + return failure(); return success(); } @@ -1042,6 +1170,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +/// Remove empty acc.kernel_environment operations. If the operation has wait +/// operands, create a acc.wait operation to preserve synchronization. +struct RemoveEmptyKernelEnvironment + : public OpRewritePattern<acc::KernelEnvironmentOp> { + using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op, + PatternRewriter &rewriter) const override { + assert(op->getNumRegions() == 1 && "expected op to have one region"); + + Block &block = op.getRegion().front(); + if (!block.empty()) + return failure(); + + // Conservatively disable canonicalization of empty acc.kernel_environment + // operations if the wait operands in the kernel_environment cannot be fully + // represented by acc.wait operation. + + // Disable canonicalization if device type is not the default + if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) { + for (auto attr : deviceTypeAttr) { + if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) { + if (dtAttr.getValue() != mlir::acc::DeviceType::None) + return failure(); + } + } + } + + // Disable canonicalization if any wait segment has a devnum + if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) { + for (auto attr : hasDevnumAttr) { + if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) { + if (boolAttr.getValue()) + return failure(); + } + } + } + + // Disable canonicalization if there are multiple wait segments + if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) { + if (segmentsAttr.size() > 1) + return failure(); + } + + // Remove empty kernel environment. + // Preserve synchronization by creating acc.wait operation if needed. + if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr()) + rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(), + /*asyncOperand=*/Value(), + /*waitDevnum=*/Value(), + /*async=*/nullptr, + /*ifCond=*/Value()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Recipe Region Helpers //===----------------------------------------------------------------------===// @@ -1263,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, return recipe; } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, + FirstprivateRecipeOp firstprivRecipe) { + // Create the private.recipe op with the same type as the firstprivate.recipe. + OpBuilder::InsertionGuard guard(builder); + auto varType = firstprivRecipe.getType(); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Clone the init region + IRMapping mapping; + firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping); + + // Clone destroy region if the firstprivate.recipe has one. + if (!firstprivRecipe.getDestroyRegion().empty()) { + IRMapping mapping; + firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(), + mapping); + } + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1373,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() { } //===----------------------------------------------------------------------===// -// Custom parser and printer verifier for private clause -//===----------------------------------------------------------------------===// - -static ParseResult parseSymOperandList( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, - llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { - llvm::SmallVector<SymbolRefAttr> attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), - attributes.end()); - symbols = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional<mlir::ArrayAttr> attributes) { - llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { - p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " - << std::get<1>(it).getType(); - }); -} - -//===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1425,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op, return success(); } -template <typename Op> -static LogicalResult -checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, - mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName, bool checkOperandType = true) { - if (!operands.empty()) { - if (!attributes || attributes->size() != operands.size()) - return op->emitOpError() - << "expected as many " << symbolName << " symbol reference as " - << operandName << " operands"; - } else { - if (attributes) - return op->emitOpError() - << "unexpected " << symbolName << " symbol reference"; - return success(); - } - +template <typename OpT, typename RecipeOpT> +static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, + const mlir::ValueRange &operands, + llvm::StringRef operandName) { llvm::DenseSet<Value> set; - for (auto args : llvm::zip(operands, *attributes)) { - mlir::Value operand = std::get<0>(args); - + for (mlir::Value operand : operands) { + if (!mlir::isa<OpT>(operand.getDefiningOp())) + return accConstructOp->emitOpError() + << "expected " << operandName << " as defining op"; if (!set.insert(operand).second) - return op->emitOpError() + return accConstructOp->emitOpError() << operandName << " operand appears more than once"; - - mlir::Type varType = operand.getType(); - auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); - auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); - if (!decl) - return op->emitOpError() - << "expected symbol reference " << symbolRef << " to point to a " - << operandName << " declaration"; - - if (checkOperandType && decl.getType() && decl.getType() != varType) - return op->emitOpError() << "expected " << operandName << " (" << varType - << ") to be the same type as " << operandName - << " declaration (" << decl.getType() << ")"; } - return success(); } @@ -1520,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( } LogicalResult acc::ParallelOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -1661,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, mlir::ValueRange gangPrivateOperands, mlir::ValueRange gangFirstPrivateOperands, mlir::ValueRange dataClauseOperands) { - ParallelOp::build( odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, @@ -1670,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, /*numGangsDeviceType=*/nullptr, numWorkers, /*numWorkersDeviceType=*/nullptr, vectorLength, /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, - /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, - gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, - /*firstprivatizations=*/nullptr, dataClauseOperands, + /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands, + gangFirstPrivateOperands, dataClauseOperands, /*defaultAttr=*/nullptr, /*combined=*/nullptr); } @@ -1749,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands( void acc::ParallelOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } static ParseResult parseNumGangs( @@ -2356,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { } LogicalResult acc::SerialOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -2430,46 +2553,22 @@ void acc::SerialOp::addWaitOperands( void acc::SerialOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -2599,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { @@ -2691,6 +2811,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// KernelEnvironmentOp +//===----------------------------------------------------------------------===// + +void acc::KernelEnvironmentOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<RemoveEmptyKernelEnvironment>(context); +} + +//===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// @@ -2899,19 +3028,21 @@ bool hasDuplicateDeviceTypes( } /// Check for duplicates in the DeviceType array attribute. -LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { +/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found. +static std::optional<mlir::acc::DeviceType> +checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; if (!deviceTypes) - return success(); + return std::nullopt; for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); if (!deviceTypeAttr) - return failure(); + return mlir::acc::DeviceType::None; if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) - return failure(); + return deviceTypeAttr.getValue(); } - return success(); + return std::nullopt; } LogicalResult acc::LoopOp::verify() { @@ -2938,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() { getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; - if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) - return emitOpError() - << "duplicate device_type found in collapseDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { @@ -2953,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() { return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } - if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) - return emitOpError() << "duplicate device_type found in gang attribute"; + if (getGangAttr()) { + if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in gang attribute"; + } if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), @@ -2962,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() { return failure(); // Check worker - if (failed(checkDeviceTypes(getWorkerAttr()))) - return emitOpError() << "duplicate device_type found in worker attribute"; - if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "workerNumOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in worker attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector - if (failed(checkDeviceTypes(getVectorAttr()))) - return emitOpError() << "duplicate device_type found in vector attribute"; - if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "vectorOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vector attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getVectorOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) @@ -3042,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() { } } - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (getCombined().has_value() && @@ -3068,8 +3212,12 @@ LogicalResult acc::LoopOp::verify() { if (getRegion().empty()) return emitError("expected non-empty body."); - // When it is container-like - it is expected to hold a loop-like operation. - if (isContainerLike()) { + if (getUnstructured()) { + if (!isContainerLike()) + return emitError( + "unstructured acc.loop must not have induction variables"); + } else if (isContainerLike()) { + // When it is container-like - it is expected to hold a loop-like operation. // Obtain the maximum collapse count - we use this to check that there // are enough loops contained. uint64_t collapseCount = getCollapseValue().value_or(1); @@ -3484,45 +3632,21 @@ void acc::LoopOp::addGangOperands( void acc::LoopOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -3987,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() { if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " - "be present at the same time"; + "be present at the same time for device_type `" + << acc::stringifyDeviceType(dtype) << "`"; } return success(); @@ -4284,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { return std::nullopt; } +void RoutineOp::addSeq(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addVector(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addWorker(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + uint64_t val) { + llvm::SmallVector<mlir::Attribute> dimValues; + llvm::SmallVector<mlir::Attribute> deviceTypes; + + if (getGangDimAttr()) + llvm::copy(getGangDimAttr(), std::back_inserter(dimValues)); + if (getGangDimDeviceTypeAttr()) + llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes)); + + assert(dimValues.size() == deviceTypes.size()); + + if (effectiveDeviceTypes.empty()) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back( + acc::DeviceTypeAttr::get(context, acc::DeviceType::None)); + } else { + for (DeviceType dt : effectiveDeviceTypes) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt)); + } + } + assert(dimValues.size() == deviceTypes.size()); + + setGangDimAttr(mlir::ArrayAttr::get(context, dimValues)); + setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes)); +} + +void RoutineOp::addBindStrName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::StringAttr val) { + unsigned before = getBindStrNameDeviceTypeAttr() + ? getBindStrNameDeviceTypeAttr().size() + : 0; + + setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindStrNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindStrNameAttr()) + llvm::copy(getBindStrNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindStrNameAttr(mlir::ArrayAttr::get(context, vals)); +} + +void RoutineOp::addBindIDName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::SymbolRefAttr val) { + unsigned before = + getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0; + + setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindIdNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindIdNameAttr()) + llvm::copy(getBindIdNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindIdNameAttr(mlir::ArrayAttr::get(context, vals)); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -4667,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; } + +mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) { + auto recipe{ + llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp) + .Case<ACC_DATA_ENTRY_OPS>( + [&](auto entry) { return entry.getRecipeAttr(); }) + .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })}; + return recipe; +} diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp new file mode 100644 index 0000000..67cdf10 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp @@ -0,0 +1,781 @@ +//===- ACCImplicitData.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the OpenACC specification for "Variables with +// Implicitly Determined Data Attributes" (OpenACC 3.4 spec, section 2.6.2). +// +// Overview: +// --------- +// The pass automatically generates data clause operations for variables used +// within OpenACC compute constructs (parallel, kernels, serial) that do not +// already have explicit data clauses. The semantics follow these rules: +// +// 1. If there is a default(none) clause visible, no implicit data actions +// apply. +// +// 2. An aggregate variable (arrays, derived types, etc.) will be treated as: +// - In a present clause when default(present) is visible. +// - In a copy clause otherwise. +// +// 3. A scalar variable will be treated as if it appears in: +// - A copy clause if the compute construct is a kernels construct. +// - A firstprivate clause otherwise (parallel, serial). +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Type Interface Implementation: Variables from the dialect being used +// must implement one or both of the following MLIR interfaces: +// `acc::MappableType` and/or `acc::PointerLikeType` +// +// These interfaces provide the necessary methods for the pass to: +// - Determine variable type categories (scalar vs. aggregate) +// - Generate appropriate bounds information +// - Generate privatization recipes +// +// 2. Operation Interface Implementation: Operations that access partial +// entities or create views should implement the following MLIR +// interfaces: `acc::PartialEntityAccess` and/or +// `mlir::ViewLikeOpInterface` +// +// These interfaces are used for proper data clause ordering, ensuring +// that base entities are mapped before derived entities (e.g., a +// struct is mapped before its fields, an array is mapped before +// subarray views). +// +// 3. Analysis Registration (Optional): If custom behavior is needed for +// variable name extraction or alias analysis, the dialect should +// pre-register the `acc::OpenACCSupport` and `mlir::AliasAnalysis` analyses. +// +// If not registered, default behavior will be used. +// +// Implementation Details: +// ----------------------- +// The pass performs the following operations: +// +// 1. Finds candidate variables which are live-in to the compute region and +// are not already in a data clause or private clause. +// +// 2. Generates both data "entry" and "exit" clause operations that match +// the intended action depending on variable type: +// - copy -> acc.copyin (entry) + acc.copyout (exit) +// - present -> acc.present (entry) + acc.delete (exit) +// - firstprivate -> acc.firstprivate (entry only, no exit) +// +// 3. Ensures that default clause is taken into consideration by looking +// through current construct and parent constructs to find the "visible +// default clause". +// +// 4. Fixes up SSA value links so that uses in the acc region reference the +// result of the newly created data clause operations. +// +// 5. When generating implicit data clause operations, it also adds variable +// name information and marks them with the implicit flag. +// +// 6. Recipes are generated by calling the appropriate entrypoints in the +// MappableType and PointerLikeType interfaces. +// +// 7. AliasAnalysis is used to determine if a variable is already covered by +// an existing data clause (e.g., an interior pointer covered by its parent). +// +// Examples: +// --------- +// +// Example 1: Scalar in parallel construct (implicit firstprivate) +// +// Before: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32> +// acc.parallel { +// %val = memref.load %scalar[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32> +// %firstpriv = acc.firstprivate varPtr(%scalar : memref<f32>) +// -> memref<f32> {implicit = true, name = "x"} +// acc.parallel firstprivate(@recipe -> %firstpriv : memref<f32>) { +// %val = memref.load %firstpriv[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Scalar in kernels construct (implicit copy) +// +// Before: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32> +// acc.kernels { +// %val = memref.load %scalar[] : memref<i32> +// acc.terminator +// } +// } +// +// After: +// func.func @test() { +// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32> +// %copyin = acc.copyin varPtr(%scalar : memref<i32>) -> memref<i32> +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "n"} +// acc.kernels dataOperands(%copyin : memref<i32>) { +// %val = memref.load %copyin[] : memref<i32> +// acc.terminator +// } +// acc.copyout accPtr(%copyin : memref<i32>) +// to varPtr(%scalar : memref<i32>) +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "n"} +// } +// +// Example 3: Array (aggregate) in parallel (implicit copy) +// +// Before: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// acc.parallel { +// %c0 = arith.constant 0 : index +// %val = memref.load %array[%c0] : memref<100xf32> +// acc.yield +// } +// } +// +// After: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// %copyin = acc.copyin varPtr(%array : memref<100xf32>) +// -> memref<100xf32> +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "arr"} +// acc.parallel dataOperands(%copyin : memref<100xf32>) { +// %c0 = arith.constant 0 : index +// %val = memref.load %copyin[%c0] : memref<100xf32> +// acc.yield +// } +// acc.copyout accPtr(%copyin : memref<100xf32>) +// to varPtr(%array : memref<100xf32>) +// {dataClause = #acc<data_clause acc_copy>, +// implicit = true, name = "arr"} +// } +// +// Example 4: Array with default(present) +// +// Before: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// acc.parallel { +// %c0 = arith.constant 0 : index +// %val = memref.load %array[%c0] : memref<100xf32> +// acc.yield +// } attributes {defaultAttr = #acc<defaultvalue present>} +// } +// +// After: +// func.func @test() { +// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32> +// %present = acc.present varPtr(%array : memref<100xf32>) +// -> memref<100xf32> +// {implicit = true, name = "arr"} +// acc.parallel dataOperands(%present : memref<100xf32>) +// attributes {defaultAttr = #acc<defaultvalue present>} { +// %c0 = arith.constant 0 : index +// %val = memref.load %present[%c0] : memref<100xf32> +// acc.yield +// } +// acc.delete accPtr(%present : memref<100xf32>) +// {dataClause = #acc<data_clause acc_present>, +// implicit = true, name = "arr"} +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include <type_traits> + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDATA +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-data" + +using namespace mlir; + +namespace { + +class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> { +public: + using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase; + + void runOnOperation() override; + +private: + /// Looks through the `dominatingDataClauses` to find the original data clause + /// op for an alias. Returns nullptr if no original data clause op is found. + template <typename OpT> + Operation *getOriginalDataClauseOpForAlias( + Value var, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses); + + /// Generates the appropriate `acc.copyin`, `acc.present`,`acc.firstprivate`, + /// etc. data clause op for a candidate variable. + template <typename OpT> + Operation *generateDataClauseOpForCandidate( + Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses, + const std::optional<acc::ClauseDefaultValue> &defaultClause); + + /// Generates the implicit data ops for a compute construct. + template <typename OpT> + void generateImplicitDataOps( + ModuleOp &module, OpT computeConstructOp, + std::optional<acc::ClauseDefaultValue> &defaultClause); + + /// Generates a private recipe for a variable. + acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport); + + /// Generates a firstprivate recipe for a variable. + acc::FirstprivateRecipeOp + generateFirstprivateRecipe(ModuleOp &module, Value var, Location loc, + OpBuilder &builder, + acc::OpenACCSupport &accSupport); + + /// Generates recipes for a list of variables. + void generateRecipes(ModuleOp &module, OpBuilder &builder, + Operation *computeConstructOp, + const SmallVector<Value> &newOperands); +}; + +/// Determines if a variable is a candidate for implicit data mapping. +/// Returns true if the variable is a candidate, false otherwise. +static bool isCandidateForImplicitData(Value val, Region &accRegion) { + // Ensure the variable is an allowed type for data clause. + if (!acc::isPointerLikeType(val.getType()) && + !acc::isMappableType(val.getType())) + return false; + + // If this is already coming from a data clause, we do not need to generate + // another. + if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp())) + return false; + + // If this is only used by private clauses, it is not a real live-in. + if (acc::isOnlyUsedByPrivateClauses(val, accRegion)) + return false; + + return true; +} + +template <typename OpT> +Operation *ACCImplicitData::getOriginalDataClauseOpForAlias( + Value var, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses) { + auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>(); + for (auto dataClause : dominatingDataClauses) { + if (auto *dataClauseOp = dataClause.getDefiningOp()) { + // Only accept clauses that guarantee that the alias is present. + if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp, + acc::DevicePtrOp>(dataClauseOp)) + if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust()) + return dataClauseOp; + } + } + return nullptr; +} + +// Generates bounds for variables that have unknown dimensions +static void fillInBoundsForUnknownDimensions(Operation *dataClauseOp, + OpBuilder &builder) { + + if (!acc::getBounds(dataClauseOp).empty()) + // If bounds are already present, do not overwrite them. + return; + + // For types that have unknown dimensions, attempt to generate bounds by + // relying on MappableType being able to extract it from the IR. + auto var = acc::getVar(dataClauseOp); + auto type = var.getType(); + if (auto mappableTy = dyn_cast<acc::MappableType>(type)) { + if (mappableTy.hasUnknownDimensions()) { + TypeSwitch<Operation *>(dataClauseOp) + .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) { + if (std::is_same_v<decltype(dataClauseOp), acc::DevicePtrOp>) + return; + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(dataClauseOp); + auto bounds = mappableTy.generateAccBounds(var, builder); + if (!bounds.empty()) + dataClauseOp.getBoundsMutable().assign(bounds); + }); + } + } +} + +acc::PrivateRecipeOp +ACCImplicitData::generatePrivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + auto type = var.getType(); + std::string recipeName = + accSupport.getRecipeName(acc::RecipeKind::private_recipe, type, var); + + // Check if recipe already exists + auto existingRecipe = module.lookupSymbol<acc::PrivateRecipeOp>(recipeName); + if (existingRecipe) + return existingRecipe; + + // Set insertion point to module body in a scoped way + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto recipe = + acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type); + if (!recipe.has_value()) + return accSupport.emitNYI(loc, "implicit private"), nullptr; + return recipe.value(); +} + +acc::FirstprivateRecipeOp +ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var, + Location loc, OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + auto type = var.getType(); + std::string recipeName = + accSupport.getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var); + + // Check if recipe already exists + auto existingRecipe = + module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName); + if (existingRecipe) + return existingRecipe; + + // Set insertion point to module body in a scoped way + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc, + recipeName, type); + if (!recipe.has_value()) + return accSupport.emitNYI(loc, "implicit firstprivate"), nullptr; + return recipe.value(); +} + +void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder, + Operation *computeConstructOp, + const SmallVector<Value> &newOperands) { + auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); + for (auto var : newOperands) { + auto loc{var.getLoc()}; + if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) { + auto recipe = generatePrivateRecipe( + module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); + if (recipe) + privateOp.setRecipeAttr( + SymbolRefAttr::get(module->getContext(), recipe.getSymName())); + } else if (auto firstprivateOp = + dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) { + auto recipe = generateFirstprivateRecipe( + module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); + if (recipe) + firstprivateOp.setRecipeAttr(SymbolRefAttr::get( + module->getContext(), recipe.getSymName().str())); + } else { + accSupport.emitNYI(var.getLoc(), "implicit reduction"); + } + } +} + +// Generates the data entry data op clause so that it adheres to OpenACC +// rules as follows (line numbers and specification from OpenACC 3.4): +// 1388 An aggregate variable will be treated as if it appears either: +// 1389 - In a present clause if there is a default(present) clause visible at +// the compute construct. +// 1391 - In a copy clause otherwise. +// 1392 A scalar variable will be treated as if it appears either: +// 1393 - In a copy clause if the compute construct is a kernels construct. +// 1394 - In a firstprivate clause otherwise. +template <typename OpT> +Operation *ACCImplicitData::generateDataClauseOpForCandidate( + Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp, + const SmallVector<Value> &dominatingDataClauses, + const std::optional<acc::ClauseDefaultValue> &defaultClause) { + auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); + acc::VariableTypeCategory typeCategory = + acc::VariableTypeCategory::uncategorized; + if (auto mappableTy = dyn_cast<acc::MappableType>(var.getType())) { + typeCategory = mappableTy.getTypeCategory(var); + } else if (auto pointerLikeTy = + dyn_cast<acc::PointerLikeType>(var.getType())) { + typeCategory = pointerLikeTy.getPointeeTypeCategory( + cast<TypedValue<acc::PointerLikeType>>(var), + pointerLikeTy.getElementType()); + } + + bool isScalar = + acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar); + bool isAnyAggregate = acc::bitEnumContainsAny( + typeCategory, acc::VariableTypeCategory::aggregate); + Location loc = computeConstructOp->getLoc(); + + Operation *op = nullptr; + op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp, + dominatingDataClauses); + if (op) { + if (isa<acc::NoCreateOp>(op)) + return acc::NoCreateOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + + if (isa<acc::DevicePtrOp>(op)) + return acc::DevicePtrOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + + // The original data clause op is a PresentOp, CopyinOp, or CreateOp, + // hence guaranteed to be present. + return acc::PresentOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var), + acc::getBounds(op)); + } else if (isScalar) { + if (enableImplicitReductionCopy && + acc::isOnlyUsedByReductionClauses(var, + computeConstructOp->getRegion(0))) { + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_reduction); + return copyinOp.getOperation(); + } + if constexpr (std::is_same_v<OpT, acc::KernelsOp> || + std::is_same_v<OpT, acc::KernelEnvironmentOp>) { + // Scalars are implicit copyin in kernels construct. + // We also do the same for acc.kernel_environment because semantics + // of user variable mappings should be applied while ACC construct exists + // and at this point we should only be dealing with unmapped variables + // that were made live-in by the compiler. + // TODO: This may be revisited. + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_copy); + return copyinOp.getOperation(); + } else { + // Scalars are implicit firstprivate in parallel and serial construct. + return acc::FirstprivateOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + } + } else if (isAnyAggregate) { + Operation *newDataOp = nullptr; + + // When default(present) is true, the implicit behavior is present. + if (defaultClause.has_value() && + defaultClause.value() == acc::ClauseDefaultValue::Present) { + newDataOp = acc::PresentOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + newDataOp->setAttr(acc::getFromDefaultClauseAttrName(), + builder.getUnitAttr()); + } else { + auto copyinOp = + acc::CopyinOp::create(builder, loc, var, + /*structured=*/true, /*implicit=*/true, + accSupport.getVariableName(var)); + copyinOp.setDataClause(acc::DataClause::acc_copy); + newDataOp = copyinOp.getOperation(); + } + + return newDataOp; + } else { + // This is not a fatal error - for example when the element type is + // pointer type (aka we have a pointer of pointer), it is potentially a + // deep copy scenario which is not being handled here. + // Other types need to be canonicalized. Thus just log unhandled cases. + LLVM_DEBUG(llvm::dbgs() + << "Unhandled case for implicit data mapping " << var << "\n"); + } + return nullptr; +} + +// Ensures that result values from the acc data clause ops are used inside the +// acc region. ie: +// acc.kernels { +// use %val +// } +// => +// %dev = acc.dataop %val +// acc.kernels { +// use %dev +// } +static void legalizeValuesInRegion(Region &accRegion, + SmallVector<Value> &newPrivateOperands, + SmallVector<Value> &newDataClauseOperands) { + for (Value dataClause : + llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) { + Value var = acc::getVar(dataClause.getDefiningOp()); + replaceAllUsesInRegionWith(var, dataClause, accRegion); + } +} + +// Adds the private operands to the compute construct operation. +template <typename OpT> +static void addNewPrivateOperands(OpT &accOp, + const SmallVector<Value> &privateOperands) { + if (privateOperands.empty()) + return; + + for (auto priv : privateOperands) { + if (isa<acc::PrivateOp>(priv.getDefiningOp())) { + accOp.getPrivateOperandsMutable().append(priv); + } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) { + accOp.getFirstprivateOperandsMutable().append(priv); + } else { + llvm_unreachable("unhandled reduction operand"); + } + } +} + +static Operation *findDataExitOp(Operation *dataEntryOp) { + auto res = acc::getAccVar(dataEntryOp); + for (auto *user : res.getUsers()) + if (isa<ACC_DATA_EXIT_OPS>(user)) + return user; + return nullptr; +} + +// Generates matching data exit operation as described in the acc dialect +// for how data clauses are decomposed: +// https://mlir.llvm.org/docs/Dialects/OpenACCDialect/#operation-categories +// Key ones used here: +// * acc {construct} copy -> acc.copyin (before region) + acc.copyout (after +// region) +// * acc {construct} present -> acc.present (before region) + acc.delete +// (after region) +static void +generateDataExitOperations(OpBuilder &builder, Operation *accOp, + const SmallVector<Value> &newDataClauseOperands, + const SmallVector<Value> &sortedDataClauseOperands) { + builder.setInsertionPointAfter(accOp); + Value lastDataClause = nullptr; + for (auto dataEntry : llvm::reverse(sortedDataClauseOperands)) { + if (llvm::find(newDataClauseOperands, dataEntry) == + newDataClauseOperands.end()) { + // If this is not a new data clause operand, we should not generate an + // exit operation for it. + lastDataClause = dataEntry; + continue; + } + if (lastDataClause) + if (auto *dataExitOp = findDataExitOp(lastDataClause.getDefiningOp())) + builder.setInsertionPointAfter(dataExitOp); + Operation *dataEntryOp = dataEntry.getDefiningOp(); + if (isa<acc::CopyinOp>(dataEntryOp)) { + auto copyoutOp = acc::CopyoutOp::create( + builder, dataEntryOp->getLoc(), dataEntry, acc::getVar(dataEntryOp), + /*structured=*/true, /*implicit=*/true, + acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp)); + copyoutOp.setDataClause(acc::DataClause::acc_copy); + } else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) { + auto deleteOp = acc::DeleteOp::create( + builder, dataEntryOp->getLoc(), dataEntry, + /*structured=*/true, /*implicit=*/true, + acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp)); + deleteOp.setDataClause(acc::getDataClause(dataEntryOp).value()); + } else if (isa<acc::DevicePtrOp>(dataEntryOp)) { + // Do nothing. + } else { + llvm_unreachable("unhandled data exit"); + } + lastDataClause = dataEntry; + } +} + +/// Returns all base references of a value in order. +/// So for example, if we have a reference to a struct field like +/// s.f1.f2.f3, this will return <s, s.f1, s.f1.f2, s.f1.f2.f3>. +/// Any intermediate casts/view-like operations are included in the +/// chain as well. +static SmallVector<Value> getBaseRefsChain(Value val) { + SmallVector<Value> baseRefs; + baseRefs.push_back(val); + while (true) { + Value prevVal = val; + + val = acc::getBaseEntity(val); + if (val != baseRefs.front()) + baseRefs.insert(baseRefs.begin(), val); + + // If this is a view-like operation, it is effectively another + // view of the same entity so we should add it to the chain also. + if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>()) { + val = viewLikeOp.getViewSource(); + baseRefs.insert(baseRefs.begin(), val); + } + + // Continue loop if we made any progress + if (val == prevVal) + break; + } + + return baseRefs; +} + +static void insertInSortedOrder(SmallVector<Value> &sortedDataClauseOperands, + Operation *newClause) { + auto *insertPos = + std::find_if(sortedDataClauseOperands.begin(), + sortedDataClauseOperands.end(), [&](Value dataClauseVal) { + // Get the base refs for the current clause we are looking + // at. + auto var = acc::getVar(dataClauseVal.getDefiningOp()); + auto baseRefs = getBaseRefsChain(var); + + // If the newClause is of a base ref of an existing clause, + // we should insert it right before the current clause. + // Thus return true to stop iteration when this is the + // case. + return std::find(baseRefs.begin(), baseRefs.end(), + acc::getVar(newClause)) != baseRefs.end(); + }); + + if (insertPos != sortedDataClauseOperands.end()) { + newClause->moveBefore(insertPos->getDefiningOp()); + sortedDataClauseOperands.insert(insertPos, acc::getAccVar(newClause)); + } else { + sortedDataClauseOperands.push_back(acc::getAccVar(newClause)); + } +} + +template <typename OpT> +void ACCImplicitData::generateImplicitDataOps( + ModuleOp &module, OpT computeConstructOp, + std::optional<acc::ClauseDefaultValue> &defaultClause) { + // Implicit data attributes are only applied if "[t]here is no default(none) + // clause visible at the compute construct." + if (defaultClause.has_value() && + defaultClause.value() == acc::ClauseDefaultValue::None) + return; + assert(!defaultClause.has_value() || + defaultClause.value() == acc::ClauseDefaultValue::Present); + + // 1) Collect live-in values. + Region &accRegion = computeConstructOp->getRegion(0); + SetVector<Value> liveInValues; + getUsedValuesDefinedAbove(accRegion, liveInValues); + + // 2) Run the filtering to find relevant pointers that need copied. + auto isCandidate{[&](Value val) -> bool { + return isCandidateForImplicitData(val, accRegion); + }}; + auto candidateVars( + llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate))); + if (candidateVars.empty()) + return; + + // 3) Generate data clauses for the variables. + SmallVector<Value> newPrivateOperands; + SmallVector<Value> newDataClauseOperands; + OpBuilder builder(computeConstructOp); + if (!candidateVars.empty()) { + LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n" + << computeConstructOp << "\n"); + } + auto &domInfo = this->getAnalysis<DominanceInfo>(); + auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); + auto dominatingDataClauses = + acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo); + for (auto var : candidateVars) { + auto newDataClauseOp = generateDataClauseOpForCandidate( + var, module, builder, computeConstructOp, dominatingDataClauses, + defaultClause); + fillInBoundsForUnknownDimensions(newDataClauseOp, builder); + LLVM_DEBUG(llvm::dbgs() << "Generated data clause for " << var << ":\n" + << "\t" << *newDataClauseOp << "\n"); + if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>( + newDataClauseOp)) { + newPrivateOperands.push_back(acc::getAccVar(newDataClauseOp)); + } else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) { + newDataClauseOperands.push_back(acc::getAccVar(newDataClauseOp)); + dominatingDataClauses.push_back(acc::getAccVar(newDataClauseOp)); + } + } + + // 4) Legalize values in region (aka the uses in the region are the result + // of the data clause ops) + legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands); + + // 5) Generate private recipes which are required for properly attaching + // private operands. + if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && + !std::is_same_v<OpT, acc::KernelEnvironmentOp>) + generateRecipes(module, builder, computeConstructOp, newPrivateOperands); + + // 6) Figure out insertion order for the new data clause operands. + SmallVector<Value> sortedDataClauseOperands( + computeConstructOp.getDataClauseOperands()); + for (auto newClause : newDataClauseOperands) + insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp()); + + // 7) Generate the data exit operations. + generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands, + sortedDataClauseOperands); + // 8) Add all of the new operands to the compute construct op. + if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && + !std::is_same_v<OpT, acc::KernelEnvironmentOp>) + addNewPrivateOperands(computeConstructOp, newPrivateOperands); + computeConstructOp.getDataClauseOperandsMutable().assign( + sortedDataClauseOperands); +} + +void ACCImplicitData::runOnOperation() { + ModuleOp module = this->getOperation(); + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) { + assert(op->getNumRegions() == 1 && "must have 1 region"); + + auto defaultClause = acc::getDefaultAttr(op); + llvm::TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto op) { + generateImplicitDataOps(module, op, defaultClause); + }) + .Default([&](Operation *) {}); + } + }); +} + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp new file mode 100644 index 0000000..8cab223 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp @@ -0,0 +1,431 @@ +//===- ACCImplicitDeclare.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass applies implicit `acc declare` actions to global variables +// referenced in OpenACC compute regions and routine functions. +// +// Overview: +// --------- +// Global references in an acc regions (for globals not marked with `acc +// declare` by the user) can be handled in one of two ways: +// - Mapped through data clauses +// - Implicitly marked as `acc declare` (this pass) +// +// Thus, the OpenACC specification focuses solely on implicit data mapping rules +// whose implementation is captured in `ACCImplicitData` pass. +// +// However, it is both advantageous and required for certain cases to +// use implicit `acc declare` instead: +// - Any functions that are implicitly marked as `acc routine` through +// `ACCImplicitRoutine` may reference globals. Since data mapping +// is only possible for compute regions, such globals can only be +// made available on device through `acc declare`. +// - Compiler can generate and use globals for cases needed in IR +// representation such as type descriptors or various names needed for +// runtime calls and error reporting - such cases often are introduced +// after a frontend semantic checking is done since it is related to +// implementation detail. Thus, such compiler generated globals would +// not have been visible for a user to mark with `acc declare`. +// - Constant globals such as filename strings or data initialization values +// are values that do not get mutated but are still needed for appropriate +// runtime execution. If a kernel is launched 1000 times, it is not a +// good idea to map such a global 1000 times. Therefore, such globals +// benefit from being marked with `acc declare`. +// +// This pass automatically +// marks global variables with the `acc.declare` attribute when they are +// referenced in OpenACC compute constructs or routine functions and meet +// the criteria noted above, ensuring +// they are properly handled for device execution. +// +// The pass performs two main optimizations: +// +// 1. Hoisting: For non-constant globals referenced in compute regions, the +// pass hoists the address-of operation out of the region when possible, +// allowing them to be implicitly mapped through normal data clause +// mechanisms rather than requiring declare marking. +// +// 2. Declaration: For globals that must be available on the device (constants, +// globals in routines, globals in recipe operations), the pass adds the +// `acc.declare` attribute with the copyin data clause. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that compute addresses +// of global variables must implement the `acc::AddressOfGlobalOpInterface` +// and those that represent globals must implement the +// `acc::GlobalOpInterface`. Additionally, any operations that indirectly +// access globals must implement the `acc::IndirectGlobalAccessOpInterface`. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +// +// Examples: +// --------- +// +// Example 1: Non-constant global in compute region (hoisted) +// +// Before: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_scalar : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// %addr = memref.get_global @g_scalar : memref<f32> +// acc.serial { +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Constant global in compute region (declared) +// +// Before: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 3: Global in acc routine (declared) +// +// Before: +// memref.global @g_data : memref<f32> = dense<0.0> +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// After: +// memref.global @g_data : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// Example 4: Global in private recipe (declared if recipe is used) +// +// Before: +// memref.global @g_init : memref<f32> = dense<0.0> +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +// After: +// memref.global @g_init : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDECLARE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-declare" + +using namespace mlir; + +namespace { + +using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>; + +/// Checks whether a use of the requested `globalOp` should be considered +/// for hoisting out of acc region due to avoid `acc declare`ing something +/// that instead should be implicitly mapped. +static bool isGlobalUseCandidateForHoisting(Operation *globalOp, + Operation *user, + SymbolRefAttr symbol, + acc::OpenACCSupport &accSupport) { + // This symbol is valid in GPU region. This means semantics + // would change if moved to host - therefore it is not a candidate. + if (accSupport.isValidSymbolUse(user, symbol)) + return false; + + bool isConstant = false; + bool isFunction = false; + + if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) + isConstant = globalVarOp.isConstant(); + + if (isa<FunctionOpInterface>(globalOp)) + isFunction = true; + + // Constants should be kept in device code to ensure they are duplicated. + // Function references should be kept in device code to ensure their device + // addresses are computed. Everything else should be hoisted since we already + // proved they are not valid symbols in GPU region. + return !isConstant && !isFunction; +} + +/// Checks whether it is valid to use acc.declare marking on the global. +bool isValidForAccDeclare(Operation *globalOp) { + // For functions - we use acc.routine marking instead. + return !isa<FunctionOpInterface>(globalOp); +} + +/// Checks whether a recipe operation has meaningful use of its symbol that +/// justifies processing its regions for global references. Returns false if: +/// 1. The recipe has no symbol uses at all, or +/// 2. The only symbol use is the recipe's own symbol definition +template <typename RecipeOpT> +static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) { + std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod); + + // No recipe symbol uses. + if (!symbolUses.has_value() || symbolUses->empty()) + return false; + + // If more than one use, assume it's used. + auto begin = symbolUses->begin(); + auto end = symbolUses->end(); + if (begin != end && std::next(begin) != end) + return true; + + // If single use, check if the use is the recipe itself. + const SymbolTable::SymbolUse &use = *symbolUses->begin(); + return use.getUser() != recipeOp.getOperation(); +} + +// Hoists addr_of operations for non-constant globals out of OpenACC regions. +// This way - they are implicitly mapped instead of being considered for +// implicit declare. +template <typename AccConstructT> +static void hoistNonConstantDirectUses(AccConstructT accOp, + acc::OpenACCSupport &accSupport) { + accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + if (symRef) { + Operation *globalOp = + SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef); + if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef, + accSupport)) { + addrOfOp->moveBefore(accOp); + LLVM_DEBUG( + llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t"; + accOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + } + } + }); +} + +// Collects the globals referenced in a device region +static void collectGlobalsFromDeviceRegion(Region ®ion, + GlobalOpSetT &globals, + acc::OpenACCSupport &accSupport, + SymbolTable &symTab) { + region.walk([&](Operation *op) { + // 1) Only consider relevant operations which use symbols + auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op); + if (addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + // 2) Found an operation which uses the symbol. Next determine if it + // is a candidate for `acc declare`. Some of the criteria considered + // is whether this symbol is not already a device one (either because + // acc declare is already used or this is a CUF global). + Operation *globalOp = nullptr; + bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp); + // 3) Add the candidate to the set of globals to be `acc declare`d. + if (isCandidate && globalOp && isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } else if (auto indirectAccessOp = + dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) { + // Process operations that indirectly access globals + llvm::SmallVector<SymbolRefAttr> symbols; + indirectAccessOp.getReferencedSymbols(symbols, &symTab); + for (SymbolRefAttr symRef : symbols) + if (Operation *globalOp = symTab.lookup(symRef.getLeafReference())) + if (isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } + }); +} + +// Adds the declare attribute to the operation `op`. +static void addDeclareAttr(MLIRContext *context, Operation *op, + acc::DataClause clause) { + op->setAttr(acc::getDeclareAttrName(), + acc::DeclareAttr::get(context, + acc::DataClauseAttr::get(context, clause))); +} + +// This pass applies implicit declare actions for globals referenced in +// OpenACC compute and routine regions. +class ACCImplicitDeclare + : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> { +public: + using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *context = &getContext(); + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // 1) Start off by hoisting any AddressOf operations out of acc region + // for any cases we do not want to `acc declare`. This is because we can + // rely on implicit data mapping in majority of cases without uselessly + // polluting the device globals. + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + hoistNonConstantDirectUses(accOp, accSupport); + }); + }); + + // 2) Collect global symbols which need to be `acc declare`d. Do it for + // compute regions, acc routine, and existing globals with the declare + // attribute. + SymbolTable symTab(mod); + GlobalOpSetT globalsToAccDeclare; + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + collectGlobalsFromDeviceRegion( + accOp.getRegion(), globalsToAccDeclare, accSupport, symTab); + }) + .Case<FunctionOpInterface>([&](auto func) { + if ((acc::isAccRoutine(func) || + acc::isSpecializedAccRoutine(func)) && + !func.isExternal()) + collectGlobalsFromDeviceRegion(func.getFunctionBody(), + globalsToAccDeclare, accSupport, + symTab); + }) + .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) { + if (globalVarOp->getAttr(acc::getDeclareAttrName())) + if (Region *initRegion = globalVarOp.getInitRegion()) + collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare, + accSupport, symTab); + }) + .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) { + if (hasRelevantRecipeUse(privateRecipe, mod)) { + collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) { + if (hasRelevantRecipeUse(firstprivateRecipe, mod)) { + collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare, + accSupport, symTab); + collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) { + if (hasRelevantRecipeUse(reductionRecipe, mod)) { + collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + reductionRecipe.getCombinerRegion(), globalsToAccDeclare, + accSupport, symTab); + } + }); + }); + + // 3) Finally, generate the appropriate declare actions needed to ensure + // this is considered for device global. + for (Operation *globalOp : globalsToAccDeclare) { + LLVM_DEBUG( + llvm::dbgs() << "Global is being `acc declare copyin`d: "; + globalOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + + // Mark it as declare copyin. + addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin); + + // TODO: May need to create the global constructor which does the mapping + // action. It is not yet clear if this is needed yet (since the globals + // might just end up in the GPU image without requiring mapping via + // runtime). + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp new file mode 100644 index 0000000..12efaf4 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp @@ -0,0 +1,237 @@ +//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the implicit rules described in OpenACC specification +// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1). +// +// "If no explicit routine directive applies to a procedure whose definition +// appears in the program unit being compiled, then the implementation applies +// an implicit routine directive to that procedure if any of the following +// conditions holds: +// - The procedure is called or its address is accessed in a compute region." +// +// The specification further states: +// "When the implementation applies an implicit routine directive to a +// procedure, it must recursively apply implicit routine directives to other +// procedures for which the above rules specify relevant dependencies. Such +// dependencies can form a cycle, so the implementation must take care to avoid +// infinite recursion." +// +// This pass implements these requirements by: +// 1. Walking through all OpenACC compute constructs and functions already +// marked with `acc routine` in the module and identifying function calls +// within these regions. +// 2. Creating implicit `acc.routine` operations for functions that don't +// already have routine declarations. +// 3. Recursively walking through all existing `acc routine` and creating +// implicit routine operations for function calls within these routines, +// while avoiding infinite recursion through proper tracking. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that define functions +// or call functions should implement `mlir::FunctionOpInterface` and +// `mlir::CallOpInterface` respectively. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include <queue> + +#define DEBUG_TYPE "acc-implicit-routine" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITROUTINE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +namespace { + +using namespace mlir; + +class ACCImplicitRoutine + : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> { +private: + unsigned routineCounter = 0; + static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_"; + + // Count existing routine operations and update counter + void initRoutineCounter(ModuleOp module) { + module.walk([&](acc::RoutineOp routineOp) { routineCounter++; }); + } + + // Check if routine has a default bind clause or a device-type specific bind + // clause. Returns true if `acc routine` has a default bind clause or + // a device-type specific bind clause. + bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op, + acc::DeviceType deviceType) { + // Fast check to avoid device-type specific lookups. + if (!op.getBindIdName() && !op.getBindStrName()) + return false; + return op.getBindNameValue().has_value() || + op.getBindNameValue(deviceType).has_value(); + } + + // Generate a unique name for the routine and create the routine operation + acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc, + FunctionOpInterface &callee) { + std::string routineName = + (accRoutinePrefix + std::to_string(routineCounter++)).str(); + auto routineOp = acc::RoutineOp::create( + builder, loc, + /* sym_name=*/builder.getStringAttr(routineName), + /* func_name=*/ + mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(callee.getName())), + /* bindIdName=*/nullptr, + /* bindStrName=*/nullptr, + /* bindIdNameDeviceType=*/nullptr, + /* bindStrNameDeviceType=*/nullptr, + /* worker=*/nullptr, + /* vector=*/nullptr, + /* seq=*/nullptr, + /* nohost=*/nullptr, + /* implicit=*/builder.getUnitAttr(), + /* gang=*/nullptr, + /* gangDim=*/nullptr, + /* gangDimDeviceType=*/nullptr); + + // Assert that the callee does not already have routine info attribute + assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) && + "function is already associated with a routine"); + + callee->setAttr( + acc::getRoutineInfoAttrName(), + mlir::acc::RoutineInfoAttr::get( + builder.getContext(), + {mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(routineName))})); + return routineOp; + } + + // Used to walk through a compute region looking for function calls. + void + implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + op->walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + createRoutineOp(builder, callee.getLoc(), callee); + }); + } + + // Recursively handle calls within a routine operation + void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport, + acc::DeviceType targetDeviceType) { + // When bind clause is used, it means that the target is different than the + // function to which the `acc routine` is used with. Skip this case to + // avoid implicitly recursively marking calls that would not end up on + // device. + if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType)) + return; + + SymbolTable symTab(routineOp->getParentOfType<ModuleOp>()); + std::queue<acc::RoutineOp> routineQueue; + routineQueue.push(routineOp); + while (!routineQueue.empty()) { + auto currentRoutine = routineQueue.front(); + routineQueue.pop(); + auto func = symTab.lookup<FunctionOpInterface>( + currentRoutine.getFuncName().getLeafReference()); + func.walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee); + routineQueue.push(newRoutineOp); + }); + } + } + +public: + using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase; + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module.getContext()); + SymbolTable symTab(module); + initRoutineCounter(module); + + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // Handle compute regions + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) + implicitRoutineForCallsInComputeRegions(op, symTab, builder, + accSupport); + }); + + // Use the device type option from the pass options. + acc::DeviceType targetDeviceType = deviceType; + + // Handle existing routines + module.walk([&](acc::RoutineOp routineOp) { + implicitRoutineForCallsInRoutine(routineOp, builder, accSupport, + targetDeviceType); + }); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000..f41ce276 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 7d93495..10a1796 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,4 +1,8 @@ add_mlir_dialect_library(MLIROpenACCTransforms + ACCImplicitData.cpp + ACCImplicitDeclare.cpp + ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS @@ -14,7 +18,10 @@ add_mlir_dialect_library(MLIROpenACCTransforms MLIROpenACCTypeInterfacesIncGen LINK_LIBS PUBLIC + MLIRAnalysis + MLIROpenACCAnalysis MLIROpenACCDialect + MLIROpenACCUtils MLIRFuncDialect MLIRIR MLIRPass diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index fbac28e..7f27b44 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,13 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -155,3 +160,109 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { return val; } + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp, + mlir::acc::FirstprivateRecipeOp>(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} + +llvm::SmallVector<mlir::Value> +mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp, + mlir::DominanceInfo &domInfo, + mlir::PostDominanceInfo &postDomInfo) { + llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses; + + llvm::TypeSwitch<mlir::Operation *>(computeConstructOp) + .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>( + [&](auto op) { + for (auto dataClause : op.getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + }) + .Default([](mlir::Operation *) {}); + + // Collect the data clauses from enclosing data constructs. + mlir::Operation *currParentOp = computeConstructOp->getParentOp(); + while (currParentOp) { + if (mlir::isa<mlir::acc::DataOp>(currParentOp)) { + for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp) + .getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + } + currParentOp = currParentOp->getParentOp(); + } + + // Find the enclosing function/subroutine + auto funcOp = + computeConstructOp->getParentOfType<mlir::FunctionOpInterface>(); + if (!funcOp) + return dominatingDataClauses.takeVector(); + + // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that + // dominate and post-dominate the compute construct and add their data + // clauses to the list. + funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) { + if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { + // Collect all `acc.declare_exit` ops for this token. + llvm::SmallVector<mlir::acc::DeclareExitOp> exits; + for (auto *user : declareEnterOp.getToken().getUsers()) + if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user)) + exits.push_back(declareExit); + + // Only add clauses if every `acc.declare_exit` op post-dominates the + // compute construct. + if (!exits.empty() && + llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) { + return postDomInfo.postDominates(exitOp, computeConstructOp); + })) { + for (auto dataClause : declareEnterOp.getDataClauseOperands()) + dominatingDataClauses.insert(dataClause); + } + } + }); + + return dominatingDataClauses.takeVector(); +} diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 1b069c6..103295d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, break; case ClauseScheduleKind::Auto: case ClauseScheduleKind::Runtime: + case ClauseScheduleKind::Distribute: chunkSize = std::nullopt; } @@ -1817,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, if (mapTypeMod == "ref_ptr_ptee") mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; + if (mapTypeMod == "is_device_ptr") + mapTypeBits |= ClauseMapFlags::is_device_ptr; + return success(); }; @@ -1886,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, mapTypeStrs.push_back("ref_ptee"); if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr)) + mapTypeStrs.push_back("is_device_ptr"); if (mapFlags == ClauseMapFlags::none) mapTypeStrs.push_back("none"); @@ -2824,6 +2830,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*linear_var_types*/ nullptr, /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/false, @@ -2842,8 +2849,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, WsloopOp::build( builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, - clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, - clauses.ordered, clauses.privateVars, + clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait, + clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), @@ -2888,17 +2895,16 @@ LogicalResult WsloopOp::verifyRegions() { void SimdOp::build(OpBuilder &builder, OperationState &state, const SimdOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: linearVars, linearStepVars - SimdOp::build(builder, state, clauses.alignedVars, - makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, - /*linear_vars=*/{}, /*linear_step_vars=*/{}, - clauses.nontemporalVars, clauses.order, clauses.orderMod, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.reductionMod, - clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, - clauses.simdlen); + SimdOp::build( + builder, state, clauses.alignedVars, + makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, + clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes, + clauses.nontemporalVars, clauses.order, clauses.orderMod, + clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, + clauses.simdlen); } LogicalResult SimdOp::verify() { diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index 423e1c3..b111117 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRSCFDialect MLIRSideEffectInterfaces MLIRTensorDialect MLIRValueBoundsOpInterface + MLIRTransformUtils ) - diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2946b53..c4bd31f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -2565,6 +2566,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { struct ConditionPropagation : public OpRewritePattern<IfOp> { using OpRewritePattern<IfOp>::OpRewritePattern; + /// Kind of parent region in the ancestor cache. + enum class Parent { Then, Else, None }; + + /// Returns the kind of region ("then", "else", or "none") of the + /// IfOp that the given region is transitively nested in. Updates + /// the cache accordingly. + static Parent getParentType(Region *toCheck, IfOp op, + DenseMap<Region *, Parent> &cache, + Region *endRegion) { + SmallVector<Region *> seen; + while (toCheck != endRegion) { + auto found = cache.find(toCheck); + if (found != cache.end()) + return found->second; + seen.push_back(toCheck); + if (&op.getThenRegion() == toCheck) { + for (Region *region : seen) + cache[region] = Parent::Then; + return Parent::Then; + } + if (&op.getElseRegion() == toCheck) { + for (Region *region : seen) + cache[region] = Parent::Else; + return Parent::Else; + } + toCheck = toCheck->getParentRegion(); + } + + for (Region *region : seen) + cache[region] = Parent::None; + return Parent::None; + } + LogicalResult matchAndRewrite(IfOp op, PatternRewriter &rewriter) const override { // Early exit if the condition is constant since replacing a constant @@ -2580,9 +2614,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { Value constantTrue = nullptr; Value constantFalse = nullptr; + DenseMap<Region *, Parent> cache; for (OpOperand &use : llvm::make_early_inc_range(op.getCondition().getUses())) { - if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { + switch (getParentType(use.getOwner()->getParentRegion(), op, cache, + op.getCondition().getParentRegion())) { + case Parent::Then: { changed = true; if (!constantTrue) @@ -2591,8 +2628,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); - } else if (op.getElseRegion().isAncestor( - use.getOwner()->getParentRegion())) { + break; + } + case Parent::Else: { changed = true; if (!constantFalse) @@ -2601,6 +2639,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); + break; + } + case Parent::None: + break; } } @@ -3646,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() { } namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { + using OpRewritePattern<scf::WhileOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = op.getConditionOp(); + + // Only support ifOp right before the condition at the moment. Relaxing this + // would require to: + // - check that the body does not have side-effects conflicting with + // operations between the if and the condition. + // - check that results of the if operation are only used as arguments to + // the condition. + auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp)) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector<Value> additionalUsedValuesSet; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (&op.getBefore() == operand->get().getParentRegion()) + additionalUsedValuesSet.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector<Type> newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, + SmallVector<Location>(additionalValueSize, loc)); + }); + + rewriter.modifyOpInPlace(conditionOp, [&] { + conditionOp.getArgsMutable().append(additionalUsedValues); + }); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues, + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// @@ -4302,7 +4471,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { LogicalResult matchAndRewrite(WhileOp loop, PatternRewriter &rewriter) const override { - auto oldBefore = loop.getBeforeBody(); + auto *oldBefore = loop.getBeforeBody(); ConditionOp oldTerm = loop.getConditionOp(); ValueRange beforeArgs = oldBefore->getArguments(); ValueRange termArgs = oldTerm.getArgs(); @@ -4323,7 +4492,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { beforeArgs); } - auto oldAfter = loop.getAfterBody(); + auto *oldAfter = loop.getAfterBody(); SmallVector<Type> newResultTypes(beforeArgs.size()); for (auto &&[i, j] : llvm::enumerate(*mapping)) @@ -4332,8 +4501,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { auto newLoop = WhileOp::create( rewriter, loop.getLoc(), newResultTypes, loop.getInits(), /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); - auto newBefore = newLoop.getBeforeBody(); - auto newAfter = newLoop.getAfterBody(); + auto *newBefore = newLoop.getBeforeBody(); + auto *newAfter = newLoop.getAfterBody(); SmallVector<Value> newResults(beforeArgs.size()); SmallVector<Value> newAfterArgs(beforeArgs.size()); @@ -4358,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<RemoveLoopInvariantArgsFromBeforeBlock, RemoveLoopInvariantValueYielded, WhileConditionTruth, WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp index 8f7d5e3..c469a99 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp @@ -44,7 +44,6 @@ mlir::scf::parallelForToNestedFors(RewriterBase &rewriter, lowerBounds.size() == steps.size() && "Mismatched parallel loop bounds"); - SmallVector<Value> ivs; scf::LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770f..009c2c3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperands)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner(); + + return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp, + maybeConsumerOpOperands.value(), + candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) + return std::nullopt; + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +static std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outerMostLoop = loops.front(); + if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (!isa<TilingInterface>(consumer)) { + return rewriter.notifyMatchFailure( + consumer, "unhandled consumer that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + consumer, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the consumer that come from the outermost loop of + // the loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : consumer->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + consumer, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumer, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index f0b46e6..a846d7e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -220,6 +220,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() { } //===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + DenseIntElementsAttr literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + build(builder, result, selector, defaultOperands, targetOperands, literals, + defaultTarget, targets); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<APInt> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<int32_t> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +LogicalResult SwitchOp::verify() { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + BlockRange targets = getTargets(); + + if (!literals && targets.empty()) + return success(); + + Type selectorType = getSelector().getType(); + Type literalType = literals->getType().getElementType(); + if (literalType != selectorType) + return emitOpError() << "'selector' type (" << selectorType + << ") should match literals type (" << literalType + << ")"; + + if (literals && literals->size() != static_cast<int64_t>(targets.size())) + return emitOpError() << "number of literals (" << literals->size() + << ") should match number of targets (" + << targets.size() << ")"; + return success(); +} + +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getTargetOperandsMutable(index - 1)); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + + if (!literals) + return getDefaultTarget(); + + SuccessorRange targets = getTargets(); + if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) { + for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>())) + if (literal == value.getValue()) + return targets[index]; + return getDefaultTarget(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index 2f3a28f..8575487 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, } } +/// Adapted from the cf.switch implementation. +/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &selectorType, Block *&defaultTarget, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, + SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals, + SmallVectorImpl<Block *> &targets, + SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> + &targetOperands, + SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) { + if (parser.parseKeyword("default") || parser.parseColon() || + parser.parseSuccessor(defaultTarget)) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + return failure(); + } + + SmallVector<APInt> values; + unsigned bitWidth = selectorType.getIntOrFloatBitWidth(); + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); + + Block *target; + SmallVector<OpAsmParser::UnresolvedOperand> operands; + SmallVector<Type> operandTypes; + if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseOperandList(operands, + OpAsmParser::Delimiter::None)) || + failed(parser.parseColonTypeList(operandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + targets.push_back(target); + targetOperands.emplace_back(operands); + targetOperandTypes.emplace_back(operandTypes); + } + + if (!values.empty()) { + ShapedType literalType = + VectorType::get(static_cast<int64_t>(values.size()), selectorType); + literals = DenseIntElementsAttr::get(literalType, values); + } + return success(); +} + +static void +printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType, + Block *defaultTarget, OperandRange defaultOperands, + TypeRange defaultOperandTypes, DenseIntElementsAttr literals, + SuccessorRange targets, OperandRangeRange targetOperands, + const TypeRangeRange &targetOperandTypes) { + p << " default: "; + p.printSuccessorAndUseList(defaultTarget, defaultOperands); + + if (!literals) + return; + + for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) { + p << ','; + p.printNewline(); + p << " "; + p << literal.getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(targets[index], targetOperands[index]); + } + p.printNewline(); +} + } // namespace mlir::spirv // TablenGen'erated operation definitions. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index cb9b7f6..f07307f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, << type << " illegal: cannot handle zero-element tensors\n"); return nullptr; } + if (arrayElemCount > std::numeric_limits<unsigned>::max()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot fit tensor into target type\n"); + return nullptr; + } Type arrayElemType = convertScalarType(targetEnv, options, scalarType); if (!arrayElemType) diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 645cbff..5941f7d 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames( //===----------------------------------------------------------------------===// void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, - ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid), - GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), + GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes, - ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, - ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes, + ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) { mlir::SmallVector<int64_t> staticHalos, staticDims; mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims; - dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); - dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); + dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos); + dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims); return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); } @@ -576,7 +575,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return failure(); } if (mlir::ShapedType::isDynamicShape(grid->getShape()) && - getStaticShardedDimsOffsets().size() > 0) { + !getStaticShardedDimsOffsets().empty()) { return emitError() << "sharded dims offsets are not allowed for " "device grids with dynamic shape."; } @@ -650,14 +649,14 @@ public: if (dynamicOffs.empty() && !staticOffs.empty()) { assert(staticOffs.size() >= 2); auto diff = staticOffs[1] - staticOffs[0]; - bool all_same = staticOffs.size() > 2; + bool allSame = staticOffs.size() > 2; for (auto i = 2u; i < staticOffs.size(); ++i) { if (staticOffs[i] - staticOffs[i - 1] != diff) { - all_same = false; + allSame = false; break; } } - if (all_same) { + if (allSame) { staticOffs.clear(); modified = true; } @@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const { bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); } -Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {} +Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {} Sharding::Sharding(Value rhs) { auto shardingOp = rhs.getDefiningOp<ShardingOp>(); @@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) { SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); } -Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, - ArrayRef<GridAxesAttr> split_axes_, - ArrayRef<int64_t> static_halo_sizes_, - ArrayRef<int64_t> static_sharded_dims_offsets_, - ArrayRef<Value> dynamic_halo_sizes_, - ArrayRef<Value> dynamic_sharded_dims_offsets_) { - Sharding res(grid_); - if (split_axes_.empty()) { +Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid, + ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHaloSizes, + ArrayRef<int64_t> staticShardedDimsOffsets, + ArrayRef<Value> dynamicHaloSizes, + ArrayRef<Value> dynamicShardedDimsOffsets) { + Sharding res(grid); + if (splitAxes.empty()) { return res; } - res.split_axes.resize(split_axes_.size()); - for (auto [i, axis] : llvm::enumerate(split_axes_)) { - res.split_axes[i] = - GridAxesAttr::get(grid_.getContext(), axis.asArrayRef()); + res.split_axes.resize(splitAxes.size()); + for (auto [i, axis] : llvm::enumerate(splitAxes)) { + res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef()); } auto clone = [](const auto src, auto &dst) { @@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, llvm::copy(src, dst.begin()); }; - clone(static_halo_sizes_, res.static_halo_sizes); - clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets); - clone(dynamic_halo_sizes_, res.dynamic_halo_sizes); - clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets); + clone(staticHaloSizes, res.static_halo_sizes); + clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets); + clone(dynamicHaloSizes, res.dynamic_halo_sizes); + clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets); return res; } @@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames( void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<int64_t> dims, - ArrayRef<Value> dims_dyn, ::mlir::Value sharding, + ArrayRef<Value> dimsDyn, ::mlir::Value sharding, ::mlir::ValueRange device) { SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType()); - build(odsBuilder, odsState, resType, dims, dims_dyn, sharding, + build(odsBuilder, odsState, resType, dims, dimsDyn, sharding, SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device); } diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index 3bfbf373..f954131 100644 --- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -184,7 +184,7 @@ ReshardingRquirementKind getReshardingRquirementKind( for (auto [result, sharding] : llvm::zip_equal(op->getResults(), resultShardings)) { - for (auto user : result.getUsers()) { + for (auto *user : result.getUsers()) { ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); if (!shardOp) { continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index ae7eef2..9db9814 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1365,8 +1365,8 @@ public: arith::SubIOp::create(rewriter, loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = memref::SubViewOp::create( - rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize}, - /*size=*/ValueRange{fillSize}, + rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize}, + /*sizes=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, fillValue, subBuffer); } @@ -1386,8 +1386,8 @@ public: memref::StoreOp::create(rewriter, loc, value, buffer, size); } else { Value subBuffer = memref::SubViewOp::create( - rewriter, loc, buffer, /*offset=*/ValueRange{size}, - /*size=*/ValueRange{n}, + rewriter, loc, buffer, /*offsets=*/ValueRange{size}, + /*sizes=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, value, subBuffer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 7a26cd3..1fbcf5f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1050,7 +1050,7 @@ public: /// Sparse codegen rule for position accesses. class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { public: - using OpAdaptor = typename ToPositionsOp::Adaptor; + using OpAdaptor = ToPositionsOp::Adaptor; using OpConversionPattern<ToPositionsOp>::OpConversionPattern; LogicalResult matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, @@ -1073,7 +1073,7 @@ public: class SparseToCoordinatesConverter : public OpConversionPattern<ToCoordinatesOp> { public: - using OpAdaptor = typename ToCoordinatesOp::Adaptor; + using OpAdaptor = ToCoordinatesOp::Adaptor; using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; LogicalResult matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, @@ -1099,7 +1099,7 @@ public: class SparseToCoordinatesBufferConverter : public OpConversionPattern<ToCoordinatesBufferOp> { public: - using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; + using OpAdaptor = ToCoordinatesBufferOp::Adaptor; using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; LogicalResult matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, @@ -1121,7 +1121,7 @@ public: /// Sparse codegen rule for value accesses. class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { public: - using OpAdaptor = typename ToValuesOp::Adaptor; + using OpAdaptor = ToValuesOp::Adaptor; using OpConversionPattern<ToValuesOp>::OpConversionPattern; LogicalResult matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index febec6d..23436a6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, SmallVector<Value> scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask, - rhs); + vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem, + scalarArgs, indexVec, vmask, rhs); return; } vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 869d27a..7e8d360 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index ffa8b40..9904803 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) { return includesAny(mask, SortMask::kIncludeDenseOutput); } +/// Returns a sparsity rank for loop ordering: lower values indicate +/// dimensions that should be placed in outer loops. +/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown. +static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors, + ArrayRef<AffineMap> allMaps) { + // Start with highest rank. + unsigned minRank = 3; + + for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) { + // Check if this loop accesses this tensor. + bool loopAccessesTensor = false; + unsigned tensorDim = 0; + for (AffineExpr expr : map.getResults()) { + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { + if (dimExpr.getPosition() == loop) { + loopAccessesTensor = true; + break; + } + } + tensorDim++; + } + + if (loopAccessesTensor) { + const auto enc = getSparseTensorEncoding(tensor.getType()); + if (!enc) { + // Dense tensor - lowest rank. + return 0; + } else { + // Sparse tensor - check the level type for this dimension. + auto lvlTypes = enc.getLvlTypes(); + if (tensorDim < lvlTypes.size()) { + auto lvlType = lvlTypes[tensorDim]; + if (isDenseLT(lvlType)) { + return 0; // Dense level. + } else if (isCompressedLT(lvlType)) { + minRank = std::min(minRank, 1u); // Compressed level. + } else if (isSingletonLT(lvlType)) { + minRank = std::min(minRank, 2u); // Singleton level. + } + } + } + } + } + + return minRank; +} + AffineMap IterationGraphSorter::topoSort() { // The sorted result will put the first Reduction iterator to the // latest possible position. @@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() { case sparse_tensor::LoopOrderingStrategy::kDefault: src = it.back(); break; + case sparse_tensor::LoopOrderingStrategy::kDenseOuter: { + // Prefer dense, then compressed, then singleton dimensions outermost. + // Create combined tensor and map lists for analysis. + SmallVector<Value> allTensors = ins; + allTensors.push_back(out); + SmallVector<AffineMap> allMaps = loop2InsLvl; + allMaps.push_back(loop2OutLvl); + + // Find loop with minimum (lowest) sparsity rank. + unsigned minLoop = it[0]; + unsigned minRank = getLoopSparsityRank(minLoop, allTensors, allMaps); + + for (auto candidateLoop : it) { + unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps); + if (rank < minRank || (rank == minRank && candidateLoop < minLoop)) { + minLoop = candidateLoop; + minRank = rank; + } + } + src = minLoop; + break; + } } loopOrder.push_back(src); - it.pop_back(); + // Remove the selected loop from the worklist. + it.erase(std::find(it.begin(), it.end(), src)); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < numLoops; dst++) { if (itGraph[src][dst] && --inDegree[dst] == 0) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 3636f3f..46378b9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -197,7 +197,7 @@ public: // Sets the iterate to the specified position. void seek(ValueRange vals) { assert(vals.size() == cursorValsCnt); - std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin()); + llvm::copy(vals, cursorValsStorageRef.begin()); // Now that the iterator is re-positioned, the coordinate becomes invalid. crd = nullptr; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 4ec13e1..686f6ee 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -77,6 +77,9 @@ namespace { struct ReifyExpandShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, ExpandShapeOp> { + using Base = + ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, + ExpandShapeOp>; LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifyResultShapes) const { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 110bfdc..204e9bb 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) { assert(!inputTypes.empty() && "cannot concatenate 0 tensors"); auto tensorTypes = - llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) { - return llvm::cast<RankedTensorType>(type); - })); + llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>); int64_t concatRank = tensorTypes[0].getRank(); // The concatenation dim must be in the range [0, rank). @@ -2293,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames( /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets, - ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<int64_t> staticSizes) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type // and strides=1. @@ -2307,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType( } // TODO: This uses neither offsets nor strides! -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<OpFoldResult> sizes) { SmallVector<int64_t> staticSizes; std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast<int64_t>(staticSizes.size()) == sourceTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); @@ -2329,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType( /// To disambiguate, this function always drops the first 1 sizes occurrences. RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, - ArrayRef<int64_t> strides) { + ArrayRef<int64_t> sizes) { // Type inferred in the absence of rank-reducing behavior. auto inferredType = llvm::cast<RankedTensorType>( - inferResultType(sourceRankedTensorType, offsets, sizes, strides)); + inferResultType(sourceRankedTensorType, sizes)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -2352,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, - ArrayRef<OpFoldResult> strides) { - SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; - SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + ArrayRef<OpFoldResult> sizes) { + SmallVector<int64_t> staticSizes; + SmallVector<Value> dynamicSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); return ExtractSliceOp::inferCanonicalRankReducedResultType( - desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, - staticStrides); + desiredResultRank, sourceRankedTensorType, staticSizes); } /// Build an ExtractSliceOp with mixed static and dynamic entries and custom @@ -2380,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType( - sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); + resultType = llvm::cast<RankedTensorType>( + ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes)); } result.addAttributes(attrs); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, @@ -2451,13 +2445,26 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } } +/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred +/// result type, offsets set to 0 and strides set to 1. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr); + build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs); +} + /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); // Verify result type against inferred type. - RankedTensorType expectedType = ExtractSliceOp::inferResultType( - sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides()); + RankedTensorType expectedType = + ExtractSliceOp::inferResultType(sourceType, getMixedSizes()); SliceVerificationResult result = isRankReducedType(expectedType, getType()); if (result != SliceVerificationResult::Success) return produceSliceErrorMsg(result, *this, expectedType); @@ -2697,8 +2704,7 @@ struct SliceReturnTypeCanonicalizer { ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { return ExtractSliceOp::inferCanonicalRankReducedResultType( - op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, - mixedStrides); + op.getType().getRank(), op.getSourceType(), mixedSizes); } }; @@ -2839,8 +2845,8 @@ static SliceVerificationResult verifyInsertSliceOp( ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. - RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, staticOffsets, staticSizes, staticStrides); + RankedTensorType expected = + ExtractSliceOp::inferResultType(dstType, staticSizes); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2968,7 +2974,7 @@ public: // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), - mixedOffsets, mixedSizes, mixedStrides); + mixedSizes); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); @@ -3896,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set +// to 0, strides set to 1 and inferred result type. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr); + build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs); +} + LogicalResult ParallelInsertSliceOp::verify() { if (!isa<InParallelOpInterface>(getOperation()->getParentOp())) return this->emitError("expected InParallelOpInterface parent, got:") diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c607ece..310e725 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1132,35 +1132,22 @@ struct ConcatOpInterface // Extract the dimension for the concat op uint64_t concatDim = concatOp.getDim(); - bool dynamicConcatDim = false; SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1)); - SmallVector<OpFoldResult> sizes; - - for (const auto &[dimIdx, dimSize] : - llvm::enumerate(tensorType.getShape())) { - if (dimSize == ShapedType::kDynamic) { - auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx); - sizes.push_back(dimOp.getResult()); - if (dimIdx == concatDim) - dynamicConcatDim = true; - } else { - sizes.push_back(rewriter.getIndexAttr(dimSize)); - } - } - - int64_t concatDimOffset = 0; - std::optional<Value> dynamicOffset; - std::optional<Value> dynamicSize; - if (dynamicConcatDim) { - // One or more operands have dynamic size, so we must accumulate the - // offset with arith ops. - dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - } + SmallVector<OpFoldResult> sizes = + memref::getMixedSizes(rewriter, loc, dstBuffer); + + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + auto sum = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1, + {v1, v2}); + }; + OpFoldResult concatDimOffset = rewriter.getIndexAttr(0); for (auto operand : concatOp.getInputs()) { // Get the buffer for the operand. FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state); @@ -1171,18 +1158,10 @@ struct ConcatOpInterface // so the offset on that axis must accumulate through the loop, and the // size must change to the size of the current operand. auto operandTensorType = cast<RankedTensorType>(operand.getType()); - int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim); - - if (dynamicConcatDim) { - offsets[concatDim] = dynamicOffset.value(); - dynamicSize = - memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim) - .getResult(); - sizes[concatDim] = dynamicSize.value(); - } else { - sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize); - offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset); - } + offsets[concatDim] = concatDimOffset; + OpFoldResult concatDimSize = + memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim); + sizes[concatDim] = concatDimSize; // Create a subview of the destination buffer. auto dstMemrefType = cast<MemRefType>(memrefType); @@ -1197,12 +1176,7 @@ struct ConcatOpInterface if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview))) return failure(); - if (dynamicConcatDim) { - dynamicOffset = arith::AddIOp::create( - rewriter, loc, dynamicOffset.value(), dynamicSize.value()); - } else { - concatDimOffset += operandConcatDimSize; - } + concatDimOffset = sum(concatDimOffset, concatDimSize); } replaceOpWithBufferizedValues(rewriter, op, dstBuffer); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 7ec61c7..a53af98 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( - srcType, extractSliceOp.getStaticOffsets(), - extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + srcType, extractSliceOp.getStaticSizes()); if (nonReducingExtractType != resultType) return failure(); @@ -533,8 +532,8 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( getMixedSizes(b, loc, sliceOp.getSource()); // Helper variables and function for accumulating the size values. - AffineExpr d0, d1, d2; - bindDims(b.getContext(), d0, d1, d2); + AffineExpr d0, d1; + bindDims(b.getContext(), d0, d1); // Multiply two integers. auto mul = [&](OpFoldResult v1, OpFoldResult v2) { auto mulMap = AffineMap::get(2, 0, {d0 * d1}); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 753cb95..d35f458 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -155,13 +155,15 @@ struct ExtractSliceOpInterface RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride < + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); Value offset = getValueOrCreateConstantIndexOp( @@ -170,46 +172,63 @@ struct ExtractSliceOpInterface builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = builder.createOrFold<tensor::DimOp>( loc, extractSliceOp.getSource(), i); - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset < + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 293c6af..c420a4c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) { - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); } Attribute newMinValAttr, newMaxValAttr; @@ -1485,7 +1486,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return {}; } +static bool +mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) { + const auto isDynamic = [](Type ty) { + const auto shapedTy = llvm::dyn_cast<ShapedType>(ty); + return !shapedTy || !shapedTy.hasStaticShape(); + }; + + return llvm::any_of(operandTypes, isDynamic) || + failed(verifyCompatibleShapes(operandTypes)); +} + OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { + // Select allows operand shapes to be broadcast to the output shape. For + // now, don't support folding when we cannot prove no broadcasting is + // involved. + if (mayRequireBroadcast(getOperandTypes())) + return {}; + if (getOnTrue() == getOnFalse()) return getOnTrue(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index bf3810f..1c175f9ab 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType)) - srcType = quantType.getStorageType(); + srcType = getStorageElementTypeFromQuantized(quantType); return srcType; } @@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) { bool resultIsFloat = llvm::isa<FloatType>(resultEType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType)) - weightEType = quantType.getStorageType(); + weightEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType)) - biasEType = quantType.getStorageType(); + biasEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for @@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() { if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( outputType.getElementType())) { - if (result.getStorageType() == attrType.getElementType()) + if (getStorageElementTypeFromQuantized(result) == attrType.getElementType()) return success(); } @@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) @@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); return success(); } @@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() { llvm::cast<ShapedType>(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { - inputETy = quantType.getStorageType(); + inputETy = getStorageElementTypeFromQuantized(quantType); } mlir::Type outputETy = llvm::cast<ShapedType>(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { - outputETy = quantType.getStorageType(); + outputETy = getStorageElementTypeFromQuantized(quantType); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); @@ -1761,6 +1761,11 @@ LogicalResult tosa::ConcatOp::verify() { } } + const ShapeAdaptor outputShape(outType); + if (outputShape.hasRank() && outputShape.getRank() != firstInputRank) + return emitOpError("expect output rank to match inputs rank, got ") + << outputShape.getRank() << " vs " << firstInputRank; + // ERROR_IF(axis_sum != shape[axis]); int64_t axisSum = 0; for (const auto &input : inputList) { @@ -1772,7 +1777,7 @@ LogicalResult tosa::ConcatOp::verify() { } axisSum += inputShape.getDimSize(axis); } - const ShapeAdaptor outputShape(outType); + if (axisSum >= 0 && outputShape.hasRank() && !outputShape.isDynamicDim(axis) && axisSum != outputShape.getDimSize(axis)) @@ -2628,7 +2633,7 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, if (!zpElemType.isInteger(8) && zp != 0) { // convert operand to lower case for error message std::string lower = operand; - std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + llvm::transform(lower, lower.begin(), ::tolower); return op.emitOpError() << lower << " zero point must be zero for non-int8 integer types"; } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 41b338d..091b481 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaAttachTarget.cpp + TosaArithConstantToConst.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp @@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaTypeConverters.cpp TosaProfileCompliance.cpp TosaValidation.cpp + TosaNarrowI64ToI32.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms @@ -21,7 +23,9 @@ add_mlir_dialect_library(MLIRTosaTransforms LINK_LIBS PUBLIC MLIRFuncDialect + MLIRFuncTransformOps MLIRPass MLIRTosaDialect MLIRTransformUtils + MLIRFuncTransforms ) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp new file mode 100644 index 0000000..73e1e2b --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp @@ -0,0 +1,111 @@ +//===- TosaArithConstantToConst.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that converts tensor-valued arith.constant ops +// into tosa.const so that TOSA pipelines operate on a uniform constant form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +// NOTE: TOSA pipelines already lower their constants through shared Arith +// folding passes, so tensor literals often come back as `arith.constant` even +// after the IR is otherwise TOSA-only. Keep this normalization with the rest of +// the TOSA transforms so any client can re-establish a canonical `tosa.const` +// representation without needing a full Arith->TOSA conversion library. + +/// Returns true when `elementType` is natively representable by tosa.const. +static bool isSupportedElementType(Type elementType) { + if (isa<FloatType>(elementType)) + return true; + + if (auto intType = dyn_cast<IntegerType>(elementType)) + return intType.isSignless() || intType.isUnsigned(); + + if (isa<quant::QuantizedType>(elementType)) + return true; + + if (isa<tosa::mxint8Type>(elementType)) + return true; + + return false; +} + +class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constOp, + PatternRewriter &rewriter) const override { + // TOSA constant verification requires a ranked, statically shaped tensor. + auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + if (!isSupportedElementType(resultType.getElementType())) + return failure(); + + Attribute attr = constOp.getValueAttr(); + auto elementsAttr = dyn_cast<ElementsAttr>(attr); + if (!elementsAttr) + return failure(); + + auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType()); + if (!attrType || !attrType.hasStaticShape()) + return failure(); + if (attrType != resultType) + return failure(); + + auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(), + resultType, elementsAttr); + rewriter.replaceOp(constOp, newConst.getResult()); + return success(); + } +}; + +struct TosaArithConstantToTosaConstPass + : public tosa::impl::TosaArithConstantToTosaConstPassBase< + TosaArithConstantToTosaConstPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, tosa::TosaDialect>(); + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add<ArithConstantToTosaConst>(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 0bec0da..022476a2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { ShapedType weightType = cast<ShapedType>(weight.getType()); ShapedType resultType = cast<ShapedType>(op.getOutput().getType()); - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i)) + return failure(); + } + + if (!weightType.hasStaticShape()) { return failure(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index dc5c51b..8b23fd1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -49,8 +49,13 @@ public: if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t kernelHeight = weightTy.getDimSize(1); @@ -113,8 +118,13 @@ public: if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) return rewriter.notifyMatchFailure(op, "non-one stride found."); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t batch = inputTy.getDimSize(0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp new file mode 100644 index 0000000..ddaf7d8a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp @@ -0,0 +1,310 @@ +//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass narrows TOSA operations with 64-bit integer tensor types to +// 32-bit integer tensor types. This can be useful for backends that do not +// support the EXT-INT64 extension of TOSA. The pass has two options: +// +// - aggressive-rewrite - If enabled, all TOSA operations are rewritten, +// regardless or whether the narrowing is safe. This option may lead to +// data loss if not used carefully. +// - convert-function-boundaries - If enabled, the pass will convert function +// I/O types as well. Otherwise casts will be inserted at the I/O +// boundaries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +LogicalResult convertGenericOp(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + const TypeConverter *typeConverter) { + // Convert types of results + SmallVector<Type, 4> newResults; + if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults))) + return failure(); + + // Create a new operation state + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, {}, op->getSuccessors()); + + for (const NamedAttribute &namedAttribute : op->getAttrs()) { + const Attribute attribute = namedAttribute.getValue(); + + // Convert integer attribute type + if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) { + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(intAttr.getType(), attribute); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) { + Type type = typeAttr.getValue(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, attribute); + if (!convertedAttribute) + return rewriter.notifyMatchFailure(op, + "Failed to convert type attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) { + const Type type = denseElementsAttr.getType(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, denseElementsAttr); + if (!convertedAttribute) + return rewriter.notifyMatchFailure( + op, "Failed to convert dense elements attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + state.addAttribute(namedAttribute.getName(), attribute); + } + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter))) + return failure(); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +// =========================== +// Aggressive rewrite patterns +// =========================== + +class ConvertGenericOp : public ConversionPattern { +public: + ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + if (!isa<tosa::TosaOp>(op)) + return rewriter.notifyMatchFailure( + op, + "Support for operations other than TOSA has not been implemented."); + + return convertGenericOp(op, operands, rewriter, typeConverter); + } +}; + +// =============================== +// Bounds checked rewrite patterns +// =============================== + +class ConvertArgMaxOpWithBoundsChecking + : public OpConversionPattern<tosa::ArgMaxOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Output type can be narrowed based on the size of the axis dimension + const int32_t axis = op.getAxis(); + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + if (!inputType || !inputType.isStaticDim(axis)) + return rewriter.notifyMatchFailure( + op, "Requires a static axis dimension for bounds checking."); + const int64_t axisDim = inputType.getDimSize(axis); + if (axisDim >= std::numeric_limits<int32_t>::max()) + return rewriter.notifyMatchFailure( + op, "Axis dimension is too large to narrow safely."); + + const Type resultType = op.getOutput().getType(); + const Type newResultType = typeConverter->convertType(resultType); + rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType, + adaptor.getInput(), axis); + return success(); + } +}; + +class ConvertCastOpWithBoundsChecking + : public OpConversionPattern<tosa::CastOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + const auto resultType = dyn_cast<ShapedType>(op.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + const auto elementInputIntType = + dyn_cast<IntegerType>(inputType.getElementType()); + const auto elementResultIntType = + dyn_cast<IntegerType>(resultType.getElementType()); + if (elementInputIntType && elementResultIntType && + elementInputIntType.getWidth() > elementResultIntType.getWidth()) + return rewriter.notifyMatchFailure( + op, "Narrowing cast may lead to data loss."); + + rewriter.replaceOpWithNewOp<tosa::CastOp>( + op, typeConverter->convertType(resultType), adaptor.getInput()); + return success(); + } +}; + +template <typename OpTy> +class ConvertTypedOp : public OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + return convertGenericOp(op, adaptor.getOperands(), rewriter, + this->getTypeConverter()); + } +}; + +struct TosaNarrowI64ToI32 + : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> { +public: + explicit TosaNarrowI64ToI32() = default; + explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) + : TosaNarrowI64ToI32() { + this->aggressiveRewrite = options.aggressiveRewrite; + this->convertFunctionBoundaries = options.convertFunctionBoundaries; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) -> Type { return type; }); + typeConverter.addConversion([](IntegerType type) -> Type { + if (!type.isInteger(64)) + return type; + return IntegerType::get(type.getContext(), 32); + }); + typeConverter.addConversion( + [&typeConverter](RankedTensorType type) -> Type { + const Type elementType = type.getElementType(); + if (!elementType.isInteger(64)) + return type; + return RankedTensorType::get(type.getShape(), + typeConverter.convertType(elementType)); + }); + + const auto materializeCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + return tosa::CastOp::create(builder, loc, resultType, inputs.front()); + }; + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); + + typeConverter.addTypeAttributeConversion( + [](IntegerType type, IntegerAttr attribute) -> Attribute { + const APInt value = attribute.getValue().truncSSat(32); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), + value); + }); + typeConverter.addTypeAttributeConversion( + [&typeConverter](ShapedType type, + DenseIntElementsAttr attr) -> Attribute { + const ShapedType newType = + cast<ShapedType>(typeConverter.convertType(type)); + const auto oldElementType = cast<IntegerType>(type.getElementType()); + const auto newElementType = + cast<IntegerType>(newType.getElementType()); + if (oldElementType.getWidth() == newElementType.getWidth()) + return attr; + + DenseElementsAttr mapped = + attr.mapValues(newElementType, [&](const APInt &v) { + return v.truncSSat(newElementType.getWidth()); + }); + return mapped; + }); + + ConversionTarget target(*context); + target.addDynamicallyLegalDialect<tosa::TosaDialect>( + [&typeConverter](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + if (convertFunctionBoundaries) { + target.addDynamicallyLegalOp<func::FuncOp>( + [&typeConverter](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) { + const FunctionType funcType = + op->getParentOfType<func::FuncOp>().getFunctionType(); + return llvm::equal(op.getOperandTypes(), funcType.getResults()); + }); + } else { + target.addDynamicallyLegalOp<func::FuncOp>( + [](func::FuncOp op) { return true; }); + target.addDynamicallyLegalOp<func::ReturnOp>( + [](func::ReturnOp op) { return true; }); + } + + RewritePatternSet patterns(context); + if (convertFunctionBoundaries) { + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + } + if (aggressiveRewrite) { + patterns.add<ConvertGenericOp>(typeConverter, context); + } else { + // Tensor + patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context); + // Data layout + patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context); + // Type conversion + patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context); + // Controlflow + patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context); + } + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index ac5d620..36e8940 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -70,6 +70,8 @@ namespace { // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c]. +// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?]. LogicalResult computeReshapeOutput(ArrayRef<int64_t> higherRankShape, ArrayRef<int64_t> lowerRankShape, @@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - if (lowerRankDim != 1 && higherRankDim != 1 && + auto isStaticDimAndNotEqualToOne = [](int64_t dim) { + return dim != 1 && dim != ShapedType::kDynamic; + }; + + if (isStaticDimAndNotEqualToOne(lowerRankDim) && + isStaticDimAndNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure(); @@ -216,22 +223,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { bool mlir::tosa::hasUniqueConstantScatterIndices( ShapedType indicesType, DenseIntElementsAttr indicesAttr) { - llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape(); + const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape(); const unsigned int indicesRank = indicesShape.size(); const unsigned int lastDimSize = indicesShape[indicesRank - 1]; // check each batch of indices from the flat indicesAttr values // for duplicates - auto const indicesValues = indicesAttr.getValues<int32_t>(); + auto const indicesValues = indicesAttr.getValues<APInt>(); assert( (indicesValues.size() % lastDimSize == 0) && "Constant indices data length should be a multiple of indicesShape[-1]"); - std::vector<uint64_t> indices(lastDimSize); + std::vector<APInt> indices(lastDimSize); for (auto beg = indicesValues.begin(); beg < indicesValues.end(); beg += lastDimSize) { std::copy(beg, beg + lastDimSize, indices.begin()); - std::sort(indices.begin(), indices.end()); + std::sort(indices.begin(), indices.end(), + [](const APInt &a, const APInt &b) { return a.slt(b); }); if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { // found duplicate values in indices in batch return false; diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 02c86a0..c55b13d 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, maxAttr, quantBits, filterQuantDim, isSigned, narrowRange)); } + +Type mlir::tosa::getStorageElementTypeFromQuantized( + quant::QuantizedType quantType) { + auto quantEty = quantType.getStorageType(); + // StorageType doesn't capture the sign information + // Explicitly create unsigned type if needed + if (!quantType.isSigned()) { + quantEty = IntegerType::get(quantEty.getContext(), + quantEty.getIntOrFloatBitWidth(), + IntegerType::Unsigned); + } + return quantEty; +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 062606e..86233b0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = applySequenceBlock( callee.getBody().front(), getFailurePropagationMode(), state, results); + + if (!result.succeeded()) + return result; + mappings.clear(); detail::prepareValueMappings( mappings, callee.getBody().front().getTerminator()->getOperands(), state); diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 4f4620a..24b0487 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -47,8 +47,6 @@ static bool happensBefore(Operation *a, Operation *b) { // TransformState //===----------------------------------------------------------------------===// -constexpr const Value transform::TransformState::kTopLevelValue; - transform::TransformState::TransformState( Region *region, Operation *payloadRoot, const RaggedArray<MappedValue> &extraMappings, @@ -1497,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp, template <typename T> static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { - return llvm::to_vector(llvm::map_range( - range, [](transform::MappedValue value) { return cast<T>(value); })); + return llvm::map_to_vector(range, llvm::CastTo<T>); } void transform::detail::setApplyToOneResults( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index f727118..2bd6205 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -156,7 +156,7 @@ DiagnosedSilenceableFailure transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - std::optional<size_t> selectedRegionIdx; + std::optional<int64_t> selectedRegionIdx; if (auto selectedRegionAttr = getSelectedRegionAttr()) selectedRegionIdx = selectedRegionAttr->getSExtValue(); @@ -232,7 +232,7 @@ LogicalResult transform::tune::AlternativesOp::verify() { } if (auto selectedRegionAttr = getSelectedRegionAttr()) { - size_t regionIdx = selectedRegionAttr->getSExtValue(); + int64_t regionIdx = selectedRegionAttr->getSExtValue(); if (regionIdx < 0 || regionIdx >= getNumRegions()) return emitOpError() << "'selected_region' attribute specifies region at index " diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp index a26edac..2986f4c 100644 --- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -106,14 +106,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound( AffineMap bound = [&] { if (boundType == BoundType::EQ && !invalidBound(lowerBound) && - lowerBound[0] == upperBound[0]) { + lowerBound[0] == upperBound[0]) return lowerBound[0]; - } - if (boundType == BoundType::LB && !invalidBound(lowerBound)) { + if (boundType == BoundType::LB && !invalidBound(lowerBound)) return lowerBound[0]; - } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { + if (boundType == BoundType::UB && !invalidBound(upperBound)) return upperBound[0]; - } return AffineMap{}; }(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ae3423c..2789f63 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, case arith::AtomicRMWKind::ori: return vector::ReductionOp::create(builder, vector.getLoc(), CombiningKind::OR, vector); - // TODO: Add remaining reduction operations. + case arith::AtomicRMWKind::minnumf: + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MINNUMF, vector); + case arith::AtomicRMWKind::maxnumf: + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::MAXNUMF, vector); + case arith::AtomicRMWKind::xori: + return vector::ReductionOp::create(builder, vector.getLoc(), + CombiningKind::XOR, vector); default: (void)emitOptionalError(loc, "Reduction operation type not supported"); break; @@ -6058,19 +6066,21 @@ LogicalResult ScatterOp::verify() { VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType valueVType = getVectorType(); - MemRefType memType = getMemRefType(); + ShapedType baseType = getBaseType(); - if (valueVType.getElementType() != memType.getElementType()) + if (!llvm::isa<MemRefType, RankedTensorType>(baseType)) + return emitOpError("requires base to be a memref or ranked tensor type"); + + if (valueVType.getElementType() != baseType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getOffsets()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; + if (llvm::size(getOffsets()) != baseType.getRank()) + return emitOpError("requires ") << baseType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); return success(); } - namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: @@ -6233,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 546099c..352f477 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::bufferization; @@ -126,6 +127,54 @@ struct TransferWriteOpInterface } }; +/// Bufferization of vector.scatter. Replaced with a new vector.scatter that +/// operates on a memref. +struct ScatterOpInterface + : public BufferizableOpInterface::ExternalModel<ScatterOpInterface, + vector::ScatterOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + auto scatterOp = cast<vector::ScatterOp>(op); + if (&opOperand != &scatterOp.getBaseMutable()) + return {}; + return {{scatterOp.getResult(), BufferRelation::Equivalent}}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto scatterOp = cast<vector::ScatterOp>(op); + assert(isa<TensorType>(scatterOp.getBaseType()) && + "only tensor types expected"); + FailureOr<Value> buffer = + getBuffer(rewriter, scatterOp.getBase(), options, state); + if (failed(buffer)) + return failure(); + vector::ScatterOp::create(rewriter, scatterOp.getLoc(), + /*resultType=*/nullptr, *buffer, + scatterOp.getOffsets(), scatterOp.getIndices(), + scatterOp.getMask(), scatterOp.getValueToStore()); + replaceOpWithBufferizedValues(rewriter, op, *buffer); + return success(); + } +}; + /// Bufferization of vector.gather. Replaced with a new vector.gather that /// operates on a memref. struct GatherOpInterface @@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels( GatherOp::attachInterface<GatherOpInterface>(*ctx); MaskOp::attachInterface<MaskOpInterface>(*ctx); YieldOp::attachInterface<YieldOpInterface>(*ctx); + ScatterOp::attachInterface<ScatterOpInterface>(*ctx); }); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 258f2cb..1af5523 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { if (!isValidKind(isInt, scanOp.getKind())) return failure(); - VectorType resType = VectorType::get(destShape, elType); + VectorType resType = destType; Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); @@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { int64_t initialValueRank = initialValueType.getRank(); SmallVector<int64_t> reductionShape(destShape); + SmallVector<bool> reductionScalableDims(destType.getScalableDims()); + + if (reductionScalableDims[reductionDim]) + return rewriter.notifyMatchFailure( + scanOp, "Trying to reduce scalable dimension - not yet supported!"); + + // The reduction dimension, after reducing, becomes 1. It's a fixed-width + // dimension - no need to touch the scalability flag. reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); + VectorType reductionType = + VectorType::get(reductionShape, elType, reductionScalableDims); + SmallVector<int64_t> offsets(destRank, 0); SmallVector<int64_t> strides(destRank, 1); SmallVector<int64_t> sizes(destShape); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 726da1e..ad16b80 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); + if (!isa<VectorType>(op->getResult(0).getType())) + return failure(); auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); if (!bcastOp) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae098..462bd8c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,286 @@ private: vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.create_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// its local mask size in each dimension (d) as: +/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). +/// Example: +/// Given a create_mask operation: +/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 +/// elements +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense<false> : vector<8x16xi1> +/// +/// Slice [0,0]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 +/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [0,8]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 +/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,0]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 +/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,8]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 +/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> { + UnrollCreateMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::CreateMaskOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, createMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = createMaskOp.getVectorType(); + SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll(); + Location loc = createMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector<int64_t> strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector<Value> unrolledOperands; + + for (auto [i, originalMaskOperand] : + llvm::enumerate(createMaskOp.getOperands())) { + Value offsetVal = + arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); + Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>( + loc, originalMaskOperand, offsetVal); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value unrolledDimSize = + arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); + Value nonNegative = + rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero); + Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>( + loc, nonNegative, unrolledDimSize); + unrolledOperands.push_back(unrolledOperand); + } + + auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>( + loc, targetVectorType, unrolledOperands); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(createMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + +/// Checks whether extractShape is a contiguous slice of shape. +/// For extractShape to be contiguous in shape: +/// 1) All but the leading dimension of extractShape and shape must match +/// exactly. 2) The total number of elements in shape must be evenly divisible +/// by +/// the total number of elements in extractShape. +/// Examples: +/// isContiguous([4, 4], [8, 4]) == true +/// isContiguous([2, 4], [8, 4]) == true +/// isContiguous([2, 2], [8, 4]) == false +/// Removes leading unit dimensions to handle cases like: +/// isContiguous([1, 16], [1, 32]) == true +static bool isContiguous(ArrayRef<int64_t> extractShape, + ArrayRef<int64_t> shape) { + + if (extractShape.size() > shape.size()) + return false; + + while (!extractShape.empty() && extractShape.front() == 1) { + extractShape = extractShape.drop_front(); + } + + while (!shape.empty() && shape.front() == 1) { + shape = shape.drop_front(); + } + + size_t rankDiff = shape.size() - extractShape.size(); + if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) + return false; + + int64_t extractElements = ShapedType::getNumElements(extractShape); + int64_t shapeElements = ShapedType::getNumElements(shape); + return shapeElements % extractElements == 0; +} + +/// Determines what shape to use with `vector.extract_strided_slice` to extract +/// a contiguous memory region from a source vector. The extraction must be +/// contiguous and contain exactly the specified number of elements. If such an +/// extraction shape cannot be determined, returns std::nullopt. +/// EXAMPLE 1: +/// sourceShape = [16], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 16) = 8 from only dim → extractShape = [8], +/// remaining = 8/8 = 1 +/// Result: [8] +/// +/// EXAMPLE 2: +/// sourceShape = [4, 4], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 4) = 4 from last dim → extractShape = [4], +/// remaining = 8/4 = 2 +/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], +/// remaining = 2/2 = 1 +/// Result: [2, 4] +static std::optional<SmallVector<int64_t>> +calculateSourceExtractShape(ArrayRef<int64_t> sourceShape, + int64_t targetElements) { + SmallVector<int64_t> extractShape; + int64_t remainingElements = targetElements; + + // Build extract shape from innermost dimension outward to ensure contiguity. + for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { + int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); + extractShape.insert(extractShape.begin(), takeFromDim); + + if (remainingElements % takeFromDim != 0) + return std::nullopt; // Not evenly divisible. + remainingElements /= takeFromDim; + } + + // Fill remaining dimensions with 1. + while (extractShape.size() < sourceShape.size()) + extractShape.insert(extractShape.begin(), 1); + + if (ShapedType::getNumElements(extractShape) != targetElements) + return std::nullopt; + + return extractShape; +} + +// Convert result offsets to source offsets via linear position. +static SmallVector<int64_t> +calculateSourceOffsets(ArrayRef<int64_t> resultOffsets, + ArrayRef<int64_t> sourceShape, + ArrayRef<int64_t> resultShape) { + // Convert result offsets to linear position. + int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); + // Convert linear position to source offsets. + return delinearize(linearIndex, computeStrides(sourceShape)); +} + +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It unrolls a large shape cast into smaller +/// shape casts by extracting contiguous slices from the source vector, casting +/// each slice to the target shape, and assembling the result by inserting each +/// computed segment into the appropriate offset of the result vector. +/// +/// This pattern only applies when contiguous slices can be extracted from the +/// source vector and inserted into the result vector such that each slice +/// remains a valid vector (and not decompose to scalars). In these cases, the +/// unrolling proceeds as: +/// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> +/// vector.insert_strided_slice. +/// +/// Example: +/// Given a shape cast operation: +/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x4>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> +/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> +/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::ShapeCastOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + std::optional<SmallVector<int64_t>> targetShape = + getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + ArrayRef<int64_t> sourceShape = sourceType.getShape(); + ArrayRef<int64_t> resultShape = resultType.getShape(); + + if (!isContiguous(*targetShape, resultShape)) + return rewriter.notifyMatchFailure( + shapeCastOp, "Only supports cases where target shape is " + "contiguous in result vector shape"); + + int64_t targetElements = ShapedType::getNumElements(*targetShape); + + // Calculate the shape to extract from source. + std::optional<SmallVector<int64_t>> extractShape = + calculateSourceExtractShape(sourceShape, targetElements); + if (!extractShape) + return rewriter.notifyMatchFailure( + shapeCastOp, + "cannot extract target number of elements contiguously from source"); + + Location loc = shapeCastOp.getLoc(); + + // Create result vector initialized to zero. + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + VectorType targetType = + VectorType::get(*targetShape, sourceType.getElementType()); + + SmallVector<int64_t> extractStrides(extractShape->size(), 1); + SmallVector<int64_t> insertStrides(targetShape->size(), 1); + + for (SmallVector<int64_t> resultOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + SmallVector<int64_t> sourceOffsets = + calculateSourceOffsets(resultOffsets, sourceShape, resultShape); + Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, + extractStrides); + Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>( + loc, targetType, sourceChunk); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, targetChunk, result, resultOffsets, insertStrides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, + UnrollCreateMaskPattern>(patterns.getContext(), options, + benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c809c502..c307fb4 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, std::optional<Value> padValue, bool useInBoundsInsteadOfMasking, ArrayRef<bool> inputScalableVecDims) { - assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && + VectorType vecToReadTy = VectorType::get( + inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(), + inputScalableVecDims); + + return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue, + useInBoundsInsteadOfMasking); +} + +Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, + Value source, + const VectorType &vecToReadTy, + std::optional<Value> padValue, + bool useInBoundsInsteadOfMasking) { + assert(!llvm::is_contained(vecToReadTy.getScalableDims(), + ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast<ShapedType>(source.getType()); auto sourceShape = sourceShapedType.getShape(); - assert(sourceShape.size() == inputVectorSizes.size() && + + int64_t vecToReadRank = vecToReadTy.getRank(); + auto vecToReadShape = vecToReadTy.getShape(); + + assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, sourceShapedType.getElementType(), - inputScalableVecDims); assert((!padValue.has_value() || padValue.value().getType() == sourceShapedType.getElementType()) && "expected same pad element type to match source element type"); - int64_t readRank = inputVectorSizes.size(); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); - SmallVector<bool> inBoundsVal(readRank, true); + SmallVector<bool> inBoundsVal(vecToReadRank, true); if (useInBoundsInsteadOfMasking) { // Update the inBounds attribute. // FIXME: This computation is too weak - it ignores the read indices. - for (unsigned i = 0; i < readRank; i++) - inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && + for (unsigned i = 0; i < vecToReadRank; i++) + inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) && ShapedType::isStatic(sourceShape[i]); } auto transferReadOp = vector::TransferReadOp::create( builder, loc, - /*vectorType=*/vectorType, + /*vectorType=*/vecToReadTy, /*source=*/source, - /*indices=*/SmallVector<Value>(readRank, zero), + /*indices=*/SmallVector<Value>(vecToReadRank, zero), /*padding=*/padValue, /*inBounds=*/inBoundsVal); - if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) + if (llvm::equal(vecToReadTy.getShape(), sourceShape) || + useInBoundsInsteadOfMasking) return transferReadOp; SmallVector<OpFoldResult> mixedSourceDims = isa<MemRefType>(source.getType()) ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), - inputScalableVecDims); + auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type()); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627..cb1e9d0 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..f4c9f8a --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000..95db208 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,64 @@ +//===- X86VectorTransformOps.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateVectorContractToFMAPatterns(patterns); +} + +void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect<x86vector::X86VectorDialect>(); + declareGeneratedDialect<LLVM::LLVMDialect>(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions<X86VectorTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266a..2cab50f 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,11 +1,14 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + VectorContractToFMA.cpp + VectorContractToPackedTypeDotProduct.cpp LINK_LIBS PUBLIC MLIRArithDialect MLIRX86VectorDialect MLIRIR + MLIRLinalgDialect MLIRLLVMCommonConversion MLIRLLVMDialect MLIRVectorDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000..f3af5ca --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,143 @@ +//===- VectorContractToFMA.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +// Implements outer product contraction as a sequence of broadcast and +// FMA operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <16xf32> +// vector.fma vector<16xf32> +// ``` +struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 lowering is supported."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0) + return rewriter.notifyMatchFailure( + contractOp, "Excepts unit dimensions for either LHS or RHS shape."); + + if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator is not a vector type"); + + if (!accTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator should be F32 type."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B dimension should be non-unit."); + + // Lowers vector.contract into a broadcast+FMA sequence. + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + vector::FMAOp fma; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we + // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if (nonUnitDimRhs.size() > 0) { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, castRhs.getResult().getType(), castLhs); + fma = + vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc); + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, castLhs.getResult().getType(), castRhs); + fma = + vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc); + } + + auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma); + rewriter.replaceOp(contractOp, castFma); + + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToFMAPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToFMA>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp new file mode 100644 index 0000000..1e64811 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -0,0 +1,301 @@ +//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +static FailureOr<SmallVector<mlir::utils::IteratorType>> +inferIteratorsFromOutMap(AffineMap map) { + if (!map.isProjectedPermutation()) + return failure(); + SmallVector<mlir::utils::IteratorType> iterators( + map.getNumDims(), mlir::utils::IteratorType::reduction); + for (auto expr : map.getResults()) + if (auto dim = dyn_cast<AffineDimExpr>(expr)) + iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel; + return iterators; +} + +// Returns true if the operation is in VNNI layout. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps, + std::optional<unsigned> blockingFactor) { + // Narrow down type operations - VNNI only applies to contractions. + FailureOr<linalg::ContractionDimensions> dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return false; + + auto matA = op->getOperand(0); + auto matB = op->getOperand(1); + auto typeA = dyn_cast<ShapedType>(matA.getType()); + auto typeB = dyn_cast<ShapedType>(matB.getType()); + unsigned rankA = typeA.getRank(); + unsigned rankB = typeB.getRank(); + // VNNI format requires at least 1 parallel and 2 reduction dimensions. + if (rankA < 3 || rankB < 3) + return false; + + // At least two reduction dimensions are expected: + // one for the VNNI factor and one for the K dimension + if (dims->k.size() < 2) + return false; + + // Validate affine maps - VNNI computation should be defined by the two + // innermost reduction iterators. + // The input matrix dimensions layout must match the following: + // - matrix A - [...][K/vnniFactor][vnniFactor] + // - matrix B - [...][K/vnniFactor][N][vnniFactor] + auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]); + if (failed(maybeIters)) + return false; + SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters; + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + + auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1)); + auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2)); + auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2)); + if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != + mlir::utils::IteratorType::parallel) + return false; + + // VNNI factor must be: + // - the innermost inputs' dimension + // - statically known + // - multiple of 2 or equal to the specified factor + auto vnniDimSize = typeB.getShape().back(); + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != *blockingFactor) + return false; + + // The split reduction dimension size should also match. + if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3]) + return false; + + return true; +} + +// Implements packed type outer product contraction as a sequence +// of broadcast and packed dot-product operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <32xbf16> +// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> +// ``` +struct VectorContractToPackedTypeDotProduct + : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isBF16() && + !lhsTy.getElementType().isSignlessInteger(8)) + return rewriter.notifyMatchFailure( + contractOp, "Only BF16/Int8 lowering is supported."); + + unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4; + if (!isInVnniLayout(contractOp.getOperation(), + contractOp.getIndexingMapsArray(), blockingFactor)) + return rewriter.notifyMatchFailure(contractOp, + "Input matrices not in VNNI format."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0) + return rewriter.notifyMatchFailure(contractOp, + "Excepts unit dimensions for either " + "LHS or RHS shape other than VNNI."); + + if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type."); + + if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) || + (lhsTy.getElementType().isSignlessInteger(8) && + !accTy.getElementType().isSignlessInteger(32))) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 for BF16 or Int32 for Int8 " + "accumulation type is supported."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B should be a non-unit dim in acc."); + + // Non-unit dimensions should match the vector length of BF16 or Int8 + // dot-product. + unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front() + : nonUnitDimRhs.front(); + if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 && + nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "BF16 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8/16."); + + if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 && + nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "Int8 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8."); + + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + Value dp; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>, + // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if ((nonUnitDimRhs.size() - 1) > 0) { + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(), + rhsTy.getElementType()), + contractOp.getRhs()); + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto bitcastLhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castLhs); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)), + bitcastLhs); + auto bitcastLhsPkType = vector::BitCastOp::create( + rewriter, loc, castRhs.getResult().getType(), broadcastLhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()), + castAcc, bitcastLhsPkType, castRhs); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)), + castAcc, bitcastLhsPkType, castRhs); + } + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(), + lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto bitcastRhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castRhs); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)), + bitcastRhs); + auto bitcastRhsPkType = vector::BitCastOp::create( + rewriter, loc, castLhs.getResult().getType(), broadcastRhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()), + castAcc, castLhs, bitcastRhsPkType); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)), + castAcc, castLhs, bitcastRhsPkType); + } + } + + if (!dp) + return failure(); + + auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp); + rewriter.replaceOp(contractOp, castDp); + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6..46b8251 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 6b4c185..1a19ab5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,10 +8,8 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -38,55 +36,61 @@ void XeGPUDialect::initialize() { >(); } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit. static SmallVector<SmallVector<Value>> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, - SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, - ArrayRef<int64_t> sizePerSg, - ArrayRef<int64_t> sizePerWg) { - - SmallVector<SmallVector<Value>> offsets; +genCoordinates(OpBuilder &builder, Location loc, + SmallVector<Value> delinearizedId, + ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, + ArrayRef<int64_t> srcShape) { + SmallVector<SmallVector<Value>> coordinates; + + // A distribution unit must be less than or equal to `srcShape` + SmallVector<int64_t> distUnitShape = llvm::map_to_vector( + llvm::zip_equal(srcShape, + computeElementwiseMul(subShapesLayout, subShape)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] - SmallVector<Value> localOffsets = llvm::map_to_vector( - llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { - return builder.createOrFold<index::MulOp>( + // Get the offset of `subShape` within a distribution unit. + SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( + llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::MulIOp>( loc, std::get<0>(t), builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); }); - // distUnit[i] is the minimum value between sizePerWg[i] and - // sgLayout[i] * sizePerSg[i] - SmallVector<int64_t> distUnit = llvm::map_to_vector( - llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), - [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - + // For each dist unit for (SmallVector<int64_t> unitOffs : - StaticTileOffsetRange(sizePerWg, distUnit)) { + StaticTileOffsetRange(srcShape, distUnitShape)) { + // Get dist unit offset within `srcShape`. SmallVector<Value> base = llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { return arith::ConstantIndexOp::create(builder, loc, d); }); - - SmallVector<Value> adds = llvm::map_to_vector( - llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { - return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), - std::get<1>(t)); - }); - + // Calculate `subShape` offset within `srcShape`. + SmallVector<Value> adds = + llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), + [&](const auto &t) -> Value { + return builder.createOrFold<arith::AddIOp>( + loc, std::get<0>(t), std::get<1>(t)); + }); + // Do not go beyond `srcShape` bounds. SmallVector<Value> mods = llvm::map_to_vector( - llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { - return builder.createOrFold<index::RemUOp>( + llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::RemUIOp>( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); - offsets.push_back(mods); + coordinates.push_back(mods); } - return offsets; + return coordinates; } // Checks if the given shape can be evenly distributed based on the layout @@ -273,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } FailureOr<SmallVector<Value>> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { - // delinearizeSubgroupId is only available for - // workgroup-level layout attribute - if (!isForWorkgroup()) +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { + + SmallVector<int64_t> sgLayoutInt; + if (isForWorkgroup()) { + sgLayoutInt = getEffectiveSgLayoutAsInt(); + } else if (isForSubgroup()) { + sgLayoutInt = getEffectiveLaneLayoutAsInt(); + } else { return failure(); + } - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); + DenseI32ArrayAttr orderAttr = getOrder(); - auto dims = - llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } - return affine::delinearizeIndex(builder, loc, linearId, dims); + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal); + } + } + return result; } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate /// instructions for computing multi-dimensional offsets when distributed by /// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { - if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { return failure(); - - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + } + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); - SmallVector<Value> sgIds = *maybeIds; + SmallVector<Value> ids = *maybeIds; + + return genCoordinates(builder, loc, ids, layout, subShape, shape); +} + +bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::SliceAttr>(other)) + return false; + + return *this == dyn_cast<xegpu::LayoutAttr>(other); +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) { + auto sgDataOpt = getSgData(); + auto instDataOpt = getInstData(); + auto laneDataOpt = getLaneData(); + + SmallVector<int32_t> sgData; + SmallVector<int32_t> instData; + SmallVector<int32_t> laneData; + + if (sgDataOpt) { + sgData = llvm::to_vector(sgDataOpt.asArrayRef()); + } + if (instDataOpt) { + instData = llvm::to_vector(instDataOpt.asArrayRef()); + } + if (laneDataOpt) { + laneData = llvm::to_vector(laneDataOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgData.size())) + sgData[dim] = 1; + if (dim < static_cast<int64_t>(instData.size())) + instData[dim] = 1; + if (dim < static_cast<int64_t>(laneData.size())) + laneData[dim] = 1; + } + + return LayoutAttr::get( + getContext(), getSgLayout(), + sgData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgData), + instData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), instData), + getLaneLayout(), + laneData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneData), + getOrder()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + auto sgLayoutOpt = getSgLayout(); + auto laneLayoutOpt = getLaneLayout(); + + SmallVector<int32_t> sgLayout; + SmallVector<int32_t> laneLayout; + + if (sgLayoutOpt) { + sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef()); + } + if (laneLayoutOpt) { + laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgLayout.size())) + sgLayout[dim] = 1; + if (dim < static_cast<int64_t>(laneLayout.size())) + laneLayout[dim] = 1; + } - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return LayoutAttr::get( + getContext(), + sgLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgLayout), + getSgData(), getInstData(), + laneLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneLayout), + getLaneData(), getOrder()); } //===----------------------------------------------------------------------===// @@ -376,34 +521,43 @@ SliceAttr SliceAttr::flatten() const { } FailureOr<SmallVector<Value>> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { SliceAttr attr = flatten(); auto parent = dyn_cast<LayoutAttr>(attr.getParent()); - return parent.delinearizeSubgroupId(builder, loc, linearId); + return parent.delinearizeId(builder, loc, linearId); } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); if (!isForWorkgroup()) return failure(); - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { + return failure(); + } + + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); @@ -413,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, SmallVector<Value> sgIds = XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return genCoordinates(builder, loc, sgIds, layout, subShape, shape); } bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { @@ -437,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::LayoutAttr>(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + +// Helper function to adjust unit dimensions from sliced space to parent space +static SetVector<int64_t> +adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims, + ArrayRef<int64_t> sliceDims) { + // Reconstruct parent's non-sliced dimensions + + int64_t parentRank = sliceDims.size() + unitDims.size(); + llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(), + sliceDims.end()); + SmallVector<int64_t> nonSlicedDims; + for (int64_t i = 0; i < parentRank; ++i) { + if (!slicedDimsSet.contains(i)) + nonSlicedDims.push_back(i); + } + + // Map unit dims from sliced space to parent space + SetVector<int64_t> adjustUnitDims; + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(nonSlicedDims.size())) { + adjustUnitDims.insert(nonSlicedDims[dim]); + } + } + + return adjustUnitDims; +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims), + attr.getDims()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims), + attr.getDims()); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2..91ba07a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, + UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref<InFlightDiagnostic()> emitError) { if (!dataTy) { if (subgroup_block_io) return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; + "are only allowed when result is a VectorType."; else return success(); } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, ArrayRef<int64_t> dataShape = dataTy.getShape(); ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + ArrayAttr strideAttr = mdescTy.getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + if (subgroup_block_io && layout) { + auto laneData = layout.getEffectiveLaneDataAsInt(); + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (!laneData.empty()) { + bool isLaneDataContiguous = + std::all_of(laneData.begin(), std::prev(laneData.end()), + [](int x) { return x == 1; }); + if (!isLaneDataContiguous) + return emitError() << "With subgroup_block_io, accessed data must be " + "contiguous and coalesced."; + for (size_t i = 0; i < laneData.size(); ++i) { + if (laneLayout[i] != blockShape[i]) + return emitError() << "With subgroup_block_io, the block shape must " + "match the lane layout."; + if (laneLayout[i] != 1 && strides[i] != 1) + return emitError() << "With subgroup_block_io, the distributed " + "dimensions must be contiguous."; + } + } + } if (dataShape.size() == 2) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitError() << "data shape must not exceed mem_desc shape."; } else { - SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute if (subgroup_block_io && !blockShape.size()) @@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult PrefetchNdOp::verify() { @@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() { return emitOpError("invalid l3_hint: ") << getL3HintAttr(); int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, @@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, - packed, transpose, l1_hint, l2_hint, l3_hint); + packed, transpose, l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/layout); } LogicalResult LoadNdOp::verify() { @@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() { << tdescTy; int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/nullptr); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult StoreNdOp::verify() { @@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() { << dstTy; int64_t tDescRank = dstTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, - IntegerAttr{}); + IntegerAttr{}, /*anchor_layout=*/nullptr); } //===----------------------------------------------------------------------===// @@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + auto loc = source.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); +} + +void StoreScatterOp::build( + OpBuilder &builder, OperationState &state, Value value, Value dest, + ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { + auto loc = dest.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() { MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() { UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } namespace mlir { diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..48fe841 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000..e6009d5 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,695 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" + +#include <optional> + +#include "llvm/Support/DebugLog.h" +#define DEBUG_TYPE "xegpu-transforms" + +using namespace mlir; +using namespace mlir::transform; + +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, get that value and cast it to int type. +static DiagnosedSilenceableFailure convertMixedValuesToInt( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { + for (OpFoldResult ofr : ofrs) { + // Attribute case. + if (auto attr = dyn_cast<Attribute>(ofr)) { + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + result.push_back(intAttr.getInt()); + continue; + } + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + } + + // Transform param case. + Value transformValue = cast<Value>(ofr); + if (isa<TransformParamTypeInterface>(transformValue.getType())) { + ArrayRef<Attribute> params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back( + cast<IntegerAttr>(params.front()).getValue().getSExtValue()); + continue; + } + + // Payload value case. + auto payloadOps = state.getPayloadOps(transformValue); + if (!llvm::hasSingleElement(payloadOps)) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(transformValue.getLoc()) + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; + return diag; + } + + Operation *op = *payloadOps.begin(); + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + + IntegerAttr intAttr; + if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) + return transformOp.emitSilenceableError() + << "requires param or handle to be the result of a constant like " + "op"; + + result.push_back(intAttr.getInt()); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Find producer operation of type T for the given value. +/// It's assumed that producer ops are chained through their first operand. +/// Producer chain is traced trough loop block arguments (init values). +template <typename T> +static std::optional<T> findProducerOfType(Value val) { + Value currentValue = val; + if (!currentValue.getDefiningOp()) { + // Value may be a block argument initialized outside a loop. + if (val.getNumUses() == 0) { + LDBG() << "Failed to find producer op, value has no uses."; + return std::nullopt; + } + auto userOp = val.getUsers().begin(); + auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>(); + if (!parentLoop) { + LDBG() << "Failed to find producer op, not in a loop."; + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) { + auto numInductionVars = parentLoop.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = parentLoop.getInits()[iterArgIdx]; + } else { + LDBG() << "Failed to find producer op, value not in init values."; + return std::nullopt; + } + } + Operation *producerOp = currentValue.getDefiningOp(); + + if (auto matchingOp = dyn_cast<T>(producerOp)) + return matchingOp; + + if (producerOp->getNumOperands() == 0) + return std::nullopt; + + return findProducerOfType<T>(producerOp->getOperand(0)); +} + +/// Create a layout attribute from the given parameters. +static xegpu::LayoutAttr +createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, + ArrayRef<int32_t> sgData, + std::optional<ArrayRef<int32_t>> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Generate `xegpu::LayoutAttr` from op mixed layout values. +DiagnosedSilenceableFailure +getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, + TransformOpInterface transformOp, + ArrayRef<::mlir::OpFoldResult> mixedSgLayout, + ArrayRef<::mlir::OpFoldResult> mixedSgData, + ArrayRef<::mlir::OpFoldResult> mixedInstData, + xegpu::LayoutAttr &layoutAttr) { + SmallVector<int32_t> sgLayout, sgData, instData; + auto status = + convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional<ArrayRef<int32_t>>(instData); + + layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData); + + return DiagnosedSilenceableFailure::success(); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +static xegpu::CreateNdDescOp +setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, + xegpu::DistributeLayoutAttr layout) { + assert(descOp.getMixedOffsets().size() == 0 && + "create desc op with offsets is not supported"); + auto oldTensorDesc = descOp.getType(); + auto descType = xegpu::TensorDescType::get( + oldTensorDesc.getShape(), oldTensorDesc.getElementType(), + /*array_length=*/oldTensorDesc.getArrayLength(), + /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), + /*memory_space=*/oldTensorDesc.getMemorySpace(), + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( + descOp, descType, descOp.getSource(), descOp.getMixedSizes(), + descOp.getMixedStrides()); + return newDescOp; +} + +DiagnosedSilenceableFailure +transform::GetDescOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + + auto maybeDescOp = + findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) + << "Could not find a matching descriptor op when walking the " + "producer chain of the first operand."; + } + + results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::build(OpBuilder &builder, + OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedSgLayout, + ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData, + ArrayRef<int64_t> sliceDims) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, result, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims); +} + +DiagnosedSilenceableFailure +transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + + // For now only create_nd_desc op is supported. + auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); + if (!descOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.create_nd_desc op, but got: " + << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Set layout attr in desc op's return type. Replaces old desc op. + auto newdescOp = setDescLayout(rewriter, descOp, layout); + + // Map result handles. + results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::SetOpLayoutAttrOp::build( + OpBuilder &builder, OperationState &ostate, Value target, int64_t index, + ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims, + bool result) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*index=*/index, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims, + /*result=*/result); +} + +DiagnosedSilenceableFailure +transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + bool resultTarget = getResult(); + + int64_t index = getIndex(); + if (resultTarget && index >= target->getNumResults()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op results"; + } + if (!resultTarget && index >= target->getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op operands"; + } + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + + // Set layout attribute for the op result or operand + if (resultTarget) + xegpu::setDistributeLayoutAttr(target->getResult(index), layout); + else + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetOpLayoutAttrOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + modifiesPayload(effects); +} + +void transform::SetGPULaunchThreadsOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedThreads) { + SmallVector<int64_t> staticThreads; + SmallVector<Value> dynamicThreads; + dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); + build(builder, ostate, target.getType(), + /*target=*/target, + /*threads=*/dynamicThreads, + /*static_threads=*/staticThreads); +} + +DiagnosedSilenceableFailure +transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + auto launchOp = dyn_cast<gpu::LaunchOp>(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + SmallVector<int32_t> threads; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); + if (!status.succeeded()) + return status; + + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads argument to consist of three values (got " + << threads.size() << ")"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getThreadsMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector<int32_t> dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value); + if (!maybeLoadOp) + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType<scf::ForOp>(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value); + if (!maybeDescOp) + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold<arith::MulIOp>( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold<arith::AddIOp>( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector<OpFoldResult> { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector<Value> dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint, + /*layout=*/nullptr); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint, /*layout=*/nullptr); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + + results.set(llvm::cast<OpResult>(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::ConvertLayoutOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedInputSgLayout, + ArrayRef<OpFoldResult> mixedInputSgData, + ArrayRef<OpFoldResult> mixedInputInstData, + ArrayRef<OpFoldResult> mixedTargetSgLayout, + ArrayRef<OpFoldResult> mixedTargetSgData, + ArrayRef<OpFoldResult> mixedTargetInstData) { + SmallVector<int64_t> staticInputSgLayout, staticInputSgData, + staticInputInstData; + SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData, + dynamicInputInstData; + dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout, + staticInputSgLayout); + dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData, + staticInputSgData); + dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData, + staticInputInstData); + SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData, + staticTargetInstData; + SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData, + dynamicTargetInstData; + dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout, + staticTargetSgLayout); + dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData, + staticTargetSgData); + dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData, + staticTargetInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*input_sg_layout=*/dynamicInputSgLayout, + /*input_sg_data=*/dynamicInputSgData, + /*input_inst_data=*/dynamicInputInstData, + /*target_sg_layout=*/dynamicTargetSgLayout, + /*target_sg_data=*/dynamicTargetSgData, + /*target_inst_data=*/dynamicTargetInstData, + /*static_input_sg_layout=*/staticInputSgLayout, + /*static_input_sg_data=*/staticInputSgData, + /*static_input_inst_data=*/staticInputInstData, + /*static_target_sg_layout=*/staticTargetSgLayout, + /*static_target_sg_data=*/staticTargetSgData, + /*static_target_inst_data=*/staticTargetInstData); +} + +DiagnosedSilenceableFailure +transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + // Construct layout attributes. + xegpu::LayoutAttr inputLayoutAttr = nullptr; + auto status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedInputSgLayout(), + getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr); + if (!status.succeeded()) + return status; + + xegpu::LayoutAttr targetLayoutAttr = nullptr; + status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedTargetSgLayout(), + getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr); + if (!status.succeeded()) + return status; + + // Find first user op to define insertion point for layout conversion. + if (value.use_empty()) + return emitSilenceableFailure(getLoc()) + << "Value has no users to insert layout conversion."; + Operation *userOp = *value.getUsers().begin(); + + // Emit convert_layout op. + rewriter.setInsertionPoint(userOp); + auto convLayoutOp = + xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(), + value, inputLayoutAttr, targetLayoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + value, convLayoutOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + + results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getInputSgLayoutMutable(), effects); + onlyReadsHandle(getInputSgDataMutable(), effects); + onlyReadsHandle(getInputInstDataMutable(), effects); + onlyReadsHandle(getTargetSgLayoutMutable(), effects); + onlyReadsHandle(getTargetSgDataMutable(), effects); + onlyReadsHandle(getTargetInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +namespace { +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect<scf::SCFDialect>(); + declareGeneratedDialect<arith::ArithDialect>(); + declareGeneratedDialect<xegpu::XeGPUDialect>(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + >(); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions<XeGPUTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index e6f7606..29b645f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUWgToSgDistribute.cpp XeGPUPropagateLayout.cpp XeGPUVectorLinearize.cpp + XeGPUOptimizeBlockLoads.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp new file mode 100644 index 0000000..ab41fe4 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -0,0 +1,490 @@ +//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include <optional> + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-optimize-block-loads" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; + +namespace { + +/// Get the 2D lane data from a tensor desc type if it exists. +static std::optional<SmallVector<int64_t>> +getMaybeLaneData(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneData = layout.getEffectiveLaneDataAsInt(); + if (laneData.size() != 2) + return std::nullopt; + return laneData; +} + +/// Get the 2D lane layout from a tensor desc type if it exists. +static std::optional<SmallVector<int64_t>> +getMaybeLaneLayout(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (laneLayout.size() != 2) + return std::nullopt; + return laneLayout; +} + +/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 && +/// lane[1] == 1), but inner lane data is not equal to [1, 1]. +/// Example: +/// !xegpu.tensor_desc<16x16xf16, +/// #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> +/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form) +/// indicating that this is a load that requires transpose effect. However, +/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from +/// the inner dimension. We convert this to a optimized form by converting the +/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the +/// later lowering easily use the load with transpose instruction. +static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout, + ArrayRef<int64_t> laneData) { + if (laneLayout.size() != 2 || laneData.size() != 2) + return false; + if (laneLayout[0] == 1 || laneLayout[1] != 1) + return false; + if (laneData[0] != 1 || laneData[1] == 1) + return false; + return true; +} + +/// A tensor desc type can be optimized if its element type is less than 32 bits +/// and its layout can be optimized. +static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) { + // If the dtype is greater or equal to 32 bits, layout must be valid. + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + if (elementTyBitwidth >= 32) + return false; + auto maybeLaneLayout = getMaybeLaneLayout(tdescType); + auto maybeLaneData = getMaybeLaneData(tdescType); + if (!maybeLaneData || !maybeLaneLayout) + return false; + return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData); +} + +/// Check if a tensor desc type can be optimized for transpose, if so return the +/// new optimized tensor desc type with a valid transpose layout. +static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType, + const uArch *targetuArch) { + if (!canBeOptimizedForTranspose(tdescType)) + return tdescType; + auto laneData = getMaybeLaneData(tdescType) + .value(); // Lane data must exist if we reach here. + int64_t innerLaneData = laneData[1]; + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + // Required shape is total shape of the vector result that this tensor desc + // must eventually load after adjusting for the new bitwidth and array + // length. + SmallVector<int64_t> requiredShape(tdescType.getShape()); + requiredShape.back() = + requiredShape.back() * tdescType.getArrayLength() / innerLaneData; + int newBitWidth = elementTyBitwidth * innerLaneData; + Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth); + // Supported shape is the max transpose shape that can be supported by + // hardware that is less than or equal to required shape. + auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>( + targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad)); + auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount( + newElemTy, /** has transform */ false, /** has transpose */ true); + // If no HW params found, return the original type. + if (!maybeHWParams) + return tdescType; + auto [widths, heights, counts] = maybeHWParams.value(); + // TODO: Currently we expect array length to be 1 for transpose case. + if (counts.size() != 1 || counts[0] != 1) + return tdescType; + int arrayLen = counts[0]; + int supportedHeight = + xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights); + int supportedWidth = + xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths); + // If no supported height or width found, return the original type. + if (supportedHeight == -1 || supportedWidth == -1) + return tdescType; + + SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth}; + xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get( + tdescType.getContext(), + tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1}); + // Array length can not be larger than 1 for transpose case. + return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen, + tdescType.getBoundaryCheck(), + tdescType.getMemorySpace(), newLayout); +} + +/// Helper to convert an OpFoldResult to Value. +static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc, + OpFoldResult ofr) { + std::optional<int64_t> mayBeInt = getConstantIntValue(ofr); + if (mayBeInt) + return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult(); + return llvm::cast<Value>(ofr); +} + +/// Helper to divide a Value by a constant integer. +static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc, + Value val, int64_t constant) { + // If the constant is a power of 2, use right shift for division. + if (llvm::isPowerOf2_64(constant)) { + int64_t shiftAmount = llvm::Log2_64(constant); + return arith::ShRUIOp::create( + rewriter, loc, val, + arith::ConstantIndexOp::create(rewriter, loc, shiftAmount) + .getResult()) + .getResult(); + } + auto constantOp = + arith::ConstantIndexOp::create(rewriter, loc, constant).getResult(); + return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult(); +} + +/// This function takes a larger register block `data` and generates multiple +/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block +/// starting from `offsets`. +static Value generateLoads(ConversionPatternRewriter &rewriter, + TypedValue<VectorType> data, + SmallVector<OpFoldResult> offsets, + TypedValue<xegpu::TensorDescType> newTensorDesc, + xegpu::LoadNdOp origLoadOp) { + Location loc = data.getLoc(); + assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp"); + Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]); + Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]); + SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape()); + // Compute the ratio between original shape and supported shape. We need to + // generate loads in this ratio arrangement. + auto shapeRatio = computeShapeRatio(data.getType().getShape(), + supportedShape) + .value(); // `ratio` must be defined if we reach here. + for (int64_t h = 0; h < shapeRatio[0]; ++h) { + for (int64_t w = 0; w < shapeRatio[1]; ++w) { + int64_t localOffsetDim0 = h * supportedShape[0]; + int64_t localOffsetDim1 = w * supportedShape[1]; + Value loadOffsetX = arith::AddIOp::create( + rewriter, loc, offsetDim0, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0) + .getResult()); + Value loadOffsetY = arith::AddIOp::create( + rewriter, loc, offsetDim1, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1) + .getResult()); + auto loadOp = xegpu::LoadNdOp::create( + rewriter, loc, + VectorType::get(supportedShape, data.getType().getElementType()), + newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY}, + origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(), + origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(), + origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr()); + // Set the layout for the loadOp. + auto layoutAttr = newTensorDesc.getType().getLayoutAttr(); + xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr); + // Insert the loaded block into the right position in data. + auto insertOp = vector::InsertStridedSliceOp::create( + rewriter, loc, loadOp.getResult(), data, + ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1}, + ArrayRef<int64_t>{1, 1}); + // InsertOp must have the same layout as newTensorDesc. + xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr); + data = insertOp.getResult(); + } + } + return data; +} + +/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a +/// new CreateNdDescOp with optimized tensor desc type. This involves extracting +/// the base pointer from the original memory source and adjusting the shape and +/// strides of the tensor desc to fit with the new optimized transpose layout. +class XeGPUCreateNdDescOpPattern final + : public OpConversionPattern<xegpu::CreateNdDescOp> { +public: + using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tdescTy = createNdOp.getType(); + // Get the target uArch info. + auto chipStr = xegpu::getChipStr(createNdOp); + // Check if the chip is supported. + assert( + chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") && + "Expecting target chip to be pvc or bmg for transpose optimization."); + const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value()); + + auto convertType = tryOptimize(tdescTy, targetuArch); + if (convertType == tdescTy) + return failure(); + auto strides = createNdOp.getMixedStrides(); + auto maybeConstInnerStride = getConstantIntValue(strides.back()); + // Only row-major memrefs are expected for now. + if (!maybeConstInnerStride || *maybeConstInnerStride != 1) + return rewriter.notifyMatchFailure( + createNdOp, "Expecting row-major memref for transpose optimization."); + Value source = createNdOp.getSource(); + auto optionalLaneData = getMaybeLaneData(tdescTy); + assert(optionalLaneData && "Expected 2D lane data"); + auto laneData = optionalLaneData.value(); + int64_t innerLaneData = laneData[1]; + auto memrefType = dyn_cast<MemRefType>(source.getType()); + // Inner dimension of the shape must be adjusted based on innerLaneData. + SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes()); + modifiedShape.back() = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()), + innerLaneData); + // Similarly, second to last stride must be adjusted. + assert(strides.size() >= 2 && + "Expected at least 2 strides for CreateNdDescOp"); + SmallVector<OpFoldResult> modifiedStrides(strides); + modifiedStrides[modifiedStrides.size() - 2] = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), + modifiedStrides[modifiedStrides.size() - 2]), + innerLaneData); + + // If the source is a static memref, we need to extract the pointer to + // base address. + if (memrefType && memrefType.hasStaticShape()) { + auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, createNdOp.getLoc(), source); + source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(), + rewriter.getI64Type(), + extractOp.getResult()) + .getResult(); + } + // Create a new CreateNdDescOp with the modified shape and converted type. + auto newCreateNdDescOp = xegpu::CreateNdDescOp::create( + rewriter, createNdOp.getLoc(), convertType, source, modifiedShape, + modifiedStrides); + rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult()); + return success(); + } +}; + +/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for +/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the +/// adjusted tensor desc type. This can result in multiple LoadNdOps being +/// generated to fill in the original load shape. +class XeGPULoadNdDescOpPattern final + : public OpConversionPattern<xegpu::LoadNdOp> { +public: + using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto origTensorDescType = loadNdOp.getTensorDescType(); + auto adaptorType = + cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType()); + if (adaptorType == origTensorDescType) + return failure(); + // Offsets must be adjusted based on innerLaneData. + auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value(); + int64_t innerLaneData = laneData[1]; + auto offsets = loadNdOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(loadNdOp, + "Expecting offsets in LoadNd"); + SmallVector<OpFoldResult> modifiedOffsets(offsets); + modifiedOffsets.back() = divideByConstant( + rewriter, loadNdOp.getLoc(), + convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()), + innerLaneData); + // Get the 2D data shape of this loadNdOp in its original type including + // array length. + SmallVector<int64_t> origDataShape(origTensorDescType.getShape()); + // Adjust the data shape based on innerLaneData. + origDataShape.back() /= innerLaneData; + // HW supported shape is the new tensor desc shape after conversion. + SmallVector<int64_t> hwSupportedShape(adaptorType.getShape()); + VectorType origVectorType = + VectorType::get(origDataShape, adaptorType.getElementType()); + Value data; + // Orig data shape is 3D for the array length case. + if (origTensorDescType.getArrayLength() > 1) { + SmallVector<Value> arraySlices; + for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) { + Value slice = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), origVectorType, + rewriter.getZeroAttr(origVectorType)); + // Increase the Y offset for each array slice. + Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(), + modifiedOffsets.back()); + modifiedOffsets.back() = + arith::AddIOp::create( + rewriter, loadNdOp->getLoc(), offsetY, + arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(), + i * origDataShape[1]) + .getResult()) + .getResult(); + slice = generateLoads( + rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets, + cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()), + loadNdOp); + // BitCast back to original load shape without array length. + auto bitcastType = VectorType::get(origTensorDescType.getShape(), + origTensorDescType.getElementType()); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + bitcastType, slice); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + arraySlices.push_back(bitCastOp.getResult()); + } + rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices}); + return success(); + } + data = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), + VectorType::get(origDataShape, adaptorType.getElementType()), + rewriter.getZeroAttr(origVectorType)); + data = generateLoads( + rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets, + cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()), + loadNdOp); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + loadNdOp.getType(), data); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + rewriter.replaceOp(loadNdOp, bitCastOp); + return success(); + } +}; + +/// Vector ExtractOp must be processed if the original tensor desc type has +/// array length greater than 1. In this case, the LoadNdOp is replaced with +/// multiple LoadNdOps for each array slice making the extraction unnecessary. +/// In this case, we simply remove the ExtractOp. +class VectorExtractOpPattern final + : public OpConversionPattern<vector::ExtractOp> { +public: + using OpConversionPattern<vector::ExtractOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check if the source of the extraction is split to multiple values. + if (adaptor.getSource().size() == 1) + return failure(); + auto mixedPos = extractOp.getMixedPosition(); + if (mixedPos.size() != 1) + return failure(); + auto mayBeInt = getConstantIntValue(mixedPos[0]); + if (!mayBeInt) + return failure(); + rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]); + return success(); + } +}; + +} // namespace + +void xegpu::populateXeGPUOptimizeBlockLoadsPatterns( + RewritePatternSet &patterns) { + patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern, + VectorExtractOpPattern>(patterns.getContext()); +} + +namespace { + +struct XeGPUOptimizeBlockLoadsPass final + : public xegpu::impl::XeGPUOptimizeBlockLoadsBase< + XeGPUOptimizeBlockLoadsPass> { + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + // This pass is only meant for PVC and BMG targets. If unsupported target + // is found, exit early. + bool isTargetSupported = false; + getOperation()->walk([&](gpu::GPUFuncOp funcOp) { + auto chipStr = xegpu::getChipStr(funcOp); + if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg")) + isTargetSupported = true; + }); + + if (!isTargetSupported) { + DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets." + << "\n"; + return; + } + + // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be + // converted. + target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>( + [&](xegpu::CreateNdDescOp createNdOp) { + return !canBeOptimizedForTranspose(createNdOp.getType()); + }); + target.addDynamicallyLegalOp<xegpu::LoadNdOp>( + [&](xegpu::LoadNdOp loadNdOp) { + return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType()); + }); + // Vector ExtractOps can have optimizable layouts if they extract from + // LoadNdOps with array length greater than 1. These ExtractOps must be + // converted. + target.addDynamicallyLegalOp<vector::ExtractOp>( + [&](vector::ExtractOp extractOp) { + auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult()); + if (!layout) + return true; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + auto laneData = layout.getEffectiveLaneDataAsInt(); + return !canBeOptimizedForTranspose(laneLayout, laneData); + }); + converter.addConversion([](Type type) { return type; }); + + target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect, + vector::VectorDialect>(); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + DBGS() << "Optimize block loads pass failed.\n"; + return signalPassFailure(); + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 90eae87..dc9eb96 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -53,6 +53,8 @@ using namespace mlir::dataflow; namespace { +enum class LayoutKind { Lane, InstData }; + //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// @@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { llvm_unreachable("Join should not be triggered by layout propagation."); } -/// Construct a new layout with the transposed lane layout and lane data. +/// Construct a new layout with the transposed inst_data or lane_layout, +/// lane_data. LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { if (!isAssigned()) return {}; @@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { SmallVector<int32_t> laneData; SmallVector<int32_t> instData; for (int64_t idx : permutation) { - laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); - laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); - instData.push_back(static_cast<int32_t>(getInstData()[idx])); + if (getLaneLayout().size()) { + laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); + laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); + } + if (getInstData().size()) + instData.push_back(static_cast<int32_t>(getInstData()[idx])); } - return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData, - laneLayout, laneData)); + xegpu::LayoutAttr layoutAttr; + if (getLaneLayout().size()) + layoutAttr = + xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData); + if (getInstData().size()) + layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData); + return LayoutInfo(layoutAttr); } //===----------------------------------------------------------------------===// @@ -204,28 +215,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> { using Lattice::Lattice; }; -/// Helper Function to find a proper instruction multiple for the user-supplied -/// sg-level data shape. `candidates` are uArch allowed shapes. -/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count). -template <typename T> -int getLargestDivisor(T dim, ArrayRef<T> candidates, - ArrayRef<T> candidateMultiples = {}) { - static_assert(std::is_integral<T>::value, "T must be an integer type"); - int largest = -1; - SmallVector<T> multiples = {1}; - if (!candidateMultiples.empty()) - multiples = - SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end()); - for (T candidate : candidates) { - for (T multiple : multiples) { - int value = static_cast<int>(candidate * multiple); - if (value != 0 && dim % value == 0 && value > largest) - largest = value; - } - } - return largest; -} - /// Helper Functions to get default layouts. A `default layout` is a layout that /// is assigned to a value when the layout is not fixed by some anchor operation /// (like DPAS). @@ -235,15 +224,14 @@ int getLargestDivisor(T dim, ArrayRef<T> candidates, /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank, - const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData) { + const xegpu::uArch::uArch *uArch) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( - xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1})); + xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1})); } - return LayoutInfo(xegpu::LayoutAttr::get( - ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1})); + return LayoutInfo( + xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1})); } static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, @@ -258,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -269,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (vectorTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {uArch->getSubgroupSize(), 1}, {1, packingFactor})); } - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor})); } @@ -286,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -297,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (tdescTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); int subgroupSize = uArch->getSubgroupSize(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor})); + tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor})); } return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor})); + tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor})); } /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -320,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, const xegpu::uArch::uArch *uArch, - ArrayRef<int> instData, unsigned packingSize) { + unsigned packingSize) { Type elementTy = vectorTy.getElementType(); assert(elementTy.isIntOrFloat() && "Expected int or float type in DPAS operands"); @@ -332,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()), 1}); return LayoutInfo( - xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data)); + xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); } // Otherwise, return the default layout for the vector type. - return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize); + return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize); } //===----------------------------------------------------------------------===// @@ -350,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> { private: + LayoutKind layoutKind; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); @@ -400,10 +387,14 @@ private: ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); + bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout); + public: LayoutInfoPropagation(DataFlowSolver &solver, - SymbolTableCollection &symbolTable) - : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + SymbolTableCollection &symbolTable, + LayoutKind layoutKind) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + layoutKind(layoutKind) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; LogicalResult @@ -486,43 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation( return success(); } +bool LayoutInfoPropagation::hasParamsOfLayoutKind( + xegpu::DistributeLayoutAttr anchorLayout) { + if (anchorLayout == nullptr) { + return false; + } + if (layoutKind == LayoutKind::InstData) { + return !(anchorLayout.getEffectiveInstDataAsInt().empty()); + } else if (layoutKind == LayoutKind::Lane) { + return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || + anchorLayout.getEffectiveLaneDataAsInt().empty()); + } + return false; +} + void LayoutInfoPropagation::visitPrefetchNdOp( xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Here we assign the default layout to the tensor descriptor operand of - // prefetch. - auto tdescTy = prefetch.getTensorDescType(); - - auto uArch = getUArch(getChipStr(prefetch).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); - - auto blockWHC = - uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); - if (!blockWHC) - prefetch.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - prefetch.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (tdescTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); - if (instHeight == -1) + + LayoutInfo prefetchLayout; + xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + prefetchLayout = LayoutInfo(anchorLayout); + } else { + // Here we assign the default layout to the tensor descriptor operand of + // prefetch. + auto tdescTy = prefetch.getTensorDescType(); + + auto uArch = getUArch(getChipStr(prefetch).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + + auto blockWHC = + uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); + if (!blockWHC) + prefetch.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth); + if (instWidth == -1) prefetch.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; + if (tdescTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); + if (instHeight == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + prefetchLayout = + LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); + else + prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + + prefetch.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get())); } - auto prefetchLayout = getDefaultSIMTLayoutInfo( - tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize()); // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -561,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( // Only consider vector to vector broadcasts for now. VectorType resultTy = broadcast.getResultVectorType(); VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType()); - if (!sourceTy) { - broadcast.emitWarning("Expecting source type to be a vector type."); + // skip layout propagation for non-vector source operand. + if (!sourceTy) return; - } - // Only consider nD -> nD broadcast. + // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case. if (sourceTy.getRank() != resultTy.getRank()) { - broadcast.emitWarning("Expecting source and result to have same rank."); + auto sourceDims = sourceTy.getShape(); + auto resultDims = resultTy.getShape(); + SmallVector<int64_t> bcastDims; + auto dimDiff = resultTy.getRank() - sourceTy.getRank(); + // adding the missing leading dims + for (int i = 0; i < dimDiff; i++) + bcastDims.push_back(i); + + // for the rest dims in the resultTy, if sourceTy dim is 1, then it's + // broadcasted dim + for (size_t i = 0; i < sourceDims.size(); i++) + if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1)) + bcastDims.push_back(i + dimDiff); + + // create a slice layout for the source + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims)); + + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); return; } + SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims(); - if (broadcastUnitDims.size() != 1) { - broadcast.emitWarning("Expecting source type to be nD vector only with " - "one broadcasted dimension."); - return; - } - // Propagate the result layout to the source operand. + resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get()) + .setUnitDimData(broadcastUnitDims); propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } @@ -622,55 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp( void LayoutInfoPropagation::visitDpasOp( xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - VectorType aTy = dpas.getLhsType(); - VectorType bTy = dpas.getRhsType(); - - auto uArch = getUArch(getChipStr(dpas).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( - xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); - - const unsigned dataALen = aTy.getShape().front(); - auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); - const int maxALen = - getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); - if (maxALen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - - const unsigned dataBLen = bTy.getShape().back(); - auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); - const int maxBLen = - getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); - if (maxBLen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataA = {maxALen, subgroupSize}; - SmallVector<int> instDataB = {subgroupSize, maxBLen}; - - propagateIfChanged(operands[0], - operands[0]->meet(getSIMTLayoutInfoForDPASOperand( - aTy, 0, uArch, instDataA, - uArchInstruction->getPackedFormatBitSizeA()))); - propagateIfChanged(operands[1], - operands[1]->meet(getSIMTLayoutInfoForDPASOperand( - bTy, 1, uArch, instDataB, - uArchInstruction->getPackedFormatBitSizeB()))); - if (operands.size() > 2) { - VectorType cTy = dpas.getAccType(); - const unsigned dataCLen = bTy.getShape().back(); - auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); - const int maxCLen = - getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen)); - if (maxCLen == -1) + + LayoutInfo dpasALayout; + LayoutInfo dpasBLayout; + LayoutInfo dpasCDLayout; + + xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr(); + if (hasParamsOfLayoutKind(anchorLayoutCD)) { + xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr(); + xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr(); + assert(hasParamsOfLayoutKind(anchorLayoutA) && + "Expected anchor layout for DPAS A operand."); + assert(hasParamsOfLayoutKind(anchorLayoutB) && + "Expected anchor layout for DPAS B operand."); + dpasALayout = LayoutInfo(anchorLayoutA); + dpasBLayout = LayoutInfo(anchorLayoutB); + dpasCDLayout = LayoutInfo(anchorLayoutCD); + + } else { + + VectorType aTy = dpas.getLhsType(); + VectorType bTy = dpas.getRhsType(); + + auto uArch = getUArch(getChipStr(dpas).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + + const unsigned dataALen = aTy.getShape().front(); + auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); + const int maxALen = + xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); + if (maxALen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataC = {maxALen, maxCLen}; - propagateIfChanged(operands[2], - operands[2]->meet(getSIMTLayoutInfoForDPASOperand( - cTy, 2, uArch, instDataC, - uArchInstruction->getPackedFormatBitSizeB()))); + + const unsigned dataBLen = bTy.getShape().back(); + auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType()); + + const int maxBLen = + xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); + + if (maxBLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataA = {maxALen, subgroupSize}; + SmallVector<int> instDataB = {subgroupSize, maxBLen}; + + if (layoutKind == LayoutKind::InstData) { + dpasALayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); + dpasBLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); + } else { + dpasALayout = getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); + dpasBLayout = getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); + } + + if (operands.size() > 2) { + VectorType cTy = dpas.getAccType(); + if (layoutKind == LayoutKind::InstData) { + const unsigned dataCLen = bTy.getShape().back(); + auto supportedCLen = + uArchInstruction->getSupportedN(bTy.getElementType()); + const int maxCLen = xegpu::getLargestDivisor( + dataCLen, ArrayRef<unsigned>(supportedCLen)); + if (maxCLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataC = {maxALen, maxCLen}; + dpasCDLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); + } else + dpasCDLayout = getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + + dpas.setLayoutCdAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get())); + } + dpas.setLayoutAAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get())); + dpas.setLayoutBAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get())); + } + + propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); + propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); + if (operands.size() > 2) { + propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout)); } } @@ -679,37 +756,50 @@ void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - auto uArch = getUArch(getChipStr(store).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); - VectorType dataTy = store.getValueType(); - auto blockWHC = uArchInstruction->getBlockWidthHeightCount( - store.getValueType().getElementType()); - if (!blockWHC) - store.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - store.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (dataTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); - if (instHeight == -1) + LayoutInfo storeLayout; + xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + storeLayout = LayoutInfo(anchorLayout); + } else { + auto uArch = getUArch(getChipStr(store).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); + VectorType dataTy = store.getValueType(); + auto blockWHC = uArchInstruction->getBlockWidthHeightCount( + store.getValueType().getElementType()); + if (!blockWHC) + store.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth); + if (instWidth == -1) store.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; + if (dataTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); + if (instHeight == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + storeLayout = + LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); + else + storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, + uArchInstruction->getPackedFormatBitSize()); + store.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get())); } - LayoutInfo storeLayout = - getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData, - uArchInstruction->getPackedFormatBitSize()); + // Propagate the layout to the value operand. // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -720,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp( void LayoutInfoPropagation::visitLoadNdOp( xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo valueLayout = results[0]->getValue(); - // Need the layout of the value to propagate to the tensor descriptor. - if (!valueLayout.isAssigned()) - return; - LayoutInfo tensorDescLayout = valueLayout; - // LoadNdOp has the transpose effect. However, at the stage of this analysis - // this effect is not expected and should be abstracted away. Emit a - // warning. - if (auto transpose = load.getTranspose()) { - load.emitWarning("Transpose effect is not expected for LoadNdOp at " - "LayoutInfoPropagation stage."); - tensorDescLayout = valueLayout.transpose(transpose.value()); + + LayoutInfo loadLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + } else { + + LayoutInfo valueLayout = results[0]->getValue(); + // Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + loadLayout = valueLayout; + // LoadNdOp has the transpose effect. However, at the stage of this analysis + // this effect is not expected and should be abstracted away. Emit a + // warning. + if (auto transpose = load.getTranspose()) { + load.emitWarning("Transpose effect is not expected for LoadNdOp at " + "LayoutInfoPropagation stage."); + loadLayout = valueLayout.transpose(transpose.value()); + } + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); } // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); } /// For vector::TransposeOp, the layout of the result is transposed and @@ -824,33 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp( void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // The layout is strictly determined by the payload type. - auto payloadTy = dyn_cast<VectorType>(load.getValueType()); - if (!payloadTy) { - load.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - auto uArch = getUArch(getChipStr(load).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - SmallVector<int> instData{subgroupSize}; - if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto srcTdescTy = - dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { - if (srcTdescTy.getChunkSizeAsInt() > 1) + + LayoutInfo loadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + maskLayout = loadLayout; + } else { + + // The layout is strictly determined by the payload type. + VectorType payloadTy = load.getValueType(); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + auto uArch = getUArch(getChipStr(load).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) instData.push_back(chunkSize); - } - LayoutInfo layout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered*/ true); + else if (auto srcTdescTy = + dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { + if (srcTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } - // Mask operand should have 1D default layout. - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + if (layoutKind == LayoutKind::InstData) + loadLayout = + LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); + else + loadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); + // Mask operand should have 1D default layout. + maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); + } // Propagate the new layout to the tensor descriptor operand. if (isa<xegpu::TensorDescType>(load.getSourceType())) - propagateIfChanged(operands[0], operands[0]->meet(layout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); if (load.getOffsets()) @@ -878,38 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp( void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Currently, for 2D StoreScatterOp we expect that the height dimension of - // the tensor descriptor is equal to the subgroup size. This is ensured by - // the op verifier. - auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType()); - if (!payloadTy) { - storeScatter.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - auto uArch = getUArch(getChipStr(storeScatter).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - - auto payloadShape = payloadTy.getShape(); - if (payloadShape.size() > 1) - assert( - payloadShape[0] == subgroupSize && - "Expected the first dimension of 2D tensor descriptor to be equal to " - "subgroup size."); - - SmallVector<int> instData{subgroupSize}; - if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto dstTdescTy = - dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) { - if (dstTdescTy.getChunkSizeAsInt() > 1) - instData.push_back(chunkSize); - } - LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered=*/true); - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + LayoutInfo payloadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + payloadLayout = LayoutInfo(anchorLayout); + maskLayout = payloadLayout; + } else { + // Currently, for 2D StoreScatterOp we expect that the height dimension of + // the tensor descriptor is equal to the subgroup size. This is ensured by + // the op verifier. + VectorType payloadTy = storeScatter.getValueType(); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + + auto uArch = getUArch(getChipStr(storeScatter).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + + if (layoutKind == LayoutKind::InstData) { + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = storeScatter.getChunkSize().value_or(0); + chunkSize > 1) + instData.push_back(chunkSize); + else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>( + storeScatter.getDestType())) { + if (dstTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + payloadLayout = LayoutInfo( + xegpu::LayoutAttr::get(storeScatter.getContext(), instData)); + } else { + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) + assert(payloadShape[0] == subgroupSize && + "Expected the first dimension of 2D tensor descriptor to be " + "equal to " + "subgroup size."); + payloadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered=*/true); + } + + maskLayout = + getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + + storeScatter.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get())); + } // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); // Propagate the destination (if tdesc) operand layout @@ -931,10 +1063,10 @@ class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) - RunLayoutInfoPropagation(Operation *op) : target(op) { + RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); - solver.load<LayoutInfoPropagation>(symbolTable); + solver.load<LayoutInfoPropagation>(symbolTable, layoutKind); (void)solver.initializeAndRun(op); } @@ -1041,7 +1173,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, } // If the result is a vector type, add a temporary layout attribute to the // op. - xegpu::setDistributeLayoutAttr(result, layout); + xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true); } return success(); } @@ -1174,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final } // namespace void XeGPUPropagateLayoutPass::runOnOperation() { - auto &analysis = getAnalysis<RunLayoutInfoPropagation>(); + LayoutKind layoutKind; + if (this->layoutKind == "lane") { + layoutKind = LayoutKind::Lane; + } else if (this->layoutKind == "inst") { + layoutKind = LayoutKind::InstData; + } else { + getOperation()->emitError("Unsupported layout kind option: " + + this->layoutKind); + signalPassFailure(); + return; + } + RunLayoutInfoPropagation analysis(getOperation(), layoutKind); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); @@ -1188,8 +1331,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() { return {}; xegpu::DistributeLayoutAttr layoutAttr = cast<xegpu::DistributeLayoutAttr>(layout.get()); - if (this->layoutKind == "lane") - layoutAttr = layoutAttr.dropInstData(); if (layout.isSliceLayout()) return cast<xegpu::SliceAttr>(layoutAttr); return cast<xegpu::LayoutAttr>(layoutAttr); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 5a3b27e..ca81c3c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" @@ -98,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { if (i < distributionStart) continue; - // Check if the dimension can be distributed evenly. if (dim % effectiveLaneLayout[i - distributionStart] != 0) return failure(); @@ -173,6 +173,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout, return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1; } +/// Given a vector type and its distributed vector type, return the list of +/// dimensions that are distributed. +static SmallVector<int64_t> getDistributedDims(VectorType originalType, + VectorType distributedType) { + assert(originalType.getRank() == distributedType.getRank() && + "sequential and distributed vector types must have the same rank"); + SmallVector<int64_t> distributedDims; + for (int64_t i = 0; i < originalType.getRank(); ++i) { + if (distributedType.getDimSize(i) != originalType.getDimSize(i)) { + distributedDims.push_back(i); + } + } + return distributedDims; +} + /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is /// contained within a WarpExecuteOnLane0Op. @@ -912,6 +927,183 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern { } }; +static SmallVector<Value> computeDistributedCoordinatesForMatrixOp( + PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout, + Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) { + SmallVector<Value> newCoods; + auto maybeCoords = + layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape); + if (failed(maybeCoords)) + return {}; + assert(maybeCoords.value().size() == 1 && + "Expected one set of distributed offsets"); + SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), + getAsOpFoldResult(origOffsets)); + newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>); + return newCoods; +} + +/// Pattern for distributing xegpu::LoadMatrixOp. +struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode); + if (!matrixOp) + return failure(); + + OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { + return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op; + }); + if (!producedByLastLoad) + return rewriter.notifyMatchFailure( + warpOp, "The last op is not xegpu::LoadMatrixOp"); + const int operandIdx = producedByLastLoad->getOperandNumber(); + + VectorType sgPayloadTy = + dyn_cast<VectorType>(matrixOp.getResult().getType()); + VectorType warpResultTy = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the load op must have offsets"); + SmallVector<Value> offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr<VectorType> distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector<Value> operands = {matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector<Type> operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + + SmallVector<size_t> newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector<Value> newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector<Value> newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create( + rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure, + newOperands[0], ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + // Resolve the output type and replace all uses. + rewriter.replaceAllUsesWith( + newWarpOp.getResult(operandIdx), + resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter)); + return success(); + } +}; + +/// Pattern for distributing xegpu::StoreMatrixOp. +struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp yield = warpOp.getTerminator(); + Operation *lastNode = yield->getPrevNode(); + auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode); + if (!matrixOp) + return failure(); + + VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType()); + if (!sgPayloadTy) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix op payload must be a vector type"); + + auto loc = matrixOp.getLoc(); + auto offsets = matrixOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(matrixOp, + "the store op must have offsets"); + SmallVector<Value> offsetsAsValues = + vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + + auto layout = matrixOp.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure( + matrixOp, "the matrix operation lacks layout attribute"); + + FailureOr<VectorType> distPayloadByWarpOpOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); + if (failed(distPayloadByWarpOpOrFailure)) + return rewriter.notifyMatchFailure( + matrixOp, "Failed to distribute matrix op payload based on layout."); + + SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()}; + const unsigned offsetsStartIdx = operands.size(); + operands.append(offsetsAsValues); + + SmallVector<Type> operandTypes = llvm::to_vector( + llvm::map_range(operands, [](Value v) { return v.getType(); })); + operandTypes[0] = *distPayloadByWarpOpOrFailure; + + SmallVector<size_t> newRetIndices; + gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, operands, operandTypes, newRetIndices); + SmallVector<Value> newOperands = llvm::map_to_vector( + newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); + DenseI64ArrayAttr newConstOffsetsAttr = + rewriter.getDenseI64ArrayAttr(newConstOffsets); + ValueRange currentOffsets = + ValueRange(newOperands).drop_front(offsetsStartIdx); + + SmallVector<Value> newCoords = currentOffsets; + rewriter.setInsertionPointAfter(newWarpOp); + + if (!matrixOp.getSubgroupBlockIoAttr()) { + newCoords = computeDistributedCoordinatesForMatrixOp( + rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), + currentOffsets); + } + + xegpu::StoreMatrixOp::create( + rewriter, loc, TypeRange{}, newOperands[0], newOperands[1], + ValueRange(newCoords), newConstOffsetsAttr, + matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); + rewriter.eraseOp(matrixOp); + return success(); + } +}; + /// Distribute a scattered load op. The logic and requirements are the same as /// for the scattered store distribution. The warpOp's payload vector is /// expected to be distributed by the load's result consumer. @@ -1231,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { } }; +/// This pattern distributes the `vector.broadcast` operation across lanes in a +/// warp. The pattern supports three use cases: +/// +/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input +/// vector +/// must have a slice layout of the result. If the distributed source and +/// target vector types are identical, this lowers to a no-op; otherwise, it +/// remains a broadcast but operates on distributed vectors. +/// +/// 2) Broadcast a same-rank vector with identical layouts for source and +/// target: +/// The source vector must have unit dimensions, and lane_data must be unit +/// size for those unit dims. This always lowers to a no-op. +/// +/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from +/// scalar to distributed result type. +/// +/// Example 1 (lowering to a broadcast with distributed types): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// %2 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>} +/// : vector<32xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// gpu.yield %0 : vector<32xf32> +/// } +/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32> +/// +/// Example 2 (no-op): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// %2 = vector.broadcast %1 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8x1xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// gpu.yield %1 : vector<8x1xf32> +/// } +/// // The broadcast is implicit through layout transformation (no-op) +/// "some_use"(%r#0) +/// ``` +struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); + if (!yieldOperand) + return failure(); + auto broadcastOp = + cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp()); + unsigned operandIdx = yieldOperand->getOperandNumber(); + + VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType()); + VectorType destType = + dyn_cast<VectorType>(broadcastOp.getResult().getType()); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0)); + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getDistributeLayoutAttr(broadcastOp.getResult()); + + FailureOr<VectorType> sourceDistType; + Type sourceElemOrDistType; + if (sourceType) { + + // Case 1 and 2: source is a vector type. + int64_t rankDiff = destType.getRank() - sourceType.getRank(); + if (rankDiff > 0) { + // Case 1: source is lower-rank than result. + bool isSliceOf = sourceLayout.isSliceOf(resultLayout); + if (!isSliceOf) + return rewriter.notifyMatchFailure( + warpOp, + "Broadcast input layout must be a slice of result layout."); + } + // case 2: source and result have same rank + if (rankDiff == 0) { + SetVector<int64_t> broadcastUnitDims = + broadcastOp.computeBroadcastedUnitDims(); + resultLayout = resultLayout.setUnitDimData(broadcastUnitDims); + bool isEqualTo = sourceLayout.isEqualTo(resultLayout); + if (!isEqualTo) + return rewriter.notifyMatchFailure( + warpOp, "For same-rank broadcast, source must be identical to " + "adjusted result layouts with unit dims."); + sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims); + } + + sourceDistType = + getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType); + if (failed(sourceDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the source vector type."); + } + sourceElemOrDistType = sourceDistType.value(); + + } else { + // Case 3: source is a scalar type. + if (sourceLayout) { + return rewriter.notifyMatchFailure( + warpOp, "Broadcast from scalar must not have a layout attribute."); + } + sourceElemOrDistType = broadcastOp.getSourceType(); + } + FailureOr<VectorType> destDistType = + getDistVecTypeBasedOnLaneLayout(resultLayout, destType); + if (failed(destDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the dest vector type."); + } + + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType, + newRetIndices); + + Value distributedSource = newWarpOp.getResult(newRetIndices[0]); + + Value newBroadcast = distributedSource; + + if (sourceElemOrDistType != destDistType.value()) { + rewriter.setInsertionPointAfter(newWarpOp); + newBroadcast = + vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(), + destDistType.value(), distributedSource); + } + + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast); + return success(); + } +}; + /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing /// `gpu.warp_execute_on_lane_0` region. struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { @@ -1291,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { } }; +// Distribute a `vector.extract_strided_slice` op feeding into yield op of an +// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +// advanced cases where the distributed dimension is partially extracted and +// currently not supported by the generic vector distribution patterns. +struct VectorExtractStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>); + if (!operand) + return failure(); + auto extractOp = + cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp()); + unsigned operandIdx = operand->getOperandNumber(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + // Find the distributed dimensions. + auto extractResultType = cast<VectorType>(operand->get().getType()); + auto distributedDims = + getDistributedDims(extractResultType, distributedType); + // Collect updated source type, sizes and offsets. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + VectorType updatedSourceType = extractOp.getSourceVectorType(); + SmallVector<Attribute> updatedSizes = llvm::map_to_vector( + extractOp.getSizes(), [](Attribute attr) { return attr; }); + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + extractOp.getOffsets(), [](Attribute attr) { return attr; }); + // If the result is distributed, it must be distributed in exactly one + // dimension. In this case, we adjust the sourceDistType, distributedSizes + // and distributedOffsets accordingly. + if (distributedDims.size() > 0) { + if (distributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Source can not be distributed in multiple dimensions."); + int64_t distributedDim = distributedDims[0]; + int sourceDistrDimSize = + extractOp.getSourceVectorType().getShape()[distributedDim]; + auto sourceLayout = + xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0)); + if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source of extract_strided_slice op lacks distribution " + "layout"); + auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt(); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = sourceLaneLayout[distributedDim]; + // Check if the source size in the distributed dimension is a multiple of + // subgroup size. + if (sourceDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Source size along distributed dimension is not a multiple of " + "subgroup size."); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + // We expect lane data to be all ones in this case. + if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source layout"); + // The offsets in the distributed dimention must be a multiple of subgroup + // size. + int64_t distrDimOffset = + cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt(); + if (distrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Offset along distributed dimension " + "is not a multiple of subgroup size."); + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, extractOp.getSourceVectorType()) + .value(); + // Update the distributed sizes to match the distributed type. + updatedSizes[distributedDim] = rewriter.getI64IntegerAttr( + distributedType.getDimSize(distributedDim)); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[distributedDim] = + rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source of the extract op from + // the warp op and creating a new extract op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value source = newWarpOp.getResult(newRetIndices[0]); + // Create a new extract op outside the warp op. + Value newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), distributedType, source, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + ArrayAttr::get(rewriter.getContext(), updatedSizes), + extractOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp); + return success(); + } +}; + +/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an +/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +/// advanced cases where the distributed dimension is partially inserted and +/// currently not supported by the generic vector distribution patterns. +struct VectorInsertStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = + operand->get().getDefiningOp<vector::InsertStridedSliceOp>(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandNumber).getType()); + // Find the distributed dimensions of the dest vector. + auto insertResultType = cast<VectorType>(operand->get().getType()); + auto destDistributedDims = + getDistributedDims(insertResultType, distributedType); + // Collect updated offsets, source type and dest type. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + insertOp.getOffsets(), [](Attribute attr) { return attr; }); + VectorType updatedSourceType = insertOp.getSourceVectorType(); + VectorType updatedDestType = insertOp.getDestVectorType(); + if (destDistributedDims.size() > 0) { + // Only single dimension distribution is supported. + if (destDistributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting source to be distributed in a single dimension."); + int64_t destDistributedDim = destDistributedDims[0]; + + VectorType srcType = insertOp.getSourceVectorType(); + VectorType destType = insertOp.getDestVectorType(); + // Currently we require that both source (kD) and dest (nD) vectors are + // distributed. This requires that distributedDim (d) is contained in the + // last k dims of the dest vector (d >= n - k). + int64_t sourceDistributedDim = + destDistributedDim - (destType.getRank() - srcType.getRank()); + if (sourceDistributedDim < 0) + return rewriter.notifyMatchFailure( + insertOp, + "distributed dimension must be in the last k (i.e. source " + "rank) dims of dest vector"); + int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim); + // Obtain the source and dest layouts. + auto destLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1)); + auto sourceLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0)); + if (!destLayout || !sourceLayout || + destLayout.getEffectiveLaneLayoutAsInt().empty() || + sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source or dest of insert_strided_slice op lacks " + "distribution layout"); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = + destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim]; + // We require that source and dest lane data are all ones to ensure + // uniform round robin distribution. + auto destLaneData = destLayout.getEffectiveLaneDataAsInt(); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) || + !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source and dest layouts"); + // Source distributed dim size must be multiples of subgroup size. + if (srcDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Distributed dimension size in source is not a multiple of " + "subgroup size."); + // Offsets in the distributed dimension must be multiples of subgroup + // size. + int64_t destDistrDimOffset = + cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt(); + if (destDistrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Offset along distributed dimension in dest is not a multiple of " + "subgroup size."); + // Update the source and dest types based on their layouts. + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, insertOp.getSourceVectorType()) + .value(); + updatedDestType = getDistVecTypeBasedOnLaneLayout( + destLayout, insertOp.getDestVectorType()) + .value(); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[destDistributedDim] = + rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source and dest of the insert op + // from the warp op and creating a new insert op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, + {updatedSourceType, updatedDestType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + Value valueToStore = newWarpOp.getResult(newRetIndices[0]); + Value dest = newWarpOp.getResult(newRetIndices[1]); + // Create a new insert op outside the warp op. + Value newInsertOp = vector::InsertStridedSliceOp::create( + rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + insertOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), + newInsertOp); + return success(); + } +}; + /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op /// outside of the warp op. @@ -1443,13 +2015,18 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( LoadNdDistribution, DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution, VectorMultiReductionDistribution, LoadDistribution, StoreDistribution, VectorTransposeDistribution, - VectorBitcastDistribution, + VectorBitcastDistribution, LoadMatrixDistribution, + StoreMatrixDistribution, MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext(), /*pattern benefit=*/regularPatternBenefit); - patterns.add<VectorShapeCastDistribution>( - patterns.getContext(), - /*pattern benefit=*/highPatternBenefit); + // For following patterns, we need to override the regular vector distribution + // patterns. Therefore, assign higher benefit. + patterns + .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution, + VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>( + patterns.getContext(), + /*pattern benefit=*/highPatternBenefit); } void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( @@ -1468,6 +2045,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Layouts are needed for vector type only. if (!isa<VectorType>(operand.get().getType())) continue; + if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op)) + continue; auto layout = xegpu::getDistributeLayoutAttr(operand.get()); if (!layout) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index e6e71cc..af63f09 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value { xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { return xegpu::LoadNdOp::create( rewriter, loc, newValueTy, convertedTdescs[0], offsets, op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createLoad, loc, rewriter); @@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], convertedTdescs[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -678,12 +687,16 @@ struct UnrollLoadGatherOpWithOffset pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } + auto layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); + SmallVector<Value> newOps; for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) { auto newOp = xegpu::LoadGatherOp::create( rewriter, loc, newValueTy, op.getSource(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } @@ -774,12 +787,16 @@ struct UnrollStoreScatterOpWithOffsets SmallVector<Value> convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + auto layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); + for (auto [v, o, m] : llvm::zip(convertedValues, convertedOffsets, convertedMasks)) { xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9fc5ad9..be82cda 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, if (origOffsets.empty()) return failure(); + // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() + xegpu::DistributeLayoutAttr layout; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> || + std::is_same_v<OpType, xegpu::StoreMatrixOp>) { + layout = op.getLayoutAttr(); + } else { + layout = op.getDescLayoutAttr(); + } + // not applicable to ops without workgroup layout attributes - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -114,7 +122,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. ArrayRef<int64_t> wgShape = op.getDataShape(); - auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeDescOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); @@ -189,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { xegpu::TensorDescType tdescTy = op.getType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), @@ -308,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); SmallVector<Value> newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { @@ -317,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); @@ -338,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -362,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -488,10 +506,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -737,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Location loc = op.getLoc(); auto eltType = vecType.getElementType(); - auto setLayoutIfNeeded = [&](Value val) { - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), - layout.dropSgLayoutAndData()); - } + auto setLayout = [&](Value val) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); }; if (vecAttr.isSplat()) { @@ -750,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Attribute singleVal = vecAttr.getSplatValue<Attribute>(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - setLayoutIfNeeded(cstOp->getResult(0)); + setLayout(cstOp->getResult(0)); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); - setLayoutIfNeeded(newConstOp->getResult(0)); + setLayout(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { @@ -830,8 +843,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { // Get subgroup id Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); @@ -859,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); - setLayoutIfNeeded(baseConstVec); - setLayoutIfNeeded(bcastOffset); - setLayoutIfNeeded(finalConst); + setLayout(baseConstVec); + setLayout(bcastOffset); + setLayout(finalConst); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); @@ -912,11 +925,12 @@ struct WgToSgLoadGatherOpWithOffset VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), - layout.dropSgLayoutAndData()); + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + newLayout); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -964,16 +978,14 @@ struct WgToSgStoreScatterOpWithOffset adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { auto store = xegpu::StoreScatterOp::create( rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + layout.dropSgLayoutAndData()); // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); } } rewriter.eraseOp(op); @@ -1052,7 +1064,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(sgOffsets)) return failure(); @@ -1065,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(steps->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), - layout.dropSgLayoutAndData()); - } + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); newOps.push_back(finalSteps); } @@ -1141,10 +1151,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } @@ -1205,10 +1213,8 @@ struct WgToSgMultiDimReductionOp auto newOp = vector::MultiDimReductionOp::create( rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } @@ -1217,6 +1223,142 @@ struct WgToSgMultiDimReductionOp } }; +// This pattern transforms vector.transpose ops to work at subgroup level. +struct WgToSgVectorTransposeOp + : public OpConversionPattern<vector::TransposeOp> { + using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + + ArrayRef<int64_t> wgShape = resultType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getVector()); + if (!sourceLayout || !sourceLayout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sourceSgLayout = + sourceLayout.getEffectiveSgLayoutAsInt(); + SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); + DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + DenseI32ArrayAttr resultOrder = layout.getOrder(); + + if (!sourceOrder || !resultOrder) { + return rewriter.notifyMatchFailure( + op, "Both source and result must have order attributes"); + } + + ArrayRef<int64_t> permutation = op.getPermutation(); + size_t permutationSize = permutation.size(); + if (sourceSgLayout.size() != permutationSize || + resultSgLayout.size() != permutationSize) { + return rewriter.notifyMatchFailure( + op, "Layouts and permutation must have the same rank"); + } + + // Check that sgLayout, sgData & order are properly transposed for source + // and result + if (!layout.isTransposeOf(sourceLayout, permutation)) + return rewriter.notifyMatchFailure( + op, "Result layout is not a valid transpose of source layout " + "according to permutation"); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + SmallVector<Value> newTransposeOps; + for (auto src : adaptor.getVector()) { + auto newTranspose = vector::TransposeOp::create( + rewriter, op.getLoc(), newResultType, src, permutation); + xegpu::setDistributeLayoutAttr(newTranspose->getResult(0), + layout.dropSgLayoutAndData()); + newTransposeOps.push_back(newTranspose.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newTransposeOps}); + return success(); + } +}; + +// Distribute vector mask ops to work at subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + SmallVector<Value> wgMaskDimSizes; + if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } + + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector<Value> newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector<Value> maskOperands; + + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; } // namespace namespace mlir { @@ -1231,7 +1373,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp>(patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1358,7 +1502,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); @@ -1377,16 +1524,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::BroadcastOp>( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - - target.addDynamicallyLegalOp<vector::MultiDimReductionOp>( - [=](vector::MultiDimReductionOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index a38993e..9f126fe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -140,10 +139,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { // for StoreMatrixOp, the layout is attached to the property of the op if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp)) return storeOp.getLayoutAttr(); - std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + + // check for "permament" layout only after "temporary" layout name lookup + // for backward compatibility + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp)) + return loadGatherOp.getLayoutAttr(); } if (auto arg = dyn_cast<BlockArgument>(value)) { @@ -171,27 +174,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) { std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + + // check for "permament" layout only after "temporary" layout name lookup + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + if (auto layout = storeScatterOp.getLayoutAttr()) + return layout; + return getDistributeLayoutAttr(opr.get()); } +// Returns the permanent layout attribute for the given result if it's +// available on the defining op. Otherwise returns the provided layout. +xegpu::DistributeLayoutAttr +maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, + const OpResult &result, mlir::Operation *owner, + const std::string &name) { + xegpu::DistributeLayoutAttr candidate = layout; + + if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) { + if (auto perm = loadOp.getLayoutAttr()) + candidate = perm; + } + + return candidate; +} + +// Returns the permanent layout attribute for the given operand if it's +// available on the defining op. Otherwise returns the provided layout. +xegpu::DistributeLayoutAttr +maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, + const OpOperand &operand, mlir::Operation *owner, + const std::string &name) { + xegpu::DistributeLayoutAttr candidate = layout; + unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); + + if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) { + if (idx == 0) { + if (auto perm = storeOp.getLayoutAttr()) + candidate = perm; + } + } + + return candidate; +} + template <typename T, typename> void xegpu::setDistributeLayoutAttr(const T &operandOrResult, - const DistributeLayoutAttr layout) { + const DistributeLayoutAttr layout, + bool respectPermLayout) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name)) - owner->setAttr(name, layout); + + if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) + return; + + DistributeLayoutAttr candidate = layout; + if (respectPermLayout) + candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name); + + if (candidate) + owner->setAttr(name, candidate); } // Explicit instantiation for OpResult template void xegpu::setDistributeLayoutAttr<mlir::OpResult>( const mlir::OpResult &result, - const mlir::xegpu::DistributeLayoutAttr layout); + const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); // Explicit instantiation for OpOperand template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>( const mlir::OpOperand &operand, - const mlir::xegpu::DistributeLayoutAttr layout); + const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); void xegpu::setDistributeLayoutAttrs( Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) { @@ -253,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, int64_t rankDiff = srcShapeRank - targetShapeRank; std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, 1); - std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + llvm::copy(shape, adjustedTargetShape.begin() + rankDiff); SmallVector<Value> result; for (SmallVector<int64_t> offsets : @@ -473,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder, for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); - results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval)); } return results; } @@ -500,3 +553,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc, results.append(addElementwise(builder, loc, a, b)); return results; } + +template <typename T> +int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates, + ArrayRef<T> candidateMultiples) { + static_assert(std::is_integral<T>::value, "T must be an integer type"); + int largest = -1; + SmallVector<T> multiples = {1}; + if (!candidateMultiples.empty()) + multiples = + SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end()); + for (T candidate : candidates) { + for (T multiple : multiples) { + int value = static_cast<int>(candidate * multiple); + if (value != 0 && dim % value == 0 && value > largest) + largest = value; + } + } + return largest; +} + +/// Explicit instantiations +template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates, + ArrayRef<int> candidateMultiples); +template int +xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates, + ArrayRef<unsigned> candidateMultiples); |
