diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Bindings/Python/Rewrite.cpp | 120 | ||||
-rw-r--r-- | mlir/lib/CAPI/Transforms/Rewrite.cpp | 132 | ||||
-rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp | 13 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 33 |
4 files changed, 226 insertions, 72 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d970..47685567 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,16 @@ public: return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +175,115 @@ private: MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } + + void add(MlirStringRef rootName, unsigned benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast<PyObject *>(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter)); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_<PyPatternRewriter>(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- + nb:: + class_<PyPatternRewriter>(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector<MlirValue> &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast<std::string>(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c15a73b..46c329d 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); -} - -static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { - return {module}; -} - -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - -MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); - op.ptr = nullptr; +MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet set) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); + set.ptr = nullptr; return wrap(m); } -void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { - delete unwrap(op); - op.ptr = nullptr; +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { + delete unwrap(set); + set.ptr = nullptr; } MlirLogicalResult @@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast<mlir::PatternRewriter *>(rewriter.ptr); +MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { + return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); } -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// -MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { - return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); +namespace mlir { + +class ExternalRewritePattern : public mlir::RewritePattern { +public: + ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, + StringRef rootName, PatternBenefit benefit, + MLIRContext *context, + ArrayRef<StringRef> generatedNames) + : RewritePattern(rootName, benefit, context, generatedNames), + callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + + ~ExternalRewritePattern() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + return unwrap(callbacks.matchAndRewrite( + wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op), + wrap(&rewriter), userData)); + } + +private: + MlirRewritePatternCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames) { + std::vector<mlir::StringRef> generatedNamesVec; + generatedNamesVec.reserve(nGeneratedNames); + for (size_t i = 0; i < nGeneratedNames; ++i) { + generatedNamesVec.push_back(unwrap(generatedNames[i])); + } + return wrap(new mlir::ExternalRewritePattern( + callbacks, userData, unwrap(rootName), PatternBenefit(benefit), + unwrap(context), generatedNamesVec)); } //===----------------------------------------------------------------------===// -/// PDLPatternModule API +/// RewritePatternSet API //===----------------------------------------------------------------------===// -#if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::PDLPatternModule *>(module.ptr); +MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { + return wrap(new mlir::RewritePatternSet(unwrap(context))); +} + +void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { + delete unwrap(set); } -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; +void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern) { + std::unique_ptr<mlir::RewritePattern> patternPtr( + const_cast<mlir::RewritePattern *>(unwrap(pattern))); + pattern.ptr = nullptr; + unwrap(set)->add(std::move(patternPtr)); } +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); @@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast<const mlir::PDLValue *>(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast<mlir::PDLResultList *>(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast<mlir::Value>()); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..624519f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - LogicalResult mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { IRRewriter rewriter(module.getContext()); @@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { callerMap[calledFunc].insert(callOp); } }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a0..5edcc40b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } +mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args: dst, mbar, src, size. + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast mask, if available. + mlir::Value multicastMask = thisOp.getMulticastMask(); + const bool hasMulticastMask = static_cast<bool>(multicastMask); + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + + // Cache hint, if available. + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + + // Flag arguments for multicast and cachehint. + args.push_back(builder.getInt1(hasMulticastMask)); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op); |