aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp120
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp132
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp13
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp33
4 files changed, 226 insertions, 72 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d970..47685567 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,16 @@ public:
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ void replaceOp(MlirOperation op, MlirOperation newOp) {
+ mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+ }
+
+ void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+ mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+ }
+
+ void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
@@ -165,13 +175,115 @@ private:
MlirFrozenRewritePatternSet set;
};
+class PyRewritePatternSet {
+public:
+ PyRewritePatternSet(MlirContext ctx)
+ : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+ ~PyRewritePatternSet() {
+ if (set.ptr)
+ mlirRewritePatternSetDestroy(set);
+ }
+
+ void add(MlirStringRef rootName, unsigned benefit,
+ const nb::callable &matchAndRewrite) {
+ MlirRewritePatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+ MlirPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+ nb::object res = f(op, PyPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set, pattern);
+ }
+
+ PyFrozenRewritePatternSet freeze() {
+ MlirRewritePatternSet s = set;
+ set.ptr = nullptr;
+ return mlirFreezeRewritePattern(s);
+ }
+
+private:
+ MlirRewritePatternSet set;
+ MlirContext ctx;
+};
+
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
- nb::class_<PyPatternRewriter>(m, "PatternRewriter")
- .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
- "The current insertion point of the PatternRewriter.");
+ //----------------------------------------------------------------------------
+ // Mapping of the PatternRewriter
+ //----------------------------------------------------------------------------
+ nb::
+ class_<PyPatternRewriter>(m, "PatternRewriter")
+ .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+ "The current insertion point of the PatternRewriter.")
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ MlirOperation newOp) { self.replaceOp(op, newOp); },
+ "Replace an operation with a new operation.", nb::arg("op"),
+ nb::arg("new_op"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ )
+ .def(
+ "replace_op",
+ [](PyPatternRewriter &self, MlirOperation op,
+ const std::vector<MlirValue> &values) {
+ self.replaceOp(op, values);
+ },
+ "Replace an operation with a list of values.", nb::arg("op"),
+ nb::arg("values"),
+ // clang-format off
+ nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+ // clang-format on
+ )
+ .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+ nb::arg("op"),
+ // clang-format off
+ nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+ // clang-format on
+ );
+
+ //----------------------------------------------------------------------------
+ // Mapping of the RewritePatternSet
+ //----------------------------------------------------------------------------
+ nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+ .def(
+ "__init__",
+ [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+ new (&self) PyRewritePatternSet(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+ fn);
+ },
+ "root"_a, "fn"_a, "benefit"_a = 1,
+ "Add a new rewrite pattern on the given root operation with the "
+ "callable as the matching and rewriting function and the given "
+ "benefit.")
+ .def("freeze", &PyRewritePatternSet::freeze,
+ "Freeze the pattern set into a frozen one.");
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"freeze",
[](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b..46c329d 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
-}
-
-static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
- return {module};
-}
-
-static inline mlir::FrozenRewritePatternSet *
-unwrap(MlirFrozenRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
-}
-
-static inline MlirFrozenRewritePatternSet
-wrap(mlir::FrozenRewritePatternSet *module) {
- return {module};
-}
-
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
- op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+ set.ptr = nullptr;
return wrap(m);
}
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
- delete unwrap(op);
- op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+ delete unwrap(set);
+ set.ptr = nullptr;
}
MlirLogicalResult
@@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
/// PatternRewriter API
//===----------------------------------------------------------------------===//
-inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
- assert(rewriter.ptr && "unexpected null rewriter");
- return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
-inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
- return {rewriter};
-}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
-MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
- return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
}
//===----------------------------------------------------------------------===//
-/// PDLPatternModule API
+/// RewritePatternSet API
//===----------------------------------------------------------------------===//
-#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::PDLPatternModule *>(module.ptr);
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
}
-static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
- return {module};
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
}
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
@@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
return wrap(m);
}
-inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
- assert(value.ptr && "unexpected null PDL value");
- return static_cast<const mlir::PDLValue *>(value.ptr);
-}
-
-inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
-
-inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
- assert(results.ptr && "unexpected null PDL results");
- return static_cast<mlir::PDLResultList *>(results.ptr);
-}
-
-inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
- return {results};
-}
-
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index a50ddbe..624519f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
return returnOp;
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
-}
-
LogicalResult
mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
IRRewriter rewriter(module.getContext());
@@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
// Collect the mapping of functions to their call sites.
module.walk([&](func::CallOp callOp) {
- if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
+ if (func::FuncOp calledFunc =
+ dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
callerMap[calledFunc].insert(callOp);
}
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a0..5edcc40b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args: dst, mbar, src, size.
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getSize()));
+
+ // Multicast mask, if available.
+ mlir::Value multicastMask = thisOp.getMulticastMask();
+ const bool hasMulticastMask = static_cast<bool>(multicastMask);
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+
+ // Cache hint, if available.
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+ // Flag arguments for multicast and cachehint.
+ args.push_back(builder.getInt1(hasMulticastMask));
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ llvm::Intrinsic::ID id =
+ llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);