diff options
Diffstat (limited to 'mlir/lib')
17 files changed, 430 insertions, 237 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d970..47685567 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,16 @@ public: return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +175,115 @@ private: MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } + + void add(MlirStringRef rootName, unsigned benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast<PyObject *>(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter)); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_<PyPatternRewriter>(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- + nb:: + class_<PyPatternRewriter>(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector<MlirValue> &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast<std::string>(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c15a73b..46c329d 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); -} - -static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { - return {module}; -} - -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - -MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); - op.ptr = nullptr; +MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet set) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); + set.ptr = nullptr; return wrap(m); } -void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { - delete unwrap(op); - op.ptr = nullptr; +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { + delete unwrap(set); + set.ptr = nullptr; } MlirLogicalResult @@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast<mlir::PatternRewriter *>(rewriter.ptr); +MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { + return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); } -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// -MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { - return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); +namespace mlir { + +class ExternalRewritePattern : public mlir::RewritePattern { +public: + ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, + StringRef rootName, PatternBenefit benefit, + MLIRContext *context, + ArrayRef<StringRef> generatedNames) + : RewritePattern(rootName, benefit, context, generatedNames), + callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + + ~ExternalRewritePattern() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + return unwrap(callbacks.matchAndRewrite( + wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op), + wrap(&rewriter), userData)); + } + +private: + MlirRewritePatternCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames) { + std::vector<mlir::StringRef> generatedNamesVec; + generatedNamesVec.reserve(nGeneratedNames); + for (size_t i = 0; i < nGeneratedNames; ++i) { + generatedNamesVec.push_back(unwrap(generatedNames[i])); + } + return wrap(new mlir::ExternalRewritePattern( + callbacks, userData, unwrap(rootName), PatternBenefit(benefit), + unwrap(context), generatedNamesVec)); } //===----------------------------------------------------------------------===// -/// PDLPatternModule API +/// RewritePatternSet API //===----------------------------------------------------------------------===// -#if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::PDLPatternModule *>(module.ptr); +MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { + return wrap(new mlir::RewritePatternSet(unwrap(context))); +} + +void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { + delete unwrap(set); } -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; +void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern) { + std::unique_ptr<mlir::RewritePattern> patternPtr( + const_cast<mlir::RewritePattern *>(unwrap(pattern))); + pattern.ptr = nullptr; + unwrap(set)->add(std::move(patternPtr)); } +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); @@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast<const mlir::PDLValue *>(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast<mlir::PDLResultList *>(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast<mlir::Value>()); } diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index dcbaa56..247dba1 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Case<vector::SplatOp>([¤t](auto op) { - current = op.getInput(); - return false; - }) .Default([](Operation *) { return false; }); if (!skipOp) { diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index bad53c0..1002ebe 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { /// AFTER: /// ```mlir /// ... -/// %pad_1d = vector.splat %pad : vector<[4]xi32> +/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32> /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// ... diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 363685a..778c616 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering } }; -// Convert all `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to ArmSME via another pattern. -struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> { - using Base::Base; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), - splatOp.getInput()); - return success(); - } -}; - } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast, - TransferReadToArmSMELowering, TransferWriteToArmSMELowering, - TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, - VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, + patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering, + TransferWriteToArmSMELowering, TransposeOpToArmSMELowering, + VectorLoadToArmSMELowering, VectorStoreToArmSMELowering, + VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, ExtractFromCreateMaskToPselLowering>(&ctx); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5461646..5355909 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -2161,19 +2161,6 @@ public: } }; -/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from -/// `vector.broadcast` through other patterns. -struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering, + VectorBroadcastScalarToLowRankLowering, VectorBroadcastScalarToNdLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 311ff6f..56e8fee 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> { } }; -// Convert `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to SPIRV via other patterns. -struct VectorSplatToBroadcast final - : public OpConversionPattern<vector::SplatOp> { - using Base::Base; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - struct VectorBitcastConvert final : public OpConversionPattern<vector::BitCastOp> { using Base::Base; @@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns( VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, - VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, - VectorShuffleOpConvert, VectorInterleaveOpConvert, - VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, - VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, + VectorScalarBroadcastPattern, VectorLoadOpConverter, + VectorStoreOpConverter, VectorStepOpConvert>( typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f449d90..f276984 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { }; //===----------------------------------------------------------------------===// +// GPU index id operations +//===----------------------------------------------------------------------===// +/* +// Launch Config ops +// dimidx - x, y, z - is fixed to i32 +// return type is set by XeVM type converter +// get_local_id +xevm::WorkitemIdXOp; +xevm::WorkitemIdYOp; +xevm::WorkitemIdZOp; +// get_local_size +xevm::WorkgroupDimXOp; +xevm::WorkgroupDimYOp; +xevm::WorkgroupDimZOp; +// get_group_id +xevm::WorkgroupIdXOp; +xevm::WorkgroupIdYOp; +xevm::WorkgroupIdZOp; +// get_num_groups +xevm::GridDimXOp; +xevm::GridDimYOp; +xevm::GridDimZOp; +// get_global_id : to be added if needed +*/ + +// Helpers to get the OpenCL function name and dimension argument for each op. +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) { + return {"get_local_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) { + return {"get_local_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) { + return {"get_local_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) { + return {"get_local_size", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) { + return {"get_local_size", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) { + return {"get_local_size", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) { + return {"get_group_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) { + return {"get_group_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) { + return {"get_group_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) { + return {"get_num_groups", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) { + return {"get_num_groups", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) { + return {"get_num_groups", 2}; +} +/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with +/// a constant argument for the dimension - x, y or z. +template <typename OpType> +class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto [baseName, dim] = getConfig(op); + Type dimTy = rewriter.getI32Type(); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast<int64_t>(dim)); + std::string func = mangle(baseName, {dimTy}, {true}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +/* +// Subgroup ops +// get_sub_group_local_id +xevm::LaneIdOp; +// get_sub_group_id +xevm::SubgroupIdOp; +// get_sub_group_size +xevm::SubgroupSizeOp; +// get_num_sub_groups : to be added if needed +*/ + +// Helpers to get the OpenCL function name for each op. +static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; } +static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; } +static StringRef getConfig(xevm::SubgroupSizeOp) { + return "get_sub_group_size"; +} +template <typename OpType> +class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string func = mangle(getConfig(op).str(), {}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, BlockLoadStore1DToOCLPattern<BlockLoadOp>, - BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + BlockLoadStore1DToOCLPattern<BlockStoreOp>, + LaunchConfigOpToOCLPattern<WorkitemIdXOp>, + LaunchConfigOpToOCLPattern<WorkitemIdYOp>, + LaunchConfigOpToOCLPattern<WorkitemIdZOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimXOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimYOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimZOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdXOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdYOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdZOp>, + LaunchConfigOpToOCLPattern<GridDimXOp>, + LaunchConfigOpToOCLPattern<GridDimYOp>, + LaunchConfigOpToOCLPattern<GridDimZOp>, + SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>( patterns.getContext()); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index d5c7190..f405d0c 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -40,6 +41,15 @@ using namespace mlir::amdgpu; #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc" +namespace { +struct AMDGPUInlinerInterface final : DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace + void AMDGPUDialect::initialize() { addOperations< #define GET_OP_LIST @@ -49,6 +59,7 @@ void AMDGPUDialect::initialize() { #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" >(); + addInterfaces<AMDGPUInlinerInterface>(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index c64e10f5..d018cdd 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, arith::SelectOp, vector::SplatOp, - vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..624519f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - LogicalResult mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { IRRewriter rewriter(module.getContext()); @@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { callerMap[calledFunc].insert(callOp); } }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a0..5edcc40b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } +mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args: dst, mbar, src, size. + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast mask, if available. + mlir::Value multicastMask = thisOp.getMulticastMask(); + const bool hasMulticastMask = static_cast<bool>(multicastMask); + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + + // Cache hint, if available. + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + + // Flag arguments for multicast and cachehint. + args.push_back(builder.getInt1(hasMulticastMask)); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 14e235f..a7e3ba8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend -/// 1s, are considered to be 'broadcastlike'. +/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are +/// considered to be 'broadcastlike'. static bool isBroadcastLike(Operation *op) { - if (isa<BroadcastOp, SplatOp>(op)) + if (isa<BroadcastOp>(op)) return true; auto shapeCast = dyn_cast<ShapeCastOp>(op); @@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { }; /// Consider the defining operation `defOp` of `value`. If `defOp` is a -/// vector.splat or a vector.broadcast with a scalar operand, return the scalar -/// value that is splatted. Otherwise return null. +/// vector.broadcast with a scalar operand, return the scalar value that is +/// splatted. Otherwise return null. /// -/// Examples: +/// Example: /// -/// scalar_source --> vector.splat --> value - return scalar_source /// scalar_source --> vector.broadcast --> value - return scalar_source static Value getScalarSplatSource(Value value) { // Block argument: @@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) { if (!defOp) return {}; - // Splat: - if (auto splat = dyn_cast<vector::SplatOp>(defOp)) - return splat.getInput(); - auto broadcast = dyn_cast<vector::BroadcastOp>(defOp); // Not broadcast (and not splat): @@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( patterns.getContext(), benefit); } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { - auto constOperand = adaptor.getInput(); - if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) - return {}; - - // SplatElementsAttr::get treats single value for second arg as being a splat. - return SplatElementsAttr::get(getType(), {constOperand}); -} - -// Canonicalizer for vector.splat. It always gets canonicalized to a -// vector.broadcast. -class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { -public: - using Base::Base; - LogicalResult matchAndRewrite(SplatOp splatOp, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), - splatOp.getOperand()); - return success(); - } -}; -void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<SplatToBroadcastPattern>(context); -} - -void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 255f2bf..3a3231d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, Operation *maskOp = mask.getDefiningOp(); SmallVector<vector::ExtractOp, 2> extractOps; - // TODO: add support to `vector.splat`. + // TODO: add support to `vector.broadcast`. // Finding the mask creation operation. while (maskOp && !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 71fba71c..1b656d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final } }; -/// This pattern converts the SplatOp to work on a linearized vector. -/// Following, -/// vector.splat %value : vector<4x4xf32> -/// is converted to: -/// %out_1d = vector.splat %value : vector<16xf32> -/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -struct LinearizeVectorSplat final - : public OpConversionPattern<vector::SplatOp> { - using Base::Base; - - LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = getTypeConverter()->convertType(splatOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(), - dstTy); - return success(); - } -}; - /// This pattern converts the CreateMaskOp to work on a linearized vector. /// It currently supports only 2D masks with a unit outer dimension. /// Following, @@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, - LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, - LinearizeVectorStore, LinearizeVectorFromElements, - LinearizeVectorToElements>(typeConverter, patterns.getContext()); + LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, + LinearizeVectorFromElements, LinearizeVectorToElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index d6a6d7cd..726da1e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert // This transforms IR like: // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> // Into: -// %cst = vector.splat %c0_f32 : vector<4xf32> +// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32> // %1 = vector.extract_strided_slice %0 { // offsets = [0], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> @@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) { return newElementType; } -/// If `value` is the result of a splat or broadcast operation, return the input -/// of the splat/broadcast operation. +/// If `value` is the result of a broadcast operation, return the input +/// of the broadcast operation. static Value getBroadcastLikeSource(Value value) { Operation *op = value.getDefiningOp(); @@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) { if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) return broadcast.getSource(); - if (auto splat = dyn_cast<vector::SplatOp>(op)) - return splat.getInput(); - return {}; } -/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: +/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: /// /// Example: /// ``` @@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) { /// %r = arith.addi %arg0, %arg1 : index /// %b = vector.broadcast %r : index to vector<1x4xindex> /// ``` -/// -/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting -/// ops. struct ReorderElementwiseOpsOnBroadcast final : public OpTraitRewritePattern<OpTrait::Elementwise> { using OpTraitRewritePattern::OpTraitRewritePattern; @@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final Type resultElemType = resultType.getElementType(); // Get the type of the first non-constant operand - Value splatSource; + Value broadcastSource; for (Value operand : op->getOperands()) { Operation *definingOp = operand.getDefiningOp(); if (!definingOp) return failure(); if (definingOp->hasTrait<OpTrait::ConstantLike>()) continue; - splatSource = getBroadcastLikeSource(operand); + broadcastSource = getBroadcastLikeSource(operand); break; } - if (!splatSource) + if (!broadcastSource) return failure(); Type unbroadcastResultType = - cloneOrReplace(splatSource.getType(), resultElemType); + cloneOrReplace(broadcastSource.getType(), resultElemType); // Make sure that all operands are broadcast from identically-shaped types: - // * scalar (`vector.broadcast` + `vector.splat`), or + // * scalar (`vector.broadcast`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) { if (auto source = getBroadcastLikeSource(val)) return haveSameShapeAndScaling(source.getType(), - splatSource.getType()); + broadcastSource.getType()); SplatElementsAttr splatConst; return matchPattern(val, m_Constant(&splatConst)); })) { @@ -1271,19 +1265,18 @@ public: } }; -/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store. +/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store. /// /// Example: /// ``` -/// %0 = vector.splat %arg2 : vector<1xf32> +/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32> /// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> /// ``` /// Gets converted to: /// ``` /// memref.store %arg2, %arg0[%arg1] : memref<?xf32> /// ``` -class StoreOpFromSplatOrBroadcast final - : public OpRewritePattern<vector::StoreOp> { +class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> { public: using Base::Base; @@ -1308,9 +1301,9 @@ public: return rewriter.notifyMatchFailure( op, "value to store is not from a broadcast"); - // Checking for single use so we can remove splat. - Operation *splat = toStore.getDefiningOp(); - if (!splat->hasOneUse()) + // Checking for single use so we can remove broadcast. + Operation *broadcast = toStore.getDefiningOp(); + if (!broadcast->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); Value base = op.getBase(); @@ -1321,7 +1314,7 @@ public: } else { rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices); } - rewriter.eraseOp(splat); + rewriter.eraseOp(broadcast); return success(); } }; @@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { // TODO: Consider converting these patterns to canonicalizations. - patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>( - patterns.getContext(), benefit); + patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(), + benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index f1dbc5d..26770b3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout, /// } /// return %0 /// } -struct MoveFuncBodyToWarpExecuteOnLane0 - : public OpRewritePattern<gpu::GPUFuncOp> { +struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, PatternRewriter &rewriter) const override { @@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( /*pattern benefit=*/highPatternBenefit); } +void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( + RewritePatternSet &patterns) { + patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext()); +} + void XeGPUSubgroupDistributePass::runOnOperation() { // Step 1: Attach layouts to op operands. // TODO: Following assumptions are made: @@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // gpu.warp_execute_on_lane_0 operation. { RewritePatternSet patterns(&getContext()); - patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext()); + xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); |