From 1612993788c3d8049f1e43a6f48ea3264f8a434f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 7 Dec 2023 11:22:53 +0900 Subject: [mlir][complex] Allow integer element types in `complex.constant` ops (#74564) The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type. Note: The other complex dialect ops do not support non-float element types yet. The main purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (#74270). ``` within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex' %complex1 = tensor.extract %c1[] : tensor> ^ within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex "func.func"() <{function_type = () -> tensor<3xcomplex>, sym_name = "extract_from_elements_complex_i"}> ({ %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex %1 = "arith.constant"() <{value = dense<(3,2)> : tensor>}> : () -> tensor> %2 = "arith.constant"() <{value = dense<(1,2)> : tensor>}> : () -> tensor> %3 = "tensor.extract"(%1) : (tensor>) -> complex %4 = "tensor.from_elements"(%0, %3, %0) : (complex, complex, complex) -> tensor<3xcomplex> "func.return"(%4) : (tensor<3xcomplex>) -> () }) : () -> () ``` --- mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td | 2 +- mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 10 ++++++---- mlir/test/Dialect/Complex/ops.mlir | 3 +++ 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index ada6c14..e19d714 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [ }]; let arguments = (ins ArrayAttr:$value); - let results = (outs Complex:$complex); + let results = (outs AnyComplex:$complex); let assemblyFormat = "$value attr-dict `:` type($complex)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 6d87067..5529dcc 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() { } auto complexEltTy = getType().getElementType(); - auto re = llvm::dyn_cast(arrayAttr[0]); - auto im = llvm::dyn_cast(arrayAttr[1]); - if (!re || !im) - return emitOpError("requires attribute's elements to be float attributes"); + if (!isa(arrayAttr[0]) || + !isa(arrayAttr[1])) + return emitOpError( + "requires attribute's elements to be float or integer attributes"); + auto re = llvm::dyn_cast(arrayAttr[0]); + auto im = llvm::dyn_cast(arrayAttr[1]); if (complexEltTy != re.getType() || complexEltTy != im.getType()) { return emitOpError() << "requires attribute's element types (" << re.getType() << ", " diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir index 1050ad0..96f17b2 100644 --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -11,6 +11,9 @@ func.func @ops(%f: f32) { // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex + // CHECK: complex.constant [true, false] : complex + %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex + // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex %complex = complex.create %f, %f : complex -- cgit v1.1