aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Complex/IR/ComplexOps.cpp')
-rw-r--r--mlir/lib/Dialect/Complex/IR/ComplexOps.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6d87067..5529dcc 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
}
auto complexEltTy = getType().getElementType();
- auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
- if (!re || !im)
- return emitOpError("requires attribute's elements to be float attributes");
+ if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+ !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+ return emitOpError(
+ "requires attribute's elements to be float or integer attributes");
+ auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "