aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/CAPI/Transforms/Rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/CAPI/Transforms/Rewrite.cpp')
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp132
1 files changed, 75 insertions, 57 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b..46c329d 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
-}
-
-static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
- return {module};
-}
-
-static inline mlir::FrozenRewritePatternSet *
-unwrap(MlirFrozenRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
-}
-
-static inline MlirFrozenRewritePatternSet
-wrap(mlir::FrozenRewritePatternSet *module) {
- return {module};
-}
-
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
- auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
- op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+ set.ptr = nullptr;
return wrap(m);
}
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
- delete unwrap(op);
- op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+ delete unwrap(set);
+ set.ptr = nullptr;
}
MlirLogicalResult
@@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
/// PatternRewriter API
//===----------------------------------------------------------------------===//
-inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
- assert(rewriter.ptr && "unexpected null rewriter");
- return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
-inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
- return {rewriter};
-}
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
-MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
- return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+ ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+ StringRef rootName, PatternBenefit benefit,
+ MLIRContext *context,
+ ArrayRef<StringRef> generatedNames)
+ : RewritePattern(rootName, benefit, context, generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalRewritePattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+ wrap(&rewriter), userData));
+ }
+
+private:
+ MlirRewritePatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirRewritePatternCallbacks callbacks, void *userData,
+ size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalRewritePattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), generatedNamesVec));
}
//===----------------------------------------------------------------------===//
-/// PDLPatternModule API
+/// RewritePatternSet API
//===----------------------------------------------------------------------===//
-#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::PDLPatternModule *>(module.ptr);
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+ return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+ delete unwrap(set);
}
-static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
- return {module};
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+ MlirRewritePattern pattern) {
+ std::unique_ptr<mlir::RewritePattern> patternPtr(
+ const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+ pattern.ptr = nullptr;
+ unwrap(set)->add(std::move(patternPtr));
}
+//===----------------------------------------------------------------------===//
+/// PDLPatternModule API
+//===----------------------------------------------------------------------===//
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
@@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
return wrap(m);
}
-inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
- assert(value.ptr && "unexpected null PDL value");
- return static_cast<const mlir::PDLValue *>(value.ptr);
-}
-
-inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
-
-inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
- assert(results.ptr && "unexpected null PDL results");
- return static_cast<mlir::PDLResultList *>(results.ptr);
-}
-
-inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
- return {results};
-}
-
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
}