aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp120
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp132
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp4
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp2
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp23
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp146
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp11
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp3
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp13
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp33
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp52
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp32
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp47
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp10
17 files changed, 430 insertions, 237 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d970..47685567 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,16 @@ public:
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +175,115 @@ private:
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() {
+ if (set.ptr)
+ mlirRewritePatternSetDestroy(set);
+ }
+
+ void add(MlirStringRef rootName, unsigned benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() {
+ MlirRewritePatternSet s = set;
+ set.ptr = nullptr;
+ return mlirFreezeRewritePattern(s);
+ }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
+ nb::
+ class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ )
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+ // clang-format on
+ )
+ .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ nb::arg("op"),
+ // clang-format off
+ nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ );
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1,
+ "Add a new rewrite pattern on the given root operation with the "
+ "callable as the matching and rewriting function and the given "
+ "benefit.")
+ .def("freeze", &PyRewritePatternSet::freeze,
+ "Freeze the pattern set into a frozen one.");
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b..46c329d 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
-}
-
-static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
- return {module};
-}
-
-static inline mlir::FrozenRewritePatternSet *
-unwrap(MlirFrozenRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
-}
-
-static inline MlirFrozenRewritePatternSet
-wrap(mlir::FrozenRewritePatternSet *module) {
- return {module};
-}
-
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
- op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+ set.ptr = nullptr;
return wrap(m);
}
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
- delete unwrap(op);
- op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+ delete unwrap(set);
+ set.ptr = nullptr;
}
MlirLogicalResult
@@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
/// PatternRewriter API
//===----------------------------------------------------------------------===//
-inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
- assert(rewriter.ptr && "unexpected null rewriter");
- return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
-inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
- return {rewriter};
-}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
-MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
- return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
}
//===----------------------------------------------------------------------===//
-/// PDLPatternModule API
+/// RewritePatternSet API
//===----------------------------------------------------------------------===//
-#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::PDLPatternModule *>(module.ptr);
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
}
-static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
- return {module};
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
}
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
@@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
return wrap(m);
}
-inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
- assert(value.ptr && "unexpected null PDL value");
- return static_cast<const mlir::PDLValue *>(value.ptr);
-}
-
-inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
-
-inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
- assert(results.ptr && "unexpected null PDL results");
- return static_cast<mlir::PDLResultList *>(results.ptr);
-}
-
-inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
- return {results};
-}
-
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index dcbaa56..247dba1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Case<vector::SplatOp>([&current](auto op) {
- current = op.getInput();
- return false;
- })
.Default([](Operation *) { return false; });
if (!skipOp) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index bad53c0..1002ebe 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
-/// %pad_1d = vector.splat %pad : vector<[4]xi32>
+/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 363685a..778c616 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};
-// Convert all `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to ArmSME via another pattern.
-struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5461646..5355909 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2161,19 +2161,6 @@ public:
}
};
-/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
-/// `vector.broadcast` through other patterns.
-struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 311ff6f..56e8fee 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
-// Convert `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to SPIRV via other patterns.
-struct VectorSplatToBroadcast final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorInterleaveOpConvert,
- VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
- VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorScalarBroadcastPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f449d90..f276984 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
};
//===----------------------------------------------------------------------===//
+// GPU index id operations
+//===----------------------------------------------------------------------===//
+/*
+// Launch Config ops
+// dimidx - x, y, z - is fixed to i32
+// return type is set by XeVM type converter
+// get_local_id
+xevm::WorkitemIdXOp;
+xevm::WorkitemIdYOp;
+xevm::WorkitemIdZOp;
+// get_local_size
+xevm::WorkgroupDimXOp;
+xevm::WorkgroupDimYOp;
+xevm::WorkgroupDimZOp;
+// get_group_id
+xevm::WorkgroupIdXOp;
+xevm::WorkgroupIdYOp;
+xevm::WorkgroupIdZOp;
+// get_num_groups
+xevm::GridDimXOp;
+xevm::GridDimYOp;
+xevm::GridDimZOp;
+// get_global_id : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name and dimension argument for each op.
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
+ return {"get_local_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
+ return {"get_local_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
+ return {"get_local_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
+ return {"get_local_size", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
+ return {"get_local_size", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
+ return {"get_local_size", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
+ return {"get_group_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
+ return {"get_group_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
+ return {"get_group_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
+ return {"get_num_groups", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
+ return {"get_num_groups", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
+ return {"get_num_groups", 2};
+}
+/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
+/// a constant argument for the dimension - x, y or z.
+template <typename OpType>
+class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto [baseName, dim] = getConfig(op);
+ Type dimTy = rewriter.getI32Type();
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
+ std::string func = mangle(baseName, {dimTy}, {true});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+/*
+// Subgroup ops
+// get_sub_group_local_id
+xevm::LaneIdOp;
+// get_sub_group_id
+xevm::SubgroupIdOp;
+// get_sub_group_size
+xevm::SubgroupSizeOp;
+// get_num_sub_groups : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name for each op.
+static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
+static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
+static StringRef getConfig(xevm::SubgroupSizeOp) {
+ return "get_sub_group_size";
+}
+template <typename OpType>
+class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::string func = mangle(getConfig(op).str(), {});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
- BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
+ LaunchConfigOpToOCLPattern<GridDimXOp>,
+ LaunchConfigOpToOCLPattern<GridDimYOp>,
+ LaunchConfigOpToOCLPattern<GridDimZOp>,
+ SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index d5c7190..f405d0c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -40,6 +41,15 @@ using namespace mlir::amdgpu;
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
+namespace {
+struct AMDGPUInlinerInterface final : DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
+
void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -49,6 +59,7 @@ void AMDGPUDialect::initialize() {
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
>();
+ addInterfaces<AMDGPUInlinerInterface>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index c64e10f5..d018cdd 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, arith::SelectOp, vector::SplatOp,
- vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index a50ddbe..624519f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
return returnOp;
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
-}
-
LogicalResult
mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
IRRewriter rewriter(module.getContext());
@@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
// Collect the mapping of functions to their call sites.
module.walk([&](func::CallOp callOp) {
- if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
+ if (func::FuncOp calledFunc =
+ dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
callerMap[calledFunc].insert(callOp);
}
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a0..5edcc40b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args: dst, mbar, src, size.
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getSize()));
+
+ // Multicast mask, if available.
+ mlir::Value multicastMask = thisOp.getMulticastMask();
+ const bool hasMulticastMask = static_cast<bool>(multicastMask);
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+
+ // Cache hint, if available.
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+ // Flag arguments for multicast and cachehint.
+ args.push_back(builder.getInt1(hasMulticastMask));
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ llvm::Intrinsic::ID id =
+ llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 14e235f..a7e3ba8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
-/// 1s, are considered to be 'broadcastlike'.
+/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
+/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
- if (isa<BroadcastOp, SplatOp>(op))
+ if (isa<BroadcastOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
-/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
-/// value that is splatted. Otherwise return null.
+/// vector.broadcast with a scalar operand, return the scalar value that is
+/// splatted. Otherwise return null.
///
-/// Examples:
+/// Example:
///
-/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
@@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) {
if (!defOp)
return {};
- // Splat:
- if (auto splat = dyn_cast<vector::SplatOp>(defOp))
- return splat.getInput();
-
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
// Not broadcast (and not splat):
@@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getInput();
- if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
- return {};
-
- // SplatElementsAttr::get treats single value for second arg as being a splat.
- return SplatElementsAttr::get(getType(), {constOperand});
-}
-
-// Canonicalizer for vector.splat. It always gets canonicalized to a
-// vector.broadcast.
-class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
-public:
- using Base::Base;
- LogicalResult matchAndRewrite(SplatOp splatOp,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getOperand());
- return success();
- }
-};
-void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SplatToBroadcastPattern>(context);
-}
-
-void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 255f2bf..3a3231d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
- // TODO: add support to `vector.splat`.
+ // TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 71fba71c..1b656d8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-/// vector.splat %value : vector<4x4xf32>
-/// is converted to:
-/// %out_1d = vector.splat %value : vector<16xf32>
-/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
-struct LinearizeVectorSplat final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
-
- LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit) {}
-
- LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto dstTy = getTypeConverter()->convertType(splatOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
- rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
- dstTy);
- return success();
- }
-};
-
/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
@@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements,
- LinearizeVectorToElements>(typeConverter, patterns.getContext());
+ LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
+ LinearizeVectorFromElements, LinearizeVectorToElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d6a6d7cd..726da1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// This transforms IR like:
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
-// %cst = vector.splat %c0_f32 : vector<4xf32>
+// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
// %1 = vector.extract_strided_slice %0 {
// offsets = [0], sizes = [4], strides = [1]
// } : vector<8xf16> to vector<4xf16>
@@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) {
return newElementType;
}
-/// If `value` is the result of a splat or broadcast operation, return the input
-/// of the splat/broadcast operation.
+/// If `value` is the result of a broadcast operation, return the input
+/// of the broadcast operation.
static Value getBroadcastLikeSource(Value value) {
Operation *op = value.getDefiningOp();
@@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) {
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcast.getSource();
- if (auto splat = dyn_cast<vector::SplatOp>(op))
- return splat.getInput();
-
return {};
}
-/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
///
/// Example:
/// ```
@@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
-///
-/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
-/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final
Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
- Value splatSource;
+ Value broadcastSource;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
- splatSource = getBroadcastLikeSource(operand);
+ broadcastSource = getBroadcastLikeSource(operand);
break;
}
- if (!splatSource)
+ if (!broadcastSource)
return failure();
Type unbroadcastResultType =
- cloneOrReplace(splatSource.getType(), resultElemType);
+ cloneOrReplace(broadcastSource.getType(), resultElemType);
// Make sure that all operands are broadcast from identically-shaped types:
- // * scalar (`vector.broadcast` + `vector.splat`), or
+ // * scalar (`vector.broadcast`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
if (auto source = getBroadcastLikeSource(val))
return haveSameShapeAndScaling(source.getType(),
- splatSource.getType());
+ broadcastSource.getType());
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
@@ -1271,19 +1265,18 @@ public:
}
};
-/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
///
/// Example:
/// ```
-/// %0 = vector.splat %arg2 : vector<1xf32>
+/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
/// ```
/// Gets converted to:
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
-class StoreOpFromSplatOrBroadcast final
- : public OpRewritePattern<vector::StoreOp> {
+class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
public:
using Base::Base;
@@ -1308,9 +1301,9 @@ public:
return rewriter.notifyMatchFailure(
op, "value to store is not from a broadcast");
- // Checking for single use so we can remove splat.
- Operation *splat = toStore.getDefiningOp();
- if (!splat->hasOneUse())
+ // Checking for single use so we can remove broadcast.
+ Operation *broadcast = toStore.getDefiningOp();
+ if (!broadcast->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
Value base = op.getBase();
@@ -1321,7 +1314,7 @@ public:
} else {
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
}
- rewriter.eraseOp(splat);
+ rewriter.eraseOp(broadcast);
return success();
}
};
@@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
- patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
- patterns.getContext(), benefit);
+ patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index f1dbc5d..26770b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
/// }
/// return %0
/// }
-struct MoveFuncBodyToWarpExecuteOnLane0
- : public OpRewritePattern<gpu::GPUFuncOp> {
+struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
@@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
/*pattern benefit=*/highPatternBenefit);
}
+void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
+}
+
void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 1: Attach layouts to op operands.
// TODO: Following assumptions are made:
@@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// gpu.warp_execute_on_lane_0 operation.
{
RewritePatternSet patterns(&getContext());
- patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+ xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();