From fca8ef57d8f0c9431866c1d7988dfb6c7fb7e57f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 25 Mar 2024 09:20:52 +0000 Subject: [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Add columns for constant values/dims `ValueBoundsConstraintSet` maintains an internal constraint set (`IntegerRelation`), where every analyzed index-typed SSA value or dimension of a shaped type is represented with a dimension/symbol. Prior to this change, index-typed values with a statically known constant value and static shaped type dimensions were not added to the constraint set. Instead, `getExpr` directly returned an affine constrant expression. With this commit, dynamic and static values/dimension sizes are treated in the same way: in either case, a dimension/symbol is added to the constraint set. This is needed for a subsequent commit that adds support for branches. --- .../mlir/Interfaces/ValueBoundsOpInterface.h | 9 ++- mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 64 +++++++++++++++++----- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 83107a3..651659b 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -295,14 +295,19 @@ protected: /// value/dimension exists in the constraint set. int64_t getPos(Value value, std::optional dim = std::nullopt) const; + /// Return an affine expression that represents column `pos` in the constraint + /// set. + AffineExpr getPosExpr(int64_t pos); + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to /// "false", a dimension is added. The value/dimension is added to the - /// worklist. + /// worklist if `addToWorklist` is set. /// /// Note: There are certain affine restrictions wrt. dimensions. E.g., they /// cannot be multiplied. Furthermore, bounds can only be queried for /// dimensions but not for symbols. - int64_t insert(Value value, std::optional dim, bool isSymbol = true); + int64_t insert(Value value, std::optional dim, bool isSymbol = true, + bool addToWorklist = true); /// Insert an anonymous column into the constraint set. The column is not /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 0d362c7..3965cb4 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value, assertValidValueDim(value, dim); #endif // NDEBUG + // Check if the value/dim is statically known. In that case, an affine + // constant expression should be returned. This allows us to support + // multiplications with constants. (Multiplications of two columns in the + // constraint set is not supported.) + std::optional constSize = std::nullopt; auto shapedType = dyn_cast(value.getType()); if (shapedType) { - // Static dimension: return constant directly. if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) - return builder.getAffineConstantExpr(shapedType.getDimSize(*dim)); - } else { - // Constant index value: return directly. - if (auto constInt = ::getConstantIntValue(value)) - return builder.getAffineConstantExpr(*constInt); + constSize = shapedType.getDimSize(*dim); + } else if (auto constInt = ::getConstantIntValue(value)) { + constSize = *constInt; } - // Dynamic value: add to constraint set. + // If the value/dim is already mapped, return the corresponding expression + // directly. ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - if (!valueDimToPosition.contains(valueDim)) - (void)insert(value, dim); - int64_t pos = getPos(value, dim); - return pos < cstr.getNumDimVars() - ? builder.getAffineDimExpr(pos) - : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); + if (valueDimToPosition.contains(valueDim)) { + // If it is a constant, return an affine constant expression. Otherwise, + // return an affine expression that represents the respective column in the + // constraint set. + if (constSize) + return builder.getAffineConstantExpr(*constSize); + return getPosExpr(getPos(value, dim)); + } + + if (constSize) { + // Constant index value/dim: add column to the constraint set, add EQ bound + // and return an affine constant expression without pushing the newly added + // column to the worklist. + (void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false); + if (shapedType) + bound(value)[*dim] == *constSize; + else + bound(value) == *constSize; + return builder.getAffineConstantExpr(*constSize); + } + + // Dynamic value/dim: insert column to the constraint set and put it on the + // worklist. Return an affine expression that represents the newly inserted + // column in the constraint set. + return getPosExpr(insert(value, dim, /*isSymbol=*/true)); } AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) { @@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) { int64_t ValueBoundsConstraintSet::insert(Value value, std::optional dim, - bool isSymbol) { + bool isSymbol, bool addToWorklist) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG @@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value, if (positionToValueDim[i].has_value()) valueDimToPosition[*positionToValueDim[i]] = i; - worklist.push(pos); + if (addToWorklist) { + LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value + << " (dim: " << dim.value_or(kIndexValue) << ")\n"); + worklist.push(pos); + } + return pos; } @@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, return it->second; } +AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { + assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position"); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); +} + static Operation *getOwnerOfValue(Value value) { if (auto bbArg = dyn_cast(value)) return bbArg.getOwner()->getParentOp(); -- cgit v1.1