aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2020-05-19 09:05:09 +0200
committerTres Popp <tpopp@google.com>2020-05-20 10:39:26 +0200
commitfb6986ef69ae7e2dc60b5ae445e6e055a682c53c (patch)
tree9229dcb3e928705b090f608df1f75d4175729110
parent44226c1fea7b5aa407e7e0a17cf71cab9082fd0f (diff)
downloadllvm-fb6986ef69ae7e2dc60b5ae445e6e055a682c53c.zip
llvm-fb6986ef69ae7e2dc60b5ae445e6e055a682c53c.tar.gz
llvm-fb6986ef69ae7e2dc60b5ae445e6e055a682c53c.tar.bz2
[mlir] Custom printing/parsing for Shape::AssumingOp
Summary: Additionally, this adds traits and builder methods to AssumingYieldOp and names the input witness to the AssumingOp. Differential Revision: https://reviews.llvm.org/D80187
-rw-r--r--mlir/include/mlir/Dialect/Shape/IR/Shape.h1
-rw-r--r--mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td16
-rw-r--r--mlir/lib/Dialect/Shape/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp44
-rw-r--r--mlir/test/Dialect/Shape/ops.mlir4
5 files changed, 61 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index e58cc3b..9949758 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -17,6 +17,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index d06a1ce..2e521db 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -426,12 +427,16 @@ def Shape_AssumingOp : Shape_Op<"assuming",
nothing else. They should not exist after a program is fully lowered and
ready to execute.
}];
- let arguments = (ins Shape_WitnessType);
- let regions = (region SizedRegion<1>:$thenRegion);
+ let arguments = (ins Shape_WitnessType:$witness);
+ let regions = (region SizedRegion<1>:$doRegion);
let results = (outs Variadic<AnyType>:$results);
+
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
}
-def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
+def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
+ [NoSideEffect, ReturnLike, Terminator]> {
let summary = "Yield operation";
let description = [{
This yield operation represents a return operation within the assert_and_exec
@@ -441,6 +446,11 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
}];
let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result",
+ [{ /* nothing to do */ }]>
+ ];
}
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index 11a7b255..2af3de8 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRShape
MLIRShapeOpsIncGen
LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e1d1b33..5c6a0c2 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -108,6 +108,50 @@ AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
}
//===----------------------------------------------------------------------===//
+// AssumingOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAssumingOp(OpAsmParser &parser,
+ OperationState &result) {
+ result.regions.reserve(1);
+ Region *doRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::OperandType cond;
+ if (parser.parseOperand(cond) ||
+ parser.resolveOperand(cond, builder.getType<WitnessType>(),
+ result.operands))
+ return failure();
+
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+
+ // Parse the region and add a terminator if elided.
+ if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+static void print(OpAsmPrinter &p, AssumingOp op) {
+ bool yieldsResults = !op.results().empty();
+
+ p << AssumingOp::getOperationName() << " " << op.witness();
+ if (yieldsResults) {
+ p << " -> (" << op.getResultTypes() << ")";
+ }
+ p.printRegion(op.doRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/yieldsResults);
+ p.printOptionalAttrDict(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index dc07e66..7934f4ee0d 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -74,9 +74,9 @@ func @test_constraints() {
%w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
- "shape.assuming"(%w3) ( {
+ shape.assuming %w3 -> !shape.shape {
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
"shape.assuming_yield"(%2) : (!shape.shape) -> ()
- }) : (!shape.witness) -> !shape.shape
+ }
return
}