aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAviad Cohen <aviadcohen7@gmail.com>2025-06-04 10:01:20 +0300
committerGitHub <noreply@github.com>2025-06-04 10:01:20 +0300
commit4c6449044a943441adf160cb77010c855dcee73c (patch)
tree148888d4c80083913b64268cdb701aef0daeb4ee
parent3894bdc3c94eae51e7587cb03f456a71fd03d0e1 (diff)
downloadllvm-4c6449044a943441adf160cb77010c855dcee73c.zip
llvm-4c6449044a943441adf160cb77010c855dcee73c.tar.gz
llvm-4c6449044a943441adf160cb77010c855dcee73c.tar.bz2
[mlir]: Added properties/attributes ignore flags to OperationEquivalence (#142623)
Those flags are useful for cases and operation which we may consider equivalent even when their attributes/properties are not the same.
-rw-r--r--mlir/include/mlir/IR/OperationSupport.h9
-rw-r--r--mlir/lib/IR/OperationSupport.cpp21
-rw-r--r--mlir/unittests/IR/OperationSupportTest.cpp25
3 files changed, 48 insertions, 7 deletions
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 0046d97..65e6d4f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1322,7 +1322,14 @@ struct OperationEquivalence {
// When provided, the location attached to the operation are ignored.
IgnoreLocations = 1,
- LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations)
+ // When provided, the discardable attributes attached to the operation are
+ // ignored.
+ IgnoreDiscardableAttrs = 2,
+
+ // When provided, the properties attached to the operation are ignored.
+ IgnoreProperties = 4,
+
+ LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreProperties)
};
/// Compute a hash for the given operation.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7c9e6c8..d3e9aeb 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -680,9 +680,13 @@ llvm::hash_code OperationEquivalence::computeHash(
// - Operation Name
// - Attributes
// - Result Types
+ DictionaryAttr dictAttrs;
+ if (!(flags & Flags::IgnoreDiscardableAttrs))
+ dictAttrs = op->getRawDictionaryAttrs();
llvm::hash_code hash =
- llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
- op->getResultTypes(), op->hashProperties());
+ llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
+ if (!(flags & Flags::IgnoreProperties))
+ hash = llvm::hash_combine(hash, op->hashProperties());
// - Location if required
if (!(flags & Flags::IgnoreLocations))
@@ -836,14 +840,19 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
return true;
// 1. Compare the operation properties.
+ if (!(flags & IgnoreDiscardableAttrs) &&
+ lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
+ return false;
+
if (lhs->getName() != rhs->getName() ||
- lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
lhs->getNumRegions() != rhs->getNumRegions() ||
lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
lhs->getNumOperands() != rhs->getNumOperands() ||
- lhs->getNumResults() != rhs->getNumResults() ||
- !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
- rhs->getPropertiesStorage()))
+ lhs->getNumResults() != rhs->getNumResults())
+ return false;
+ if (!(flags & IgnoreProperties) &&
+ !(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
+ rhs->getPropertiesStorage())))
return false;
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index bac2b72..18ee9d7 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -315,6 +315,7 @@ TEST(OperandStorageTest, PopulateDefaultAttrs) {
TEST(OperationEquivalenceTest, HashWorksWithFlags) {
MLIRContext context;
context.getOrLoadDialect<test::TestDialect>();
+ OpBuilder b(&context);
auto *op1 = createOp(&context);
// `op1` has an unknown loc.
@@ -325,12 +326,36 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
op, OperationEquivalence::ignoreHashValue,
OperationEquivalence::ignoreHashValue, flags);
};
+ // Check ignore location.
EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations),
getHash(op2, OperationEquivalence::IgnoreLocations));
EXPECT_NE(getHash(op1, OperationEquivalence::None),
getHash(op2, OperationEquivalence::None));
+ op1->setLoc(NameLoc::get(StringAttr::get(&context, "foo")));
+ // Check ignore discardable dictionary attributes.
+ SmallVector<NamedAttribute> newAttrs = {
+ b.getNamedAttr("foo", b.getStringAttr("f"))};
+ op1->setAttrs(newAttrs);
+ EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreDiscardableAttrs),
+ getHash(op2, OperationEquivalence::IgnoreDiscardableAttrs));
+ EXPECT_NE(getHash(op1, OperationEquivalence::None),
+ getHash(op2, OperationEquivalence::None));
op1->destroy();
op2->destroy();
+
+ // Check ignore properties.
+ auto req1 = b.getI32IntegerAttr(10);
+ Operation *opWithProperty1 = b.create<test::OpAttrMatch1>(
+ b.getUnknownLoc(), req1, nullptr, nullptr, req1);
+ auto req2 = b.getI32IntegerAttr(60);
+ Operation *opWithProperty2 = b.create<test::OpAttrMatch1>(
+ b.getUnknownLoc(), req2, nullptr, nullptr, req2);
+ EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties),
+ getHash(opWithProperty2, OperationEquivalence::IgnoreProperties));
+ EXPECT_NE(getHash(opWithProperty1, OperationEquivalence::None),
+ getHash(opWithProperty2, OperationEquivalence::None));
+ opWithProperty1->destroy();
+ opWithProperty2->destroy();
}
} // namespace