aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir-c/IR.h10
-rw-r--r--mlir/include/mlir/Bindings/Python/PybindAdaptors.h1
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp32
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h4
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp21
-rw-r--r--mlir/test/CAPI/ir.c58
-rw-r--r--mlir/test/python/ir/operation.py75
7 files changed, 184 insertions, 17 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511..32abacf 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);
+/// Operation walk result.
+typedef enum MlirWalkResult {
+ MlirWalkResultAdvance,
+ MlirWalkResultInterrupt,
+ MlirWalkResultSkip
+} MlirWalkResult;
+
/// Traversal order for operation walk.
typedef enum MlirWalkOrder {
MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation and a pointer to a `userData`.
-typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
+ void *userData);
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 52f6321..d8f22c7a 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -18,6 +18,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
+#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7..d875f4e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
+ return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
+void PyOperationBase::walk(
+ std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ MlirOperationWalkCallback walkCallback = [](MlirOperation op,
+ void *userData) {
+ auto *fn =
+ static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
+ return (*fn)(op);
+ };
+
+ mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+}
+
py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
+ py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
@@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
@@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+ .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def("walk", &PyOperationBase::walk, py::arg("callback"),
+ py::arg("walk_order") = MlirWalkPostOrder);
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde..b038a0c 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,10 @@ public:
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
+ // Implement the walk method.
+ void walk(std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4..a72cd24 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
+static mlir::WalkResult unwrap(MlirWalkResult result) {
+ switch (result) {
+ case MlirWalkResultAdvance:
+ return mlir::WalkResult::advance();
+
+ case MlirWalkResultInterrupt:
+ return mlir::WalkResult::interrupt();
+
+ case MlirWalkResultSkip:
+ return mlir::WalkResult::skip();
+ }
+}
+
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder) {
switch (walkOrder) {
case MlirWalkPreOrder:
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return unwrap(callback(wrap(op), userData));
+ });
break;
case MlirWalkPostOrder:
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return unwrap(callback(wrap(op), userData));
+ });
}
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8e79338..3d05b2a 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2244,9 +2244,22 @@ typedef struct {
const char *x;
} callBackData;
-void walkCallBack(MlirOperation op, void *rootOpVoid) {
+MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
+ return MlirWalkResultAdvance;
+}
+
+MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
+ fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+ mlirIdentifierStr(mlirOperationGetName(op)).data);
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
+ 0)
+ return MlirWalkResultSkip;
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
+ 0)
+ return MlirWalkResultInterrupt;
+ return MlirWalkResultAdvance;
}
int testOperationWalk(MlirContext ctx) {
@@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) {
" arith.addi %1, %1: i32\n"
" return\n"
" }\n"
+ " func.func @bar() {\n"
+ " return\n"
+ " }\n"
"}";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
@@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) {
callBackData data;
data.x = "i love you";
- // CHECK: i love you: arith.constant
- // CHECK: i love you: arith.addi
- // CHECK: i love you: func.return
- // CHECK: i love you: func.func
- // CHECK: i love you: builtin.module
+ // CHECK-NEXT: i love you: arith.constant
+ // CHECK-NEXT: i love you: arith.addi
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: builtin.module
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPostOrder);
data.x = "i don't love you";
- // CHECK: i don't love you: builtin.module
- // CHECK: i don't love you: func.func
- // CHECK: i don't love you: arith.constant
- // CHECK: i don't love you: arith.addi
- // CHECK: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: builtin.module
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: arith.constant
+ // CHECK-NEXT: i don't love you: arith.addi
+ // CHECK-NEXT: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: func.return
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPreOrder);
+
+ data.x = "interrupt";
+ // Interrupted at `arith.addi`
+ // CHECK-NEXT: interrupt: arith.constant
+ // CHECK-NEXT: interrupt: arith.addi
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPostOrder);
+
+ data.x = "skip";
+ // Skip at `func.func`
+ // CHECK-NEXT: skip: builtin.module
+ // CHECK-NEXT: skip: func.func
+ // CHECK-NEXT: skip: func.func
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPreOrder);
+
mlirModuleDestroy(module);
return 0;
}
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a99..9666e63 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,78 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
+
+# CHECK-LABEL: TEST: testOpWalk
+@run
+def testOpWalk():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ func.func @f() {
+ func.return
+ }
+ }
+ """,
+ ctx,
+ )
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.ADVANCE
+
+ # Test post-order walk (default).
+ # CHECK-NEXT: Post-order
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: func.func
+ # CHECK-NEXT: builtin.module
+ print("Post-order")
+ module.operation.walk(callback)
+
+ # Test pre-order walk.
+ # CHECK-NEXT: Pre-order
+ # CHECK-NEXT: builtin.module
+ # CHECK-NEXT: func.fun
+ # CHECK-NEXT: func.return
+ print("Pre-order")
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+ # Test interrput.
+ # CHECK-NEXT: Interrupt post-order
+ # CHECK-NEXT: func.return
+ print("Interrupt post-order")
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.INTERRUPT
+
+ module.operation.walk(callback)
+
+ # Test skip.
+ # CHECK-NEXT: Skip pre-order
+ # CHECK-NEXT: builtin.module
+ print("Skip pre-order")
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.SKIP
+
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+ # Test exception.
+ # CHECK: Exception
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: Exception raised
+ print("Exception")
+
+ def callback(op):
+ print(op.name)
+ raise ValueError
+ return WalkResult.ADVANCE
+
+ try:
+ module.operation.walk(callback)
+ except ValueError:
+ print("Exception raised")