From c68fcbf785119789ac3354f1d0efa5de0b71ae41 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Fri, 5 Apr 2024 02:25:30 -0700 Subject: =?UTF-8?q?Revert=20"[mlir][SCF]=20`ValueBoundsConstraintSet`:=20S?= =?UTF-8?q?upport=20`scf.if`=20(branches)=20(=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6b30ffef28c35c24bfd8190e06eeaa0c5cd73cbd. --- .../mlir/Interfaces/ValueBoundsOpInterface.h | 43 ++----- .../Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 61 --------- .../Vector/IR/ScalableValueBoundsConstraintSet.cpp | 14 +- mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 143 +++++---------------- .../SCF/value-bounds-op-interface-impl.mlir | 119 +---------------- 5 files changed, 46 insertions(+), 334 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 3543ab5..83107a3 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -203,26 +203,6 @@ public: std::optional dim1 = std::nullopt, std::optional dim2 = std::nullopt); - /// Traverse the IR starting from the given value/dim and populate constraints - /// as long as the stop condition holds. Also process all values/dims that are - /// already on the worklist. - void populateConstraints(Value value, std::optional dim); - - /// Comparison operator for `ValueBoundsConstraintSet::compare`. - enum ComparisonOperator { LT, LE, EQ, GT, GE }; - - /// Try to prove that, based on the current state of this constraint set - /// (i.e., without analyzing additional IR or adding new constraints), the - /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim. - /// - /// Return "true" if the specified relation between the two values/dims was - /// proven to hold. Return "false" if the specified relation could not be - /// proven. This could be because the specified relation does in fact not hold - /// or because there is not enough information in the constraint set. In other - /// words, if we do not know for sure, this function returns "false". - bool compare(Value lhs, std::optional lhsDim, ComparisonOperator cmp, - Value rhs, std::optional rhsDim); - /// Compute whether the given values/dimensions are equal. Return "failure" if /// equality could not be determined. /// @@ -294,13 +274,13 @@ protected: ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); - /// Given an affine map with a single result (and map operands), add a new - /// column to the constraint set that represents the result of the map. - /// Traverse additional IR starting from the map operands as needed (as long - /// as the stop condition is not satisfied). Also process all values/dims that - /// are already on the worklist. Return the position of the newly added - /// column. - int64_t populateConstraints(AffineMap map, ValueDimList mapOperands); + /// Populates the constraint set for a value/map without actually computing + /// the bound. Returns the position for the value/map (via the return value + /// and `posOut` output parameter). + int64_t populateConstraintsSet(Value value, + std::optional dim = std::nullopt); + int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands, + int64_t *posOut = nullptr); /// Iteratively process all elements on the worklist until an index-typed /// value or shaped value meets `stopCondition`. Such values are not processed @@ -315,19 +295,14 @@ protected: /// value/dimension exists in the constraint set. int64_t getPos(Value value, std::optional dim = std::nullopt) const; - /// Return an affine expression that represents column `pos` in the constraint - /// set. - AffineExpr getPosExpr(int64_t pos); - /// Insert a value/dimension into the constraint set. If `isSymbol` is set to /// "false", a dimension is added. The value/dimension is added to the - /// worklist if `addToWorklist` is set. + /// worklist. /// /// Note: There are certain affine restrictions wrt. dimensions. E.g., they /// cannot be multiplied. Furthermore, bounds can only be queried for /// dimensions but not for symbols. - int64_t insert(Value value, std::optional dim, bool isSymbol = true, - bool addToWorklist = true); + int64_t insert(Value value, std::optional dim, bool isSymbol = true); /// Insert an anonymous column into the constraint set. The column is not /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 72c5aaa..8128408 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -103,66 +103,6 @@ struct ForOpInterface } }; -struct IfOpInterface - : public ValueBoundsOpInterface::ExternalModel { - - static void populateBounds(scf::IfOp ifOp, Value value, - std::optional dim, - ValueBoundsConstraintSet &cstr) { - unsigned int resultNum = cast(value).getResultNumber(); - Value thenValue = ifOp.thenYield().getResults()[resultNum]; - Value elseValue = ifOp.elseYield().getResults()[resultNum]; - - // Populate constraints for the yielded value (and all values on the - // backward slice, as long as the current stop condition is not satisfied). - cstr.populateConstraints(thenValue, dim); - cstr.populateConstraints(elseValue, dim); - auto boundsBuilder = cstr.bound(value); - if (dim) - boundsBuilder[*dim]; - - // Compare yielded values. - // If thenValue <= elseValue: - // * result <= elseValue - // * result >= thenValue - if (cstr.compare(thenValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { - if (dim) { - cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); - cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); - } else { - cstr.bound(value) >= thenValue; - cstr.bound(value) <= elseValue; - } - } - // If elseValue <= thenValue: - // * result <= thenValue - // * result >= elseValue - if (cstr.compare(elseValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { - if (dim) { - cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); - cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); - } else { - cstr.bound(value) >= elseValue; - cstr.bound(value) <= thenValue; - } - } - } - - void populateBoundsForIndexValue(Operation *op, Value value, - ValueBoundsConstraintSet &cstr) const { - populateBounds(cast(op), value, /*dim=*/std::nullopt, cstr); - } - - void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, - ValueBoundsConstraintSet &cstr) const { - populateBounds(cast(op), value, dim, cstr); - } -}; - } // namespace } // namespace scf } // namespace mlir @@ -171,6 +111,5 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { scf::ForOp::attachInterface(*ctx); - scf::IfOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp index f8df348..52359fa8 100644 --- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -59,16 +59,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound( ScalableValueBoundsConstraintSet scalableCstr( value.getContext(), stopCondition ? stopCondition : defaultStopCondition, vscaleMin, vscaleMax); - int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false); - scalableCstr.processWorklist(); + int64_t pos = scalableCstr.populateConstraintsSet(value, dim); - // Project out all columns apart from vscale and the starting point - // (value/dim). This should result in constraints in terms of vscale only. + // Project out all variables apart from vscale. + // This should result in constraints in terms of vscale only. auto projectOutFn = [&](ValueDim p) { - bool isStartingPoint = - p.first == value && - p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue); - return p.first != scalableCstr.getVscaleValue() && !isStartingPoint; + return p.first != scalableCstr.getVscaleValue(); }; scalableCstr.projectOut(projectOutFn); @@ -76,7 +72,7 @@ ScalableValueBoundsConstraintSet::computeScalableBound( scalableCstr.positionToValueDim.size() && "inconsistent mapping state"); - // Check that the only columns left are vscale and the starting point. + // Check that the only symbols left are vscale. for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { if (i == pos) continue; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 6e3d6dd..0d362c7 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -110,47 +110,25 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value, assertValidValueDim(value, dim); #endif // NDEBUG - // Check if the value/dim is statically known. In that case, an affine - // constant expression should be returned. This allows us to support - // multiplications with constants. (Multiplications of two columns in the - // constraint set is not supported.) - std::optional constSize = std::nullopt; auto shapedType = dyn_cast(value.getType()); if (shapedType) { + // Static dimension: return constant directly. if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) - constSize = shapedType.getDimSize(*dim); - } else if (auto constInt = ::getConstantIntValue(value)) { - constSize = *constInt; + return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); + } else { + // Constant index value: return directly. + if (auto constInt = ::getConstantIntValue(value)) + return builder.getAffineConstantExpr(*constInt); } - // If the value/dim is already mapped, return the corresponding expression - // directly. + // Dynamic value: add to constraint set. ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - if (valueDimToPosition.contains(valueDim)) { - // If it is a constant, return an affine constant expression. Otherwise, - // return an affine expression that represents the respective column in the - // constraint set. - if (constSize) - return builder.getAffineConstantExpr(*constSize); - return getPosExpr(getPos(value, dim)); - } - - if (constSize) { - // Constant index value/dim: add column to the constraint set, add EQ bound - // and return an affine constant expression without pushing the newly added - // column to the worklist. - (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false); - if (shapedType) - bound(value)[*dim] == *constSize; - else - bound(value) == *constSize; - return builder.getAffineConstantExpr(*constSize); - } - - // Dynamic value/dim: insert column to the constraint set and put it on the - // worklist. Return an affine expression that represents the newly inserted - // column in the constraint set. - return getPosExpr(insert(value, dim, /*isSymbol=*/true)); + if (!valueDimToPosition.contains(valueDim)) + (void)insert(value, dim); + int64_t pos = getPos(value, dim); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); } AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { @@ -167,7 +145,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { int64_t ValueBoundsConstraintSet::insert(Value value, std::optional dim, - bool isSymbol, bool addToWorklist) { + bool isSymbol) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG @@ -182,12 +160,7 @@ int64_t ValueBoundsConstraintSet::insert(Value value, if (positionToValueDim[i].has_value()) valueDimToPosition[*positionToValueDim[i]] = i; - if (addToWorklist) { - LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value - << " (dim: " << dim.value_or(kIndexValue) << ")\n"); - worklist.push(pos); - } - + worklist.push(pos); return pos; } @@ -217,13 +190,6 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, return it->second; } -AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { - assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position"); - return pos < cstr.getNumDimVars() - ? builder.getAffineDimExpr(pos) - : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); -} - static Operation *getOwnerOfValue(Value value) { if (auto bbArg = dyn_cast(value)) return bbArg.getOwner()->getParentOp(); @@ -526,7 +492,7 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( // Default stop condition if none was specified: Keep adding constraints until // a bound could be computed. - int64_t pos = 0; + int64_t pos; auto defaultStopCondition = [&](Value v, std::optional dim, ValueBoundsConstraintSet &cstr) { return cstr.cstr.getConstantBound64(type, pos).has_value(); @@ -534,8 +500,7 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( ValueBoundsConstraintSet cstr( map.getContext(), stopCondition ? stopCondition : defaultStopCondition); - pos = cstr.populateConstraints(map, operands); - assert(pos == 0 && "expected `map` is the first column"); + cstr.populateConstraintsSet(map, operands, &pos); // Compute constant bound for `valueDim`. int64_t ubAdjustment = closedUB ? 0 : 1; @@ -544,28 +509,29 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( return failure(); } -void ValueBoundsConstraintSet::populateConstraints(Value value, - std::optional dim) { +int64_t +ValueBoundsConstraintSet::populateConstraintsSet(Value value, + std::optional dim) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG - // `getExpr` pushes the value/dim onto the worklist (unless it was already - // analyzed). - (void)getExpr(value, dim); - // Process all values/dims on the worklist. This may traverse and analyze - // additional IR, depending the current stop function. - processWorklist(); + AffineMap map = + AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + Builder(value.getContext()).getAffineDimExpr(0)); + return populateConstraintsSet(map, {{value, dim}}); } -int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, - ValueDimList operands) { +int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map, + ValueDimList operands, + int64_t *posOut) { assert(map.getNumResults() == 1 && "expected affine map with one result"); int64_t pos = insert(/*isSymbol=*/false); + if (posOut) + *posOut = pos; // Add map and operands to the constraint set. Dimensions are converted to - // symbols. All operands are added to the worklist (unless they were already - // processed). + // symbols. All operands are added to the worklist. auto mapper = [&](std::pair> v) { return getExpr(v.first, v.second); }; @@ -600,55 +566,6 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, {{value1, dim1}, {value2, dim2}}); } -bool ValueBoundsConstraintSet::compare(Value lhs, std::optional lhsDim, - ComparisonOperator cmp, Value rhs, - std::optional rhsDim) { - // This function returns "true" if "lhs CMP rhs" is proven to hold. - // - // Example for ComparisonOperator::LE and index-typed values: We would like to - // prove that lhs <= rhs. Proof by contradiction: add the inverse - // relation (lhs > rhs) to the constraint set and check if the resulting - // constraint set is "empty" (i.e. has no solution). In that case, - // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds. - - // We cannot prove anything if the constraint set is already empty. - if (cstr.isEmpty()) { - LLVM_DEBUG( - llvm::dbgs() - << "cannot compare value/dims: constraint system is already empty"); - return false; - } - - // EQ can be expressed as LE and GE. - if (cmp == EQ) - return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && - compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); - - // Construct inequality. For the above example: lhs > rhs. - // `IntegerRelation` inequalities are expressed in the "flattened" form and - // with ">= 0". I.e., lhs - rhs - 1 >= 0. - SmallVector eq(cstr.getNumDimAndSymbolVars() + 1, 0); - if (cmp == LT || cmp == LE) { - ++eq[getPos(lhs, lhsDim)]; - --eq[getPos(rhs, rhsDim)]; - } else if (cmp == GT || cmp == GE) { - --eq[getPos(lhs, lhsDim)]; - ++eq[getPos(rhs, rhsDim)]; - } else { - llvm_unreachable("unsupported comparison operator"); - } - if (cmp == LE || cmp == GE) - eq[cstr.getNumDimAndSymbolVars()] -= 1; - - // Add inequality to the constraint set and check if it made the constraint - // set empty. - int64_t ineqPos = cstr.getNumInequalities(); - cstr.addInequality(eq); - bool isEmpty = cstr.isEmpty(); - cstr.removeInequality(ineqPos); - return isEmpty; -} - FailureOr ValueBoundsConstraintSet::areEqual(Value value1, Value value2, std::optional dim1, diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index 0ea0673..e4d7141 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ -// RUN: -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s // CHECK-LABEL: func @scf_for( // CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index @@ -104,118 +104,3 @@ func.func @scf_for_swapping_yield(%t1: tensor, %t2: tensor, %a: in "test.some_use"(%reify1) : (index) -> () return } - -// ----- - -// CHECK-LABEL: func @scf_if_constant( -func.func @scf_if_constant(%c : i1) { - // CHECK: arith.constant 4 : index - // CHECK: arith.constant 9 : index - %c4 = arith.constant 4 : index - %c9 = arith.constant 9 : index - %r = scf.if %c -> index { - scf.yield %c4 : index - } else { - scf.yield %c9 : index - } - - // CHECK: %[[c4:.*]] = arith.constant 4 : index - // CHECK: %[[c10:.*]] = arith.constant 10 : index - %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) - %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index) - // CHECK: "test.some_use"(%[[c4]], %[[c10]]) - "test.some_use"(%reify1, %reify2) : (index, index) -> () - return -} - -// ----- - -// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> -// CHECK-LABEL: func @scf_if_dynamic( -// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1) -func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) { - %c4 = arith.constant 4 : index - %r = scf.if %c -> index { - %add1 = arith.addi %a, %b : index - scf.yield %add1 : index - } else { - %add2 = arith.addi %b, %c4 : index - %add3 = arith.addi %add2, %a : index - scf.yield %add3 : index - } - - // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] - // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]] - %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) - %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index) - // CHECK: "test.some_use"(%[[lb]], %[[ub]]) - "test.some_use"(%reify1, %reify2) : (index, index) -> () - return -} - -// ----- - -func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) { - %r = scf.if %c -> index { - scf.yield %a : index - } else { - scf.yield %b : index - } - // The reified bound would be min(%a, %b). min/max expressions are not - // supported in reified bounds. - // expected-error @below{{could not reify bound}} - %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) - "test.some_use"(%reify1) : (index) -> () - return -} - -// ----- - -// CHECK-LABEL: func @scf_if_tensor_dim( -func.func @scf_if_tensor_dim(%c : i1) { - // CHECK: arith.constant 4 : index - // CHECK: arith.constant 9 : index - %c4 = arith.constant 4 : index - %c9 = arith.constant 9 : index - %t1 = tensor.empty(%c4) : tensor - %t2 = tensor.empty(%c9) : tensor - %r = scf.if %c -> tensor { - scf.yield %t1 : tensor - } else { - scf.yield %t2 : tensor - } - - // CHECK: %[[c4:.*]] = arith.constant 4 : index - // CHECK: %[[c10:.*]] = arith.constant 10 : index - %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0} - : (tensor) -> (index) - %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0} - : (tensor) -> (index) - // CHECK: "test.some_use"(%[[c4]], %[[c10]]) - "test.some_use"(%reify1, %reify2) : (index, index) -> () - return -} - -// ----- - -// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-LABEL: func @scf_if_eq( -// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1) -func.func @scf_if_eq(%a: index, %b: index, %c : i1) { - %c0 = arith.constant 0 : index - %r = scf.if %c -> index { - %add1 = arith.addi %a, %b : index - scf.yield %add1 : index - } else { - %add2 = arith.addi %b, %c0 : index - %add3 = arith.addi %add2, %a : index - scf.yield %add3 : index - } - - // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] - %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index) - // CHECK: "test.some_use"(%[[eq]]) - "test.some_use"(%reify1) : (index) -> () - return -} -- cgit v1.1