diff options
author | Aviad Cohen <aviadcohen7@gmail.com> | 2025-06-04 10:01:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-04 10:01:20 +0300 |
commit | 4c6449044a943441adf160cb77010c855dcee73c (patch) | |
tree | 148888d4c80083913b64268cdb701aef0daeb4ee | |
parent | 3894bdc3c94eae51e7587cb03f456a71fd03d0e1 (diff) | |
download | llvm-4c6449044a943441adf160cb77010c855dcee73c.zip llvm-4c6449044a943441adf160cb77010c855dcee73c.tar.gz llvm-4c6449044a943441adf160cb77010c855dcee73c.tar.bz2 |
[mlir]: Added properties/attributes ignore flags to OperationEquivalence (#142623)
Those flags are useful for cases and operation which we may consider equivalent even when their attributes/properties are not the same.
-rw-r--r-- | mlir/include/mlir/IR/OperationSupport.h | 9 | ||||
-rw-r--r-- | mlir/lib/IR/OperationSupport.cpp | 21 | ||||
-rw-r--r-- | mlir/unittests/IR/OperationSupportTest.cpp | 25 |
3 files changed, 48 insertions, 7 deletions
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 0046d97..65e6d4f 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1322,7 +1322,14 @@ struct OperationEquivalence { // When provided, the location attached to the operation are ignored. IgnoreLocations = 1, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations) + // When provided, the discardable attributes attached to the operation are + // ignored. + IgnoreDiscardableAttrs = 2, + + // When provided, the properties attached to the operation are ignored. + IgnoreProperties = 4, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreProperties) }; /// Compute a hash for the given operation. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 7c9e6c8..d3e9aeb 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -680,9 +680,13 @@ llvm::hash_code OperationEquivalence::computeHash( // - Operation Name // - Attributes // - Result Types + DictionaryAttr dictAttrs; + if (!(flags & Flags::IgnoreDiscardableAttrs)) + dictAttrs = op->getRawDictionaryAttrs(); llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(), - op->getResultTypes(), op->hashProperties()); + llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes()); + if (!(flags & Flags::IgnoreProperties)) + hash = llvm::hash_combine(hash, op->hashProperties()); // - Location if required if (!(flags & Flags::IgnoreLocations)) @@ -836,14 +840,19 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs, return true; // 1. Compare the operation properties. + if (!(flags & IgnoreDiscardableAttrs) && + lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs()) + return false; + if (lhs->getName() != rhs->getName() || - lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults() || - !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(), - rhs->getPropertiesStorage())) + lhs->getNumResults() != rhs->getNumResults()) + return false; + if (!(flags & IgnoreProperties) && + !(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(), + rhs->getPropertiesStorage()))) return false; if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index bac2b72..18ee9d7 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -315,6 +315,7 @@ TEST(OperandStorageTest, PopulateDefaultAttrs) { TEST(OperationEquivalenceTest, HashWorksWithFlags) { MLIRContext context; context.getOrLoadDialect<test::TestDialect>(); + OpBuilder b(&context); auto *op1 = createOp(&context); // `op1` has an unknown loc. @@ -325,12 +326,36 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) { op, OperationEquivalence::ignoreHashValue, OperationEquivalence::ignoreHashValue, flags); }; + // Check ignore location. EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations), getHash(op2, OperationEquivalence::IgnoreLocations)); EXPECT_NE(getHash(op1, OperationEquivalence::None), getHash(op2, OperationEquivalence::None)); + op1->setLoc(NameLoc::get(StringAttr::get(&context, "foo"))); + // Check ignore discardable dictionary attributes. + SmallVector<NamedAttribute> newAttrs = { + b.getNamedAttr("foo", b.getStringAttr("f"))}; + op1->setAttrs(newAttrs); + EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreDiscardableAttrs), + getHash(op2, OperationEquivalence::IgnoreDiscardableAttrs)); + EXPECT_NE(getHash(op1, OperationEquivalence::None), + getHash(op2, OperationEquivalence::None)); op1->destroy(); op2->destroy(); + + // Check ignore properties. + auto req1 = b.getI32IntegerAttr(10); + Operation *opWithProperty1 = b.create<test::OpAttrMatch1>( + b.getUnknownLoc(), req1, nullptr, nullptr, req1); + auto req2 = b.getI32IntegerAttr(60); + Operation *opWithProperty2 = b.create<test::OpAttrMatch1>( + b.getUnknownLoc(), req2, nullptr, nullptr, req2); + EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties), + getHash(opWithProperty2, OperationEquivalence::IgnoreProperties)); + EXPECT_NE(getHash(opWithProperty1, OperationEquivalence::None), + getHash(opWithProperty2, OperationEquivalence::None)); + opWithProperty1->destroy(); + opWithProperty2->destroy(); } } // namespace |