diff options
Diffstat (limited to 'mlir/test/lib')
-rw-r--r-- | mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp | 26 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestDialect.cpp | 37 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 16 |
3 files changed, 58 insertions, 21 deletions
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 6730f9b..b098a5a 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, FailureOr<OpFoldResult> reified = failure(); if (constant) { auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( - boundType, value, dim, /*stopCondition=*/nullptr); + boundType, {value, dim}, /*stopCondition=*/nullptr); if (succeeded(reifiedConst)) reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst)); } else if (scalable) { @@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, rewriter, loc, reifiedScalable->map, vscaleOperand); } } else { - if (dim) { - if (useArithOps) { - reified = arith::reifyShapedValueDimBound( - rewriter, op->getLoc(), boundType, value, *dim, stopCondition); - } else { - reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType, - value, *dim, stopCondition); - } + if (useArithOps) { + reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType, + op.getVariable(), stopCondition); } else { - if (useArithOps) { - reified = arith::reifyIndexValueBound( - rewriter, op->getLoc(), boundType, value, stopCondition); - } else { - reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType, - value, stopCondition); - } + reified = reifyValueBound(rewriter, op->getLoc(), boundType, + op.getVariable(), stopCondition); } } if (failed(reified)) { @@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) { } auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { - return ValueBoundsConstraintSet::compare( - /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp, - /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt); + return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs()); }; if (compare(cmpType)) { op->emitRemark("true"); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 25c5190..a23ed89 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() { return success(); } +::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { + if (getDim().has_value()) + return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); + return ValueBoundsConstraintSet::Variable(getVar()); +} + ::mlir::ValueBoundsConstraintSet::ComparisonOperator CompareOp::getComparisonOperator() { if (getCmp() == "EQ") @@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() { llvm_unreachable("invalid comparison operator"); } +::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { + if (!getLhsMap()) + return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); + SmallVector<Value> mapOperands( + getVarOperands().slice(0, getLhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); +} + +::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { + int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + if (!getRhsMap()) + return ValueBoundsConstraintSet::Variable( + getVarOperands()[rhsOperandsBegin]); + SmallVector<Value> mapOperands( + getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); +} + +LogicalResult CompareOp::verify() { + if (getCompose() && (getLhsMap() || getRhsMap())) + return emitOpError( + "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); + int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; + if (getVarOperands().size() != size_t(expectedNumOperands)) + return emitOpError("expected ") + << expectedNumOperands << " operands, but got " + << getVarOperands().size(); + return success(); +} + //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ebf158b..b641b3d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> { let extraClassDeclaration = [{ ::mlir::presburger::BoundType getBoundType(); + ::mlir::ValueBoundsConstraintSet::Variable getVariable(); }]; let hasVerifier = 1; @@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> { Compare `lhs` and `rhs`. A remark is emitted which indicates whether the specified comparison operator was proven to hold. The remark also indicates whether the opposite comparison operator was proven to hold. + + `var_operands` must have exactly two operands: one for the LHS operand and + one for the RHS operand. If `lhs_map` is specified, as many operands as + `lhs_map` has inputs are expected instead of the first operand. If `rhs_map` + is specified, as many operands as `rhs_map` has inputs are expected instead + of the second operand. }]; - let arguments = (ins Index:$lhs, - Index:$rhs, + let arguments = (ins Variadic<Index>:$var_operands, DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp, + OptionalAttr<AffineMapAttr>:$lhs_map, + OptionalAttr<AffineMapAttr>:$rhs_map, UnitAttr:$compose); let results = (outs); let extraClassDeclaration = [{ ::mlir::ValueBoundsConstraintSet::ComparisonOperator getComparisonOperator(); + ::mlir::ValueBoundsConstraintSet::Variable getLhs(); + ::mlir::ValueBoundsConstraintSet::Variable getRhs(); }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// |