aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td2
-rw-r--r--mlir/lib/Dialect/Complex/IR/ComplexOps.cpp10
-rw-r--r--mlir/test/Dialect/Complex/ops.mlir3
3 files changed, 10 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index ada6c14..e19d714 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [
}];
let arguments = (ins ArrayAttr:$value);
- let results = (outs Complex<AnyFloat>:$complex);
+ let results = (outs AnyComplex:$complex);
let assemblyFormat = "$value attr-dict `:` type($complex)";
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6d87067..5529dcc 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
}
auto complexEltTy = getType().getElementType();
- auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
- if (!re || !im)
- return emitOpError("requires attribute's elements to be float attributes");
+ if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+ !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+ return emitOpError(
+ "requires attribute's elements to be float or integer attributes");
+ auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 1050ad0..96f17b2 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -11,6 +11,9 @@ func.func @ops(%f: f32) {
// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
+ // CHECK: complex.constant [true, false] : complex<i1>
+ %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1>
+
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>