aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2025-12-09 08:26:41 +0000
committerMatthias Springer <me@m-sp.org>2025-12-09 08:27:38 +0000
commitbfd97aa344dfd781f4548608c281a074ea61d2fc (patch)
tree3209490093289a50690f66ba25a27a05faf9e033
parent13012fe20816c7ce7a4eec62b325a131b6570593 (diff)
downloadllvm-users/matthias-springer/float_zero_static_utils.zip
llvm-users/matthias-springer/float_zero_static_utils.tar.gz
llvm-users/matthias-springer/float_zero_static_utils.tar.bz2
[mlir][DialectUtils] Add helper for matching zero int/float valuesusers/matthias-springer/float_zero_static_utils
-rw-r--r--mlir/include/mlir/Dialect/Utils/StaticValueUtils.h9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp13
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp14
4 files changed, 27 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2e7f85c..ba8a030 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -24,9 +24,16 @@
namespace mlir {
-/// Return true if `v` is an IntegerAttr with value `0`.
+/// Return "true" if `v` is an integer value/attribute with constant value `0`.
bool isZeroInteger(OpFoldResult v);
+/// Return "true" if `v` is a float value/attribute with constant value `0.0`.
+bool isZeroFloat(OpFoldResult v);
+
+/// Return "true" if `v` is an integer/float value/attribute with constant
+/// value zero.
+bool isZeroIntegerOrFloat(OpFoldResult v);
+
/// Return true if `v` is an IntegerAttr with value `1`.
bool isOneInteger(OpFoldResult v);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
index e940b07..6f81702 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -21,7 +21,7 @@ static bool isDefinedAsZero(Value val) {
// Check whether val is a constant scalar / vector splat / tensor splat float
// or integer zero.
- if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+ if (isZeroIntegerOrFloat(val))
return true;
return TypeSwitch<Operation *, bool>(val.getDefiningOp())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 79f4e7f..24290bd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -39,11 +39,6 @@ using namespace mlir::sparse_tensor;
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//
-// Helper method to match any typed zero.
-static bool isZeroValue(Value val) {
- return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
-}
-
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(Value v) {
auto enc = getSparseTensorEncoding(v.getType());
@@ -59,14 +54,14 @@ static bool isMaterializing(OpOperand *op, bool isZero) {
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Value copy = alloc.getCopy();
if (isZero)
- return copy && isZeroValue(copy);
+ return copy && isZeroIntegerOrFloat(copy);
return !copy;
}
// Check for empty tensor materialization.
if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
return !isZero;
// Last resort for zero alloc: the whole value is zero.
- return isZero && isZeroValue(val);
+ return isZero && isZeroIntegerOrFloat(val);
}
// Helper to detect sampling operation.
@@ -114,10 +109,10 @@ static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
if (arg.getOwner()->getParentOp() == op) {
- return isZeroValue(op->getOperand(arg.getArgNumber()));
+ return isZeroIntegerOrFloat(op->getOperand(arg.getArgNumber()));
}
}
- return isZeroValue(yieldOp.getOperand(0));
+ return isZeroIntegerOrFloat(yieldOp.getOperand(0));
}
/// Populates given sizes array from type (for static sizes) and from
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8d3944f8..089c551c 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -19,6 +19,20 @@ namespace mlir {
bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); }
+bool isZeroFloat(OpFoldResult v) {
+ if (auto attr = dyn_cast<Attribute>(v)) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr))
+ return floatAttr.getValue().isZero();
+ } else {
+ return matchPattern(cast<Value>(v), m_AnyZeroFloat());
+ }
+ return false;
+}
+
+bool isZeroIntegerOrFloat(OpFoldResult v) {
+ return isZeroInteger(v) || isZeroFloat(v);
+}
+
bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,