diff options
-rw-r--r-- | mlir/include/mlir-c/IR.h | 10 | ||||
-rw-r--r-- | mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 1 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/IRCore.cpp | 32 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/IRModule.h | 4 | ||||
-rw-r--r-- | mlir/lib/CAPI/IR/IR.cpp | 21 | ||||
-rw-r--r-- | mlir/test/CAPI/ir.c | 58 | ||||
-rw-r--r-- | mlir/test/python/ir/operation.py | 75 |
7 files changed, 184 insertions, 17 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 82da511..32abacf 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); +/// Operation walk result. +typedef enum MlirWalkResult { + MlirWalkResultAdvance, + MlirWalkResultInterrupt, + MlirWalkResultSkip +} MlirWalkResult; + /// Traversal order for operation walk. typedef enum MlirWalkOrder { MlirWalkPreOrder, @@ -713,7 +720,8 @@ typedef enum MlirWalkOrder { /// Operation walker type. The handler is passed an (opaque) reference to an /// operation and a pointer to a `userData`. -typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); +typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation, + void *userData); /// Walks operation `op` in `walkOrder` and calls `callback` on that operation. /// `*userData` is passed to the callback as well and can be used to tunnel some diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 52f6321..d8f22c7a 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -18,6 +18,7 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H +#include <pybind11/functional.h> #include <pybind11/pybind11.h> #include <pybind11/pytypes.h> #include <pybind11/stl.h> diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 734f2f7..d875f4e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) { data->rootOp.getOperation().getContext()->clearOperation(op); else data->rootSeen = true; + return MlirWalkResult::MlirWalkResultAdvance; }; mlirOperationWalk(op.getOperation(), invalidatingCallback, static_cast<void *>(&data), MlirWalkPreOrder); @@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, .str()); } +void PyOperationBase::walk( + std::function<MlirWalkResult(MlirOperation)> callback, + MlirWalkOrder walkOrder) { + PyOperation &operation = getOperation(); + operation.checkValid(); + MlirOperationWalkCallback walkCallback = [](MlirOperation op, + void *userData) { + auto *fn = + static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData); + return (*fn)(op); + }; + + mlirOperationWalk(operation, walkCallback, &callback, walkOrder); +} + py::object PyOperationBase::getAsm(bool binary, std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, @@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) { .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); + py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local()) + .value("PRE_ORDER", MlirWalkPreOrder) + .value("POST_ORDER", MlirWalkPostOrder); + + py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local()) + .value("ADVANCE", MlirWalkResultAdvance) + .value("INTERRUPT", MlirWalkResultInterrupt) + .value("SKIP", MlirWalkResultSkip); + //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- @@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool, - bool, py::object, bool>( - &PyOperationBase::print), + bool, py::object, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) { return operation.createOpView(); }, "Detaches the operation from its parent block.") - .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }); + .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) + .def("walk", &PyOperationBase::walk, py::arg("callback"), + py::arg("walk_order") = MlirWalkPostOrder); py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9acfdde..b038a0c 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -579,6 +579,10 @@ public: void writeBytecode(const pybind11::object &fileObject, std::optional<int64_t> bytecodeVersion); + // Implement the walk method. + void walk(std::function<MlirWalkResult(MlirOperation)> callback, + MlirWalkOrder walkOrder); + /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index cdb64f4..a72cd24 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { return unwrap(op)->moveBefore(unwrap(other)); } +static mlir::WalkResult unwrap(MlirWalkResult result) { + switch (result) { + case MlirWalkResultAdvance: + return mlir::WalkResult::advance(); + + case MlirWalkResultInterrupt: + return mlir::WalkResult::interrupt(); + + case MlirWalkResultSkip: + return mlir::WalkResult::skip(); + } +} + void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder) { switch (walkOrder) { case MlirWalkPreOrder: unwrap(op)->walk<mlir::WalkOrder::PreOrder>( - [callback, userData](Operation *op) { callback(wrap(op), userData); }); + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); break; case MlirWalkPostOrder: unwrap(op)->walk<mlir::WalkOrder::PostOrder>( - [callback, userData](Operation *op) { callback(wrap(op), userData); }); + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); } } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 8e79338..3d05b2a 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2244,9 +2244,22 @@ typedef struct { const char *x; } callBackData; -void walkCallBack(MlirOperation op, void *rootOpVoid) { +MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) { fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x, mlirIdentifierStr(mlirOperationGetName(op)).data); + return MlirWalkResultAdvance; +} + +MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) { + fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x, + mlirIdentifierStr(mlirOperationGetName(op)).data); + if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") == + 0) + return MlirWalkResultSkip; + if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") == + 0) + return MlirWalkResultInterrupt; + return MlirWalkResultAdvance; } int testOperationWalk(MlirContext ctx) { @@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) { " arith.addi %1, %1: i32\n" " return\n" " }\n" + " func.func @bar() {\n" + " return\n" + " }\n" "}"; MlirModule module = mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); @@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) { callBackData data; data.x = "i love you"; - // CHECK: i love you: arith.constant - // CHECK: i love you: arith.addi - // CHECK: i love you: func.return - // CHECK: i love you: func.func - // CHECK: i love you: builtin.module + // CHECK-NEXT: i love you: arith.constant + // CHECK-NEXT: i love you: arith.addi + // CHECK-NEXT: i love you: func.return + // CHECK-NEXT: i love you: func.func + // CHECK-NEXT: i love you: func.return + // CHECK-NEXT: i love you: func.func + // CHECK-NEXT: i love you: builtin.module mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack, (void *)(&data), MlirWalkPostOrder); data.x = "i don't love you"; - // CHECK: i don't love you: builtin.module - // CHECK: i don't love you: func.func - // CHECK: i don't love you: arith.constant - // CHECK: i don't love you: arith.addi - // CHECK: i don't love you: func.return + // CHECK-NEXT: i don't love you: builtin.module + // CHECK-NEXT: i don't love you: func.func + // CHECK-NEXT: i don't love you: arith.constant + // CHECK-NEXT: i don't love you: arith.addi + // CHECK-NEXT: i don't love you: func.return + // CHECK-NEXT: i don't love you: func.func + // CHECK-NEXT: i don't love you: func.return mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack, (void *)(&data), MlirWalkPreOrder); + + data.x = "interrupt"; + // Interrupted at `arith.addi` + // CHECK-NEXT: interrupt: arith.constant + // CHECK-NEXT: interrupt: arith.addi + mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult, + (void *)(&data), MlirWalkPostOrder); + + data.x = "skip"; + // Skip at `func.func` + // CHECK-NEXT: skip: builtin.module + // CHECK-NEXT: skip: func.func + // CHECK-NEXT: skip: func.func + mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult, + (void *)(&data), MlirWalkPreOrder); + mlirModuleDestroy(module); return 0; } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 04f8a99..9666e63 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1015,3 +1015,78 @@ def testOperationParse(): print( f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}" ) + + +# CHECK-LABEL: TEST: testOpWalk +@run +def testOpWalk(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + builtin.module { + func.func @f() { + func.return + } + } + """, + ctx, + ) + + def callback(op): + print(op.name) + return WalkResult.ADVANCE + + # Test post-order walk (default). + # CHECK-NEXT: Post-order + # CHECK-NEXT: func.return + # CHECK-NEXT: func.func + # CHECK-NEXT: builtin.module + print("Post-order") + module.operation.walk(callback) + + # Test pre-order walk. + # CHECK-NEXT: Pre-order + # CHECK-NEXT: builtin.module + # CHECK-NEXT: func.fun + # CHECK-NEXT: func.return + print("Pre-order") + module.operation.walk(callback, WalkOrder.PRE_ORDER) + + # Test interrput. + # CHECK-NEXT: Interrupt post-order + # CHECK-NEXT: func.return + print("Interrupt post-order") + + def callback(op): + print(op.name) + return WalkResult.INTERRUPT + + module.operation.walk(callback) + + # Test skip. + # CHECK-NEXT: Skip pre-order + # CHECK-NEXT: builtin.module + print("Skip pre-order") + + def callback(op): + print(op.name) + return WalkResult.SKIP + + module.operation.walk(callback, WalkOrder.PRE_ORDER) + + # Test exception. + # CHECK: Exception + # CHECK-NEXT: func.return + # CHECK-NEXT: Exception raised + print("Exception") + + def callback(op): + print(op.name) + raise ValueError + return WalkResult.ADVANCE + + try: + module.operation.walk(callback) + except ValueError: + print("Exception raised") |