aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStella Laurenzo <stellaraccident@gmail.com>2021-11-28 15:33:03 -0800
committerStella Laurenzo <stellaraccident@gmail.com>2021-11-28 18:02:01 -0800
commitace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f (patch)
tree8b627355ce6a9f3d71b4a7f8c6553ab5897b2842
parent1164c4b37583eca98866853ed22149f1a1b55a3d (diff)
downloadllvm-ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f.zip
llvm-ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f.tar.gz
llvm-ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f.tar.bz2
[mlir][python] Normalize asm-printing IR behavior.
While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were: * Only doing "soft fail" verification on IR printing of Operation, not of a Module. * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle). * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors. This patch: * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation. * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works). * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it. It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed: * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level). * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing). Notes to downstreams: * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`. Differential Revision: https://reviews.llvm.org/D114680
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp45
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h6
-rw-r--r--mlir/test/python/dialects/builtin.py29
-rw-r--r--mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py5
-rw-r--r--mlir/test/python/dialects/linalg/ops.py43
-rw-r--r--mlir/test/python/dialects/shape.py3
-rw-r--r--mlir/test/python/dialects/std.py9
-rw-r--r--mlir/test/python/ir/module.py35
-rw-r--r--mlir/test/python/ir/operation.py45
9 files changed, 122 insertions, 98 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4c25fd4..c70cfc5 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -93,6 +93,13 @@ Args:
use_local_Scope: Whether to print in a way that is more optimized for
multi-threaded access but may not be consistent with how the overall
module prints.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
)";
static const char kOperationGetAsmDocstring[] =
@@ -828,14 +835,21 @@ void PyOperation::checkValid() const {
void PyOperationBase::print(py::object fileObject, bool binary,
llvm::Optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm, bool useLocalScope) {
+ bool printGenericOpForm, bool useLocalScope,
+ bool assumeVerified) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
fileObject = py::module::import("sys").attr("stdout");
- if (!printGenericOpForm && !mlirOperationVerify(operation)) {
- fileObject.attr("write")("// Verification failed, printing generic form\n");
+ if (!assumeVerified && !printGenericOpForm &&
+ !mlirOperationVerify(operation)) {
+ std::string message("// Verification failed, printing generic form\n");
+ if (binary) {
+ fileObject.attr("write")(py::bytes(message));
+ } else {
+ fileObject.attr("write")(py::str(message));
+ }
printGenericOpForm = true;
}
@@ -857,8 +871,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
py::object PyOperationBase::getAsm(bool binary,
llvm::Optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm,
- bool useLocalScope) {
+ bool printGenericOpForm, bool useLocalScope,
+ bool assumeVerified) {
py::object fileObject;
if (binary) {
fileObject = py::module::import("io").attr("BytesIO")();
@@ -870,7 +884,8 @@ py::object PyOperationBase::getAsm(bool binary,
/*enableDebugInfo=*/enableDebugInfo,
/*prettyDebugInfo=*/prettyDebugInfo,
/*printGenericOpForm=*/printGenericOpForm,
- /*useLocalScope=*/useLocalScope);
+ /*useLocalScope=*/useLocalScope,
+ /*assumeVerified=*/assumeVerified);
return fileObject.attr("getvalue")();
}
@@ -2149,12 +2164,9 @@ void mlir::python::populateIRCore(py::module &m) {
kDumpDocstring)
.def(
"__str__",
- [](PyModule &self) {
- MlirOperation operation = mlirModuleGetOperation(self.get());
- PyPrintAccumulator printAccum;
- mlirOperationPrint(operation, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
+ [](py::object self) {
+ // Defer to the operation's __str__.
+ return self.attr("operation").attr("__str__")();
},
kOperationStrDunderDocstring);
@@ -2234,7 +2246,8 @@ void mlir::python::populateIRCore(py::module &m) {
/*enableDebugInfo=*/false,
/*prettyDebugInfo=*/false,
/*printGenericOpForm=*/false,
- /*useLocalScope=*/false);
+ /*useLocalScope=*/false,
+ /*assumeVerified=*/false);
},
"Returns the assembly form of the operation.")
.def("print", &PyOperationBase::print,
@@ -2244,7 +2257,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("enable_debug_info") = false,
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
- py::arg("use_local_scope") = false, kOperationPrintDocstring)
+ py::arg("use_local_scope") = false,
+ py::arg("assume_verified") = false, kOperationPrintDocstring)
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
py::arg("binary") = false,
@@ -2252,7 +2266,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("enable_debug_info") = false,
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
- py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
+ py::arg("use_local_scope") = false,
+ py::arg("assume_verified") = false, kOperationGetAsmDocstring)
.def(
"verify",
[](PyOperationBase &self) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index eb5c238..dc024a2 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -394,11 +394,13 @@ public:
/// Implements the bound 'print' method and helps with others.
void print(pybind11::object fileObject, bool binary,
llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
- bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+ bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
+ bool assumeVerified);
pybind11::object getAsm(bool binary,
llvm::Optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm, bool useLocalScope);
+ bool printGenericOpForm, bool useLocalScope,
+ bool assumeVerified);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 8f3a041..7caf5b5 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -175,7 +175,8 @@ def testBuildFuncOp():
# CHECK-LABEL: TEST: testFuncArgumentAccess
@run
def testFuncArgumentAccess():
- with Context(), Location.unknown():
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
module = Module.create()
f32 = F32Type.get()
f64 = F64Type.get()
@@ -185,38 +186,38 @@ def testFuncArgumentAccess():
std.ReturnOp(func.arguments)
func.arg_attrs = ArrayAttr.get([
DictAttr.get({
- "foo": StringAttr.get("bar"),
- "baz": UnitAttr.get()
+ "custom_dialect.foo": StringAttr.get("bar"),
+ "custom_dialect.baz": UnitAttr.get()
}),
- DictAttr.get({"qux": ArrayAttr.get([])})
+ DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
])
func.result_attrs = ArrayAttr.get([
- DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
- DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
+ DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+ DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
])
other = builtin.FuncOp("other_func", ([f32, f32], []))
with InsertionPoint(other.add_entry_block()):
std.ReturnOp([])
other.arg_attrs = [
- DictAttr.get({"foo": StringAttr.get("qux")}),
+ DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
DictAttr.get()
]
- # CHECK: [{baz, foo = "bar"}, {qux = []}]
+ # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
print(func.arg_attrs)
- # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
+ # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
print(func.result_attrs)
# CHECK: func @some_func(
- # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
- # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
- # CHECK: f32 {res1 = 4.200000e+01 : f32},
- # CHECK: f32 {res2 = 2.560000e+02 : f64})
+ # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+ # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+ # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+ # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
#
# CHECK: func @other_func(
- # CHECK: %{{.*}}: f32 {foo = "qux"},
+ # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
# CHECK: %{{.*}}: f32)
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index d0c7427..115c227 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -405,4 +405,7 @@ with Context() as ctx, Location.unknown():
return non_default_op_name(input, outs=[init_result])
-print(module)
+# TODO: Fix me! Conv and pooling ops above do not verify, which was uncovered
+# when switching to more robust module verification. For now, reverting to the
+# old behavior which does not verify on module print.
+print(module.operation.get_asm(assume_verified=True))
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e5b96c2..4f9f138 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -83,49 +83,6 @@ def testFill():
print(module)
-# CHECK-LABEL: TEST: testStructuredOpOnTensors
-@run
-def testStructuredOpOnTensors():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- tensor_type = RankedTensorType.get((2, 3, 4), f32)
- with InsertionPoint(module.body):
- func = builtin.FuncOp(
- name="matmul_test",
- type=FunctionType.get(
- inputs=[tensor_type, tensor_type], results=[tensor_type]))
- with InsertionPoint(func.add_entry_block()):
- lhs, rhs = func.entry_block.arguments
- result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
- std.ReturnOp([result])
-
- # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
- print(module)
-
-
-# CHECK-LABEL: TEST: testStructuredOpOnBuffers
-@run
-def testStructuredOpOnBuffers():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- memref_type = MemRefType.get((2, 3, 4), f32)
- with InsertionPoint(module.body):
- func = builtin.FuncOp(
- name="matmul_test",
- type=FunctionType.get(
- inputs=[memref_type, memref_type, memref_type], results=[]))
- with InsertionPoint(func.add_entry_block()):
- lhs, rhs, result = func.entry_block.arguments
- # TODO: prperly hook up the region.
- linalg.MatmulOp([lhs, rhs], outputs=[result])
- std.ReturnOp([])
-
- # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
- print(module)
-
-
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
def testNamedStructuredOpCustomForm():
diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 7c1c5d6..a798b85 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -22,7 +22,8 @@ def testConstShape():
@builtin.FuncOp.from_py_func(
RankedTensorType.get((12, -1), f32))
def const_shape_tensor(arg):
- return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
+ return shape.ConstShapeOp(
+ DenseElementsAttr.get(np.array([10, 20]), type=IndexType.get()))
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
diff --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py
index f6e77ca..2a3b2df 100644
--- a/mlir/test/python/dialects/std.py
+++ b/mlir/test/python/dialects/std.py
@@ -78,8 +78,11 @@ def testConstantIndexOp():
@constructAndPrintInModule
def testFunctionCalls():
foo = builtin.FuncOp("foo", ([], []))
+ foo.sym_visibility = StringAttr.get("private")
bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
+ bar.sym_visibility = StringAttr.get("private")
qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
+ qux.sym_visibility = StringAttr.get("private")
with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
std.CallOp(foo, [])
@@ -88,9 +91,9 @@ def testFunctionCalls():
std.ReturnOp([])
-# CHECK: func @foo()
-# CHECK: func @bar() -> index
-# CHECK: func @qux() -> f32
+# CHECK: func private @foo()
+# CHECK: func private @bar() -> index
+# CHECK: func private @qux() -> f32
# CHECK: func @caller() {
# CHECK: call @foo() : () -> ()
# CHECK: %0 = call @bar() : () -> index
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index abddc66..76358eb 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -8,11 +8,13 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
+@run
def testParseSuccess():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -23,12 +25,11 @@ def testParseSuccess():
module.dump() # Just outputs to stderr. Verifies that it functions.
print(str(module))
-run(testParseSuccess)
-
# Verify parse error.
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+@run
def testParseError():
ctx = Context()
try:
@@ -38,12 +39,11 @@ def testParseError():
else:
print("Exception not produced")
-run(testParseError)
-
# Verify successful parse.
# CHECK-LABEL: TEST: testCreateEmpty
# CHECK: module {
+@run
def testCreateEmpty():
ctx = Context()
loc = Location.unknown(ctx)
@@ -53,8 +53,6 @@ def testCreateEmpty():
gc.collect()
print(str(module))
-run(testCreateEmpty)
-
# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
@@ -62,6 +60,7 @@ run(testCreateEmpty)
# CHECK-LABEL: TEST: testRoundtripUnicode
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
+@run
def testRoundtripUnicode():
ctx = Context()
module = Module.parse(r"""
@@ -69,11 +68,28 @@ def testRoundtripUnicode():
""", ctx)
print(str(module))
-run(testRoundtripUnicode)
+
+# Verify round-trip of ASM that contains unicode.
+# Note that this does not test that the print path converts unicode properly
+# because MLIR asm always normalizes it to the hex encoding.
+# CHECK-LABEL: TEST: testRoundtripBinary
+# CHECK: func private @roundtripUnicode()
+# CHECK: foo = "\F0\9F\98\8A"
+@run
+def testRoundtripBinary():
+ with Context():
+ module = Module.parse(r"""
+ func private @roundtripUnicode() attributes { foo = "😊" }
+ """)
+ binary_asm = module.operation.get_asm(binary=True)
+ assert isinstance(binary_asm, bytes)
+ module = Module.parse(binary_asm)
+ print(module)
# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
+@run
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -101,10 +117,9 @@ def testModuleOperation():
assert ctx._get_live_operation_count() == 0
assert ctx._get_live_module_count() == 0
-run(testModuleOperation)
-
# CHECK-LABEL: TEST: testModuleCapsule
+@run
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -122,5 +137,3 @@ def testModuleCapsule():
gc.collect()
assert ctx._get_live_module_count() == 0
-
-run(testModuleCapsule)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 8771ca0..133edc2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -630,21 +630,50 @@ def testSingleResultProperty():
print(module.body.operations[2])
-# CHECK-LABEL: TEST: testPrintInvalidOperation
+def create_invalid_operation():
+ # This module has two region and is invalid verify that we fallback
+ # to the generic printer for safety.
+ op = Operation.create("builtin.module", regions=2)
+ op.regions[0].blocks.append()
+ return op
+
+# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
@run
-def testPrintInvalidOperation():
+def testInvalidOperationStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
- module = Operation.create("builtin.module", regions=2)
- # This module has two region and is invalid verify that we fallback
- # to the generic printer for safety.
- block = module.regions[0].blocks.append()
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
# CHECK: // Verification failed, printing generic form
# CHECK: "builtin.module"() ( {
# CHECK: }) : () -> ()
- print(module)
+ print(invalid_op)
# CHECK: .verify = False
- print(f".verify = {module.operation.verify()}")
+ print(f".verify = {invalid_op.operation.verify()}")
+
+
+# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
+@run
+def testInvalidModuleStrSoftFails():
+ ctx = Context()
+ with Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
+ # CHECK: // Verification failed, printing generic form
+ print(module)
+
+
+# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
+@run
+def testInvalidOperationGetAsmBinarySoftFails():
+ ctx = Context()
+ with Location.unknown(ctx):
+ invalid_op = create_invalid_operation()
+ # Verify that we fallback to the generic printer for safety.
+ # CHECK: b'// Verification failed, printing generic form\n
+ print(invalid_op.get_asm(binary=True))
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes