diff options
Diffstat (limited to 'mlir/lib/Dialect')
9 files changed, 84 insertions, 119 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index d5c7190..f405d0c 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -40,6 +41,15 @@ using namespace mlir::amdgpu; #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc" +namespace { +struct AMDGPUInlinerInterface final : DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace + void AMDGPUDialect::initialize() { addOperations< #define GET_OP_LIST @@ -49,6 +59,7 @@ void AMDGPUDialect::initialize() { #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" >(); + addInterfaces<AMDGPUInlinerInterface>(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index c64e10f5..d018cdd 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, arith::SelectOp, vector::SplatOp, - vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..624519f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - LogicalResult mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { IRRewriter rewriter(module.getContext()); @@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { callerMap[calledFunc].insert(callOp); } }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a0..5edcc40b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } +mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args: dst, mbar, src, size. + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast mask, if available. + mlir::Value multicastMask = thisOp.getMulticastMask(); + const bool hasMulticastMask = static_cast<bool>(multicastMask); + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + + // Cache hint, if available. + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + + // Flag arguments for multicast and cachehint. + args.push_back(builder.getInt1(hasMulticastMask)); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 14e235f..a7e3ba8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend -/// 1s, are considered to be 'broadcastlike'. +/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are +/// considered to be 'broadcastlike'. static bool isBroadcastLike(Operation *op) { - if (isa<BroadcastOp, SplatOp>(op)) + if (isa<BroadcastOp>(op)) return true; auto shapeCast = dyn_cast<ShapeCastOp>(op); @@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { }; /// Consider the defining operation `defOp` of `value`. If `defOp` is a -/// vector.splat or a vector.broadcast with a scalar operand, return the scalar -/// value that is splatted. Otherwise return null. +/// vector.broadcast with a scalar operand, return the scalar value that is +/// splatted. Otherwise return null. /// -/// Examples: +/// Example: /// -/// scalar_source --> vector.splat --> value - return scalar_source /// scalar_source --> vector.broadcast --> value - return scalar_source static Value getScalarSplatSource(Value value) { // Block argument: @@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) { if (!defOp) return {}; - // Splat: - if (auto splat = dyn_cast<vector::SplatOp>(defOp)) - return splat.getInput(); - auto broadcast = dyn_cast<vector::BroadcastOp>(defOp); // Not broadcast (and not splat): @@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( patterns.getContext(), benefit); } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { - auto constOperand = adaptor.getInput(); - if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) - return {}; - - // SplatElementsAttr::get treats single value for second arg as being a splat. - return SplatElementsAttr::get(getType(), {constOperand}); -} - -// Canonicalizer for vector.splat. It always gets canonicalized to a -// vector.broadcast. -class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { -public: - using Base::Base; - LogicalResult matchAndRewrite(SplatOp splatOp, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), - splatOp.getOperand()); - return success(); - } -}; -void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<SplatToBroadcastPattern>(context); -} - -void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 255f2bf..3a3231d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, Operation *maskOp = mask.getDefiningOp(); SmallVector<vector::ExtractOp, 2> extractOps; - // TODO: add support to `vector.splat`. + // TODO: add support to `vector.broadcast`. // Finding the mask creation operation. while (maskOp && !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 71fba71c..1b656d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final } }; -/// This pattern converts the SplatOp to work on a linearized vector. -/// Following, -/// vector.splat %value : vector<4x4xf32> -/// is converted to: -/// %out_1d = vector.splat %value : vector<16xf32> -/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -struct LinearizeVectorSplat final - : public OpConversionPattern<vector::SplatOp> { - using Base::Base; - - LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = getTypeConverter()->convertType(splatOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(), - dstTy); - return success(); - } -}; - /// This pattern converts the CreateMaskOp to work on a linearized vector. /// It currently supports only 2D masks with a unit outer dimension. /// Following, @@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, - LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, - LinearizeVectorStore, LinearizeVectorFromElements, - LinearizeVectorToElements>(typeConverter, patterns.getContext()); + LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, + LinearizeVectorFromElements, LinearizeVectorToElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index d6a6d7cd..726da1e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert // This transforms IR like: // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> // Into: -// %cst = vector.splat %c0_f32 : vector<4xf32> +// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32> // %1 = vector.extract_strided_slice %0 { // offsets = [0], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> @@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) { return newElementType; } -/// If `value` is the result of a splat or broadcast operation, return the input -/// of the splat/broadcast operation. +/// If `value` is the result of a broadcast operation, return the input +/// of the broadcast operation. static Value getBroadcastLikeSource(Value value) { Operation *op = value.getDefiningOp(); @@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) { if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) return broadcast.getSource(); - if (auto splat = dyn_cast<vector::SplatOp>(op)) - return splat.getInput(); - return {}; } -/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: +/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: /// /// Example: /// ``` @@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) { /// %r = arith.addi %arg0, %arg1 : index /// %b = vector.broadcast %r : index to vector<1x4xindex> /// ``` -/// -/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting -/// ops. struct ReorderElementwiseOpsOnBroadcast final : public OpTraitRewritePattern<OpTrait::Elementwise> { using OpTraitRewritePattern::OpTraitRewritePattern; @@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final Type resultElemType = resultType.getElementType(); // Get the type of the first non-constant operand - Value splatSource; + Value broadcastSource; for (Value operand : op->getOperands()) { Operation *definingOp = operand.getDefiningOp(); if (!definingOp) return failure(); if (definingOp->hasTrait<OpTrait::ConstantLike>()) continue; - splatSource = getBroadcastLikeSource(operand); + broadcastSource = getBroadcastLikeSource(operand); break; } - if (!splatSource) + if (!broadcastSource) return failure(); Type unbroadcastResultType = - cloneOrReplace(splatSource.getType(), resultElemType); + cloneOrReplace(broadcastSource.getType(), resultElemType); // Make sure that all operands are broadcast from identically-shaped types: - // * scalar (`vector.broadcast` + `vector.splat`), or + // * scalar (`vector.broadcast`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) { if (auto source = getBroadcastLikeSource(val)) return haveSameShapeAndScaling(source.getType(), - splatSource.getType()); + broadcastSource.getType()); SplatElementsAttr splatConst; return matchPattern(val, m_Constant(&splatConst)); })) { @@ -1271,19 +1265,18 @@ public: } }; -/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. +/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store. /// /// Example: /// ``` -/// %0 = vector.splat %arg2 : vector<1xf32> +/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32> /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> /// ``` /// Gets converted to: /// ``` /// memref.store %arg2, %arg0[%arg1] : memref<?xf32> /// ``` -class StoreOpFromSplatOrBroadcast final - : public OpRewritePattern<vector::StoreOp> { +class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> { public: using Base::Base; @@ -1308,9 +1301,9 @@ public: return rewriter.notifyMatchFailure( op, "value to store is not from a broadcast"); - // Checking for single use so we can remove splat. - Operation *splat = toStore.getDefiningOp(); - if (!splat->hasOneUse()) + // Checking for single use so we can remove broadcast. + Operation *broadcast = toStore.getDefiningOp(); + if (!broadcast->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); Value base = op.getBase(); @@ -1321,7 +1314,7 @@ public: } else { rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices); } - rewriter.eraseOp(splat); + rewriter.eraseOp(broadcast); return success(); } }; @@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { // TODO: Consider converting these patterns to canonicalizations. - patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>( - patterns.getContext(), benefit); + patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(), + benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index f1dbc5d..26770b3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout, /// } /// return %0 /// } -struct MoveFuncBodyToWarpExecuteOnLane0 - : public OpRewritePattern<gpu::GPUFuncOp> { +struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, PatternRewriter &rewriter) const override { @@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( /*pattern benefit=*/highPatternBenefit); } +void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( + RewritePatternSet &patterns) { + patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext()); +} + void XeGPUSubgroupDistributePass::runOnOperation() { // Step 1: Attach layouts to op operands. // TODO: Following assumptions are made: @@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // gpu.warp_execute_on_lane_0 operation. { RewritePatternSet patterns(&getContext()); - patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext()); + xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); |