aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-12-07 11:22:53 +0900
committerGitHub <noreply@github.com>2023-12-07 03:22:53 +0100
commit1612993788c3d8049f1e43a6f48ea3264f8a434f (patch)
treefb280bde3afabd778d93c2b0156f4fd04f532ec5 /mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
parentb683709ea6eec7d0a388bd50c571774c9b9ffdb7 (diff)
downloadllvm-1612993788c3d8049f1e43a6f48ea3264f8a434f.zip
llvm-1612993788c3d8049f1e43a6f48ea3264f8a434f.tar.gz
llvm-1612993788c3d8049f1e43a6f48ea3264f8a434f.tar.bz2
[mlir][complex] Allow integer element types in `complex.constant` ops (#74564)
The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type. Note: The other complex dialect ops do not support non-float element types yet. The main purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (#74270). ``` within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>' %complex1 = tensor.extract %c1[] : tensor<complex<i32>> ^ within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> "func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({ %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32> %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>> "func.return"(%4) : (tensor<3xcomplex<i32>>) -> () }) : () -> () ```
Diffstat (limited to 'mlir/lib/Dialect/Complex/IR/ComplexOps.cpp')
-rw-r--r--mlir/lib/Dialect/Complex/IR/ComplexOps.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6d87067..5529dcc 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
}
auto complexEltTy = getType().getElementType();
- auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
- if (!re || !im)
- return emitOpError("requires attribute's elements to be float attributes");
+ if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+ !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+ return emitOpError(
+ "requires attribute's elements to be float or integer attributes");
+ auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "