diff options
-rw-r--r-- | mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/Complex/ops.mlir | 3 |
3 files changed, 10 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index ada6c14..e19d714 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [ }]; let arguments = (ins ArrayAttr:$value); - let results = (outs Complex<AnyFloat>:$complex); + let results = (outs AnyComplex:$complex); let assemblyFormat = "$value attr-dict `:` type($complex)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 6d87067..5529dcc 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() { } auto complexEltTy = getType().getElementType(); - auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]); - auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]); - if (!re || !im) - return emitOpError("requires attribute's elements to be float attributes"); + if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) || + !isa<FloatAttr, IntegerAttr>(arrayAttr[1])) + return emitOpError( + "requires attribute's elements to be float or integer attributes"); + auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); + auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); if (complexEltTy != re.getType() || complexEltTy != im.getType()) { return emitOpError() << "requires attribute's element types (" << re.getType() << ", " diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir index 1050ad0..96f17b2 100644 --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -11,6 +11,9 @@ func.func @ops(%f: f32) { // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32> %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32> + // CHECK: complex.constant [true, false] : complex<i1> + %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1> + // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32> %complex = complex.create %f, %f : complex<f32> |