diff options
| author | Matthias Springer <me@m-sp.org> | 2025-12-09 08:26:41 +0000 |
|---|---|---|
| committer | Matthias Springer <me@m-sp.org> | 2025-12-09 08:27:38 +0000 |
| commit | bfd97aa344dfd781f4548608c281a074ea61d2fc (patch) | |
| tree | 3209490093289a50690f66ba25a27a05faf9e033 | |
| parent | 13012fe20816c7ce7a4eec62b325a131b6570593 (diff) | |
| download | llvm-users/matthias-springer/float_zero_static_utils.zip llvm-users/matthias-springer/float_zero_static_utils.tar.gz llvm-users/matthias-springer/float_zero_static_utils.tar.bz2 | |
[mlir][DialectUtils] Add helper for matching zero int/float valuesusers/matthias-springer/float_zero_static_utils
4 files changed, 27 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 2e7f85c..ba8a030 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -24,9 +24,16 @@ namespace mlir { -/// Return true if `v` is an IntegerAttr with value `0`. +/// Return "true" if `v` is an integer value/attribute with constant value `0`. bool isZeroInteger(OpFoldResult v); +/// Return "true" if `v` is a float value/attribute with constant value `0.0`. +bool isZeroFloat(OpFoldResult v); + +/// Return "true" if `v` is an integer/float value/attribute with constant +/// value zero. +bool isZeroIntegerOrFloat(OpFoldResult v); + /// Return true if `v` is an IntegerAttr with value `1`. bool isOneInteger(OpFoldResult v); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp index e940b07..6f81702 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp @@ -21,7 +21,7 @@ static bool isDefinedAsZero(Value val) { // Check whether val is a constant scalar / vector splat / tensor splat float // or integer zero. - if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero())) + if (isZeroIntegerOrFloat(val)) return true; return TypeSwitch<Operation *, bool>(val.getDefiningOp()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 79f4e7f..24290bd 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -39,11 +39,6 @@ using namespace mlir::sparse_tensor; // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// -// Helper method to match any typed zero. -static bool isZeroValue(Value val) { - return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); -} - // Helper to detect a sparse tensor type operand. static bool isSparseTensor(Value v) { auto enc = getSparseTensorEncoding(v.getType()); @@ -59,14 +54,14 @@ static bool isMaterializing(OpOperand *op, bool isZero) { if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { Value copy = alloc.getCopy(); if (isZero) - return copy && isZeroValue(copy); + return copy && isZeroIntegerOrFloat(copy); return !copy; } // Check for empty tensor materialization. if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) return !isZero; // Last resort for zero alloc: the whole value is zero. - return isZero && isZeroValue(val); + return isZero && isZeroIntegerOrFloat(val); } // Helper to detect sampling operation. @@ -114,10 +109,10 @@ static bool isZeroYield(GenericOp op) { auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) { if (arg.getOwner()->getParentOp() == op) { - return isZeroValue(op->getOperand(arg.getArgNumber())); + return isZeroIntegerOrFloat(op->getOperand(arg.getArgNumber())); } } - return isZeroValue(yieldOp.getOperand(0)); + return isZeroIntegerOrFloat(yieldOp.getOperand(0)); } /// Populates given sizes array from type (for static sizes) and from diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 8d3944f8..089c551c 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -19,6 +19,20 @@ namespace mlir { bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); } +bool isZeroFloat(OpFoldResult v) { + if (auto attr = dyn_cast<Attribute>(v)) { + if (auto floatAttr = dyn_cast<FloatAttr>(attr)) + return floatAttr.getValue().isZero(); + } else { + return matchPattern(cast<Value>(v), m_AnyZeroFloat()); + } + return false; +} + +bool isZeroIntegerOrFloat(OpFoldResult v) { + return isZeroInteger(v) || isZeroFloat(v); +} + bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); } std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
