aboutsummaryrefslogtreecommitdiff
path: root/mlir/tools/mlir-tblgen/RewriterGen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/tools/mlir-tblgen/RewriterGen.cpp')
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp49
1 files changed, 26 insertions, 23 deletions
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033d..40bc1a9 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
int depth = 0;
emitMatch(tree, opName, depth);
+ // Some of the operands could be bound to the same symbol name, we need
+ // to enforce equality constraint on those.
+ // This has to happen before user provided constraints, which may assume the
+ // same name checks are already performed, since in the pattern source code
+ // the user provided constraints appear later.
+ // TODO: we should be able to emit equality checks early
+ // and short circuit unnecessary work if vars are not equal.
+ for (auto symbolInfoIt = symbolInfoMap.begin();
+ symbolInfoIt != symbolInfoMap.end();) {
+ auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+ auto startRange = range.first;
+ auto endRange = range.second;
+
+ auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+ for (++startRange; startRange != endRange; ++startRange) {
+ auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+ emitMatchCheck(
+ opName,
+ formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+ formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+ secondOperand));
+ }
+
+ symbolInfoIt = endRange;
+ }
+
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
}
}
- // Some of the operands could be bound to the same symbol name, we need
- // to enforce equality constraint on those.
- // TODO: we should be able to emit equality checks early
- // and short circuit unnecessary work if vars are not equal.
- for (auto symbolInfoIt = symbolInfoMap.begin();
- symbolInfoIt != symbolInfoMap.end();) {
- auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
- auto startRange = range.first;
- auto endRange = range.second;
-
- auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
- for (++startRange; startRange != endRange; ++startRange) {
- auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
- emitMatchCheck(
- opName,
- formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
- formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
- secondOperand));
- }
-
- symbolInfoIt = endRange;
- }
-
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}