aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test/TestDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestDialect.cpp')
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp37
1 files changed, 37 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190..a23ed89 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
return success();
}
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
::mlir::ValueBoundsConstraintSet::ComparisonOperator
CompareOp::getComparisonOperator() {
if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
llvm_unreachable("invalid comparison operator");
}
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//