aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Butygin <ivan.butygin@gmail.com>2023-08-28 20:51:37 +0200
committerIvan Butygin <ivan.butygin@gmail.com>2023-09-07 12:30:29 +0200
commit5dce74817b71a1f646fb2857c037b3a66f41c7cd (patch)
tree11855288aa2eeb1abe59279f8c46a258b9d1ea4b
parent0dd4d3b5cc47410c41535d540712cc2b5d99f21b (diff)
downloadllvm-5dce74817b71a1f646fb2857c037b3a66f41c7cd.zip
llvm-5dce74817b71a1f646fb2857c037b3a66f41c7cd.tar.gz
llvm-5dce74817b71a1f646fb2857c037b3a66f41c7cd.tar.bz2
[mlir][ub] Add poison support to CommonFolders.h
Return poison from foldBinary/unary if argument(s) is poison. Add ub dialect as dependency to affected dialects (arith, math, spirv, shape). Add poison materialization to dialects. Add tests for some ops from each dialect. Not all affected ops are covered as it will involve a huge copypaste. Differential Revision: https://reviews.llvm.org/D159013
-rw-r--r--mlir/include/mlir/Dialect/CommonFolders.h85
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithDialect.cpp4
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp1
-rw-r--r--mlir/lib/Dialect/Math/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Math/IR/MathDialect.cpp1
-rw-r--r--mlir/lib/Dialect/Math/IR/MathOps.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp1
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp4
-rw-r--r--mlir/lib/Dialect/Shape/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp5
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir55
-rw-r--r--mlir/test/Dialect/Math/canonicalize.mlir9
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir9
-rw-r--r--mlir/test/Dialect/Shape/canonicalize.mlir13
15 files changed, 179 insertions, 15 deletions
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index d3fbc72..6257e4a 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -22,17 +22,35 @@
#include <optional>
namespace mlir {
+namespace ub {
+class PoisonAttr;
+}
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses `resultType` for the type of the returned attribute.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
- const CalculationT &calculate) {
+ CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
+ static_assert(
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+ "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+ "void as template argument to opt-out from poison semantics.");
+ if constexpr (!std::is_void_v<PoisonAttr>) {
+ if (isa_and_nonnull<PoisonAttr>(operands[0]))
+ return operands[0];
+
+ if (isa_and_nonnull<PoisonAttr>(operands[1]))
+ return operands[1];
+ }
+
if (!resultType || !operands[0] || !operands[1])
return {};
@@ -95,13 +113,28 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attributes in `operands` and returns the result if possible.
/// Uses the operand element type for the element type of the returned
/// attribute.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+ CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
+ static_assert(
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+ "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+ "void as template argument to opt-out from poison semantics.");
+ if constexpr (!std::is_void_v<PoisonAttr>) {
+ if (isa_and_nonnull<PoisonAttr>(operands[0]))
+ return operands[0];
+
+ if (isa_and_nonnull<PoisonAttr>(operands[1]))
+ return operands[1];
+ }
+
auto getResultType = [](Attribute attr) -> Type {
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
return typed.getType();
@@ -115,18 +148,19 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (lhsType != rhsType)
return {};
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
- CalculationT>(operands, lhsType,
- calculate);
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ CalculationT>(
+ operands, lhsType, std::forward<CalculationT>(calculate));
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = void,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
- const CalculationT &calculate) {
- return constFoldBinaryOpConditional<AttrElementT>(
+ CalculationT &&calculate) {
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, resultType,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
@@ -135,11 +169,12 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
- return constFoldBinaryOpConditional<AttrElementT>(
+ CalculationT &&calculate) {
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
@@ -148,16 +183,28 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
- const CalculationT &&calculate) {
+ CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
return {};
+ static_assert(
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+ "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+ "void as template argument to opt-out from poison semantics.");
+ if constexpr (!std::is_void_v<PoisonAttr>) {
+ if (isa<PoisonAttr>(operands[0]))
+ return operands[0];
+ }
+
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
@@ -196,10 +243,11 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
- const CalculationT &&calculate) {
- return constFoldUnaryOpConditional<AttrElementT>(
+ CalculationT &&calculate) {
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
return calculate(a);
});
@@ -209,13 +257,23 @@ template <
class AttrElementT, class TargetAttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class TargetElementValueT = typename TargetAttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
- const CalculationT &calculate) {
+ CalculationT &&calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
return {};
+ static_assert(
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+ "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+ "void as template argument to opt-out from poison semantics.");
+ if constexpr (!std::is_void_v<PoisonAttr>) {
+ if (isa<PoisonAttr>(operands[0]))
+ return operands[0];
+ }
+
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
bool castStatus = true;
@@ -254,7 +312,6 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
}
return {};
}
-
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 7f2d793..ed4b91c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -49,5 +50,8 @@ void arith::ArithDialect::initialize() {
Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+ return builder.create<ub::PoisonOp>(loc, type, poison);
+
return ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0a86d8f..fab6f34 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -9,7 +9,6 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
index 3b7b65e..ed95bf8 100644
--- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRMathDialect
MLIRArithDialect
MLIRDialect
MLIRIR
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
index 54a8cc1..9cf47ac 100644
--- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index ae9dc08..28d1c06 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include <optional>
@@ -522,5 +523,8 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+ return builder.create<ub::PoisonOp>(loc, type, poison);
+
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 0189e79..2b5ceda 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -43,4 +43,5 @@ add_mlir_dialect_library(MLIRSPIRVDialect
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransforms
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index def62b4..9acd982 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 76e7039..a51d77d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -949,6 +950,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+ return builder.create<ub::PoisonOp>(loc, type, poison);
+
if (!spirv::ConstantOp::isBuildableWith(type))
return nullptr;
diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index ba41f1a..32a86b4 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -23,4 +23,5 @@ add_mlir_dialect_library(MLIRShapeDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRTensorDialect
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e4efa093..2444556 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -147,6 +148,9 @@ void ShapeDialect::initialize() {
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
+ if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+ return builder.create<ub::PoisonOp>(loc, type, poison);
+
if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
return builder.create<ConstShapeOp>(
loc, type, llvm::cast<DenseIntElementsAttr>(value));
@@ -156,6 +160,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
if (llvm::isa<WitnessType>(type))
return builder.create<ConstWitnessOp>(loc, type,
llvm::cast<BoolAttr>(value));
+
return arith::ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0c8e097..347b634 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2584,3 +2584,58 @@ func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
%select4 = arith.select %false, %poison, %arg : i32
return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
}
+
+// CHECK-LABEL: @addi_poison1
+// CHECK: %[[P:.*]] = ub.poison : i32
+// CHECK: return %[[P]]
+func.func @addi_poison1(%arg: i32) -> i32 {
+ %0 = ub.poison : i32
+ %1 = arith.addi %0, %arg : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @addi_poison2
+// CHECK: %[[P:.*]] = ub.poison : i32
+// CHECK: return %[[P]]
+func.func @addi_poison2(%arg: i32) -> i32 {
+ %0 = ub.poison : i32
+ %1 = arith.addi %arg, %0 : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @addf_poison1
+// CHECK: %[[P:.*]] = ub.poison : f32
+// CHECK: return %[[P]]
+func.func @addf_poison1(%arg: f32) -> f32 {
+ %0 = ub.poison : f32
+ %1 = arith.addf %0, %arg : f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @addf_poison2
+// CHECK: %[[P:.*]] = ub.poison : f32
+// CHECK: return %[[P]]
+func.func @addf_poison2(%arg: f32) -> f32 {
+ %0 = ub.poison : f32
+ %1 = arith.addf %arg, %0 : f32
+ return %1 : f32
+}
+
+
+// CHECK-LABEL: @negf_poison
+// CHECK: %[[P:.*]] = ub.poison : f32
+// CHECK: return %[[P]]
+func.func @negf_poison() -> f32 {
+ %0 = ub.poison : f32
+ %1 = arith.negf %0 : f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extsi_poison
+// CHECK: %[[P:.*]] = ub.poison : i64
+// CHECK: return %[[P]]
+func.func @extsi_poison() -> i64 {
+ %0 = ub.poison : i32
+ %1 = arith.extsi %0 : i32 to i64
+ return %1 : i64
+}
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 7a5194b..d24f764 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -483,3 +483,12 @@ func.func @erf_fold_vec() -> (vector<4xf32>) {
%0 = math.erf %v1 : vector<4xf32>
return %0 : vector<4xf32>
}
+
+// CHECK-LABEL: @abs_poison
+// CHECK: %[[P:.*]] = ub.poison : f32
+// CHECK: return %[[P]]
+func.func @abs_poison() -> f32 {
+ %0 = ub.poison : f32
+ %1 = math.absf %0 : f32
+ return %1 : f32
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 52607d7..0200805 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -325,6 +325,15 @@ func.func @const_fold_vector_iadd() -> vector<3xi32> {
return %0: vector<3xi32>
}
+// CHECK-LABEL: @iadd_poison
+// CHECK: %[[P:.*]] = ub.poison : i32
+// CHECK: return %[[P]]
+func.func @iadd_poison(%arg0: i32) -> i32 {
+ %0 = ub.poison : i32
+ %1 = spirv.IAdd %arg0, %0 : i32
+ return %1: i32
+}
+
// -----
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index aec5f32..8edbae3 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1479,3 +1479,16 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
// CHECK: return %[[DIM]]
return %result : index
}
+
+
+// -----
+
+// CHECK-LABEL: @add_poison
+// CHECK: %[[P:.*]] = ub.poison : !shape.siz
+// CHECK: return %[[P]]
+func.func @add_poison() -> !shape.size {
+ %1 = shape.const_size 2
+ %2 = ub.poison : !shape.size
+ %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
+ return %result : !shape.size
+}