aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHideto Ueno <uenoku.tokotoko@gmail.com>2024-04-17 15:09:47 +0900
committerGitHub <noreply@github.com>2024-04-17 15:09:47 +0900
commit47148832d4e3bf4901430732f1af6673147accb2 (patch)
tree5ab14b2dad51756797651fcddf46c5bb6dffe903
parentb851c7f1fc4fd83ea84d565bbdc30fd0d356788c (diff)
downloadllvm-47148832d4e3bf4901430732f1af6673147accb2.zip
llvm-47148832d4e3bf4901430732f1af6673147accb2.tar.gz
llvm-47148832d4e3bf4901430732f1af6673147accb2.tar.bz2
[mlir][python] Add `walk` method to PyOperationBase (#87962)
This commit adds `walk` method to PyOperationBase that uses a python object as a callback, e.g. `op.walk(callback)`. Currently callback must return a walk result explicitly. We(SiFive) have implemented walk method with python in our internal python tool for a while. However the overhead of python is expensive and it didn't scale well for large MLIR files. Just replacing walk with this version reduced the entire execution time of the tool by 30~40% and there are a few configs that the tool takes several hours to finish so this commit significantly improves tool performance.
-rw-r--r--mlir/include/mlir-c/IR.h10
-rw-r--r--mlir/include/mlir/Bindings/Python/PybindAdaptors.h1
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp32
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h4
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp21
-rw-r--r--mlir/test/CAPI/ir.c58
-rw-r--r--mlir/test/python/ir/operation.py75
7 files changed, 184 insertions, 17 deletions
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82da511..32abacf 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);
+/// Operation walk result.
+typedef enum MlirWalkResult {
+ MlirWalkResultAdvance,
+ MlirWalkResultInterrupt,
+ MlirWalkResultSkip
+} MlirWalkResult;
+
/// Traversal order for operation walk.
typedef enum MlirWalkOrder {
MlirWalkPreOrder,
@@ -713,7 +720,8 @@ typedef enum MlirWalkOrder {
/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation and a pointer to a `userData`.
-typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation,
+ void *userData);
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 52f6321..d8f22c7a 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -18,6 +18,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
+#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 734f2f7..d875f4e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
+ return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
@@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
.str());
}
+void PyOperationBase::walk(
+ std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+ MlirOperationWalkCallback walkCallback = [](MlirOperation op,
+ void *userData) {
+ auto *fn =
+ static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
+ return (*fn)(op);
+ };
+
+ mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+}
+
py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
@@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) {
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
+ py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
@@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
@@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
+ .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def("walk", &PyOperationBase::walk, py::arg("callback"),
+ py::arg("walk_order") = MlirWalkPostOrder);
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9acfdde..b038a0c 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -579,6 +579,10 @@ public:
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
+ // Implement the walk method.
+ void walk(std::function<MlirWalkResult(MlirOperation)> callback,
+ MlirWalkOrder walkOrder);
+
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index cdb64f4..a72cd24 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
+static mlir::WalkResult unwrap(MlirWalkResult result) {
+ switch (result) {
+ case MlirWalkResultAdvance:
+ return mlir::WalkResult::advance();
+
+ case MlirWalkResultInterrupt:
+ return mlir::WalkResult::interrupt();
+
+ case MlirWalkResultSkip:
+ return mlir::WalkResult::skip();
+ }
+}
+
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder) {
switch (walkOrder) {
case MlirWalkPreOrder:
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return unwrap(callback(wrap(op), userData));
+ });
break;
case MlirWalkPostOrder:
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
- [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ [callback, userData](Operation *op) {
+ return unwrap(callback(wrap(op), userData));
+ });
}
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8e79338..3d05b2a 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2244,9 +2244,22 @@ typedef struct {
const char *x;
} callBackData;
-void walkCallBack(MlirOperation op, void *rootOpVoid) {
+MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) {
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
mlirIdentifierStr(mlirOperationGetName(op)).data);
+ return MlirWalkResultAdvance;
+}
+
+MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) {
+ fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+ mlirIdentifierStr(mlirOperationGetName(op)).data);
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "func.func") ==
+ 0)
+ return MlirWalkResultSkip;
+ if (strcmp(mlirIdentifierStr(mlirOperationGetName(op)).data, "arith.addi") ==
+ 0)
+ return MlirWalkResultInterrupt;
+ return MlirWalkResultAdvance;
}
int testOperationWalk(MlirContext ctx) {
@@ -2259,6 +2272,9 @@ int testOperationWalk(MlirContext ctx) {
" arith.addi %1, %1: i32\n"
" return\n"
" }\n"
+ " func.func @bar() {\n"
+ " return\n"
+ " }\n"
"}";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
@@ -2266,22 +2282,42 @@ int testOperationWalk(MlirContext ctx) {
callBackData data;
data.x = "i love you";
- // CHECK: i love you: arith.constant
- // CHECK: i love you: arith.addi
- // CHECK: i love you: func.return
- // CHECK: i love you: func.func
- // CHECK: i love you: builtin.module
+ // CHECK-NEXT: i love you: arith.constant
+ // CHECK-NEXT: i love you: arith.addi
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: func.return
+ // CHECK-NEXT: i love you: func.func
+ // CHECK-NEXT: i love you: builtin.module
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPostOrder);
data.x = "i don't love you";
- // CHECK: i don't love you: builtin.module
- // CHECK: i don't love you: func.func
- // CHECK: i don't love you: arith.constant
- // CHECK: i don't love you: arith.addi
- // CHECK: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: builtin.module
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: arith.constant
+ // CHECK-NEXT: i don't love you: arith.addi
+ // CHECK-NEXT: i don't love you: func.return
+ // CHECK-NEXT: i don't love you: func.func
+ // CHECK-NEXT: i don't love you: func.return
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
(void *)(&data), MlirWalkPreOrder);
+
+ data.x = "interrupt";
+ // Interrupted at `arith.addi`
+ // CHECK-NEXT: interrupt: arith.constant
+ // CHECK-NEXT: interrupt: arith.addi
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPostOrder);
+
+ data.x = "skip";
+ // Skip at `func.func`
+ // CHECK-NEXT: skip: builtin.module
+ // CHECK-NEXT: skip: func.func
+ // CHECK-NEXT: skip: func.func
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBackTestWalkResult,
+ (void *)(&data), MlirWalkPreOrder);
+
mlirModuleDestroy(module);
return 0;
}
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a99..9666e63 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1015,3 +1015,78 @@ def testOperationParse():
print(
f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
)
+
+
+# CHECK-LABEL: TEST: testOpWalk
+@run
+def testOpWalk():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ func.func @f() {
+ func.return
+ }
+ }
+ """,
+ ctx,
+ )
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.ADVANCE
+
+ # Test post-order walk (default).
+ # CHECK-NEXT: Post-order
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: func.func
+ # CHECK-NEXT: builtin.module
+ print("Post-order")
+ module.operation.walk(callback)
+
+ # Test pre-order walk.
+ # CHECK-NEXT: Pre-order
+ # CHECK-NEXT: builtin.module
+ # CHECK-NEXT: func.fun
+ # CHECK-NEXT: func.return
+ print("Pre-order")
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+ # Test interrput.
+ # CHECK-NEXT: Interrupt post-order
+ # CHECK-NEXT: func.return
+ print("Interrupt post-order")
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.INTERRUPT
+
+ module.operation.walk(callback)
+
+ # Test skip.
+ # CHECK-NEXT: Skip pre-order
+ # CHECK-NEXT: builtin.module
+ print("Skip pre-order")
+
+ def callback(op):
+ print(op.name)
+ return WalkResult.SKIP
+
+ module.operation.walk(callback, WalkOrder.PRE_ORDER)
+
+ # Test exception.
+ # CHECK: Exception
+ # CHECK-NEXT: func.return
+ # CHECK-NEXT: Exception raised
+ print("Exception")
+
+ def callback(op):
+ print(op.name)
+ raise ValueError
+ return WalkResult.ADVANCE
+
+ try:
+ module.operation.walk(callback)
+ except ValueError:
+ print("Exception raised")