diff options
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp')
-rw-r--r-- | mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index 82a9fb0..e93b99b 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -91,6 +91,64 @@ struct AffineMaxOpInterface }; }; +struct AffineDelinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast<AffineDelinearizeIndexOp>(rawOp); + auto result = cast<OpResult>(value); + assert(result.getOwner() == rawOp && + "bounded value isn't a result of this delinearize_index"); + unsigned resIdx = result.getResultNumber(); + + AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex()); + + SmallVector<OpFoldResult> basis = op.getPaddedBasis(); + AffineExpr divisor = cstr.getExpr(1); + for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1)) + divisor = divisor * cstr.getExpr(basisElem); + + if (resIdx == 0) { + cstr.bound(value) == linearIdx.floorDiv(divisor); + if (!basis.front().isNull()) + cstr.bound(value) < cstr.getExpr(basis.front()); + return; + } + AffineExpr thisBasis = cstr.getExpr(basis[resIdx]); + cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor); + } +}; + +struct AffineLinearizeIndexOpInterface + : public ValueBoundsOpInterface::ExternalModel< + AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> { + void populateBoundsForIndexValue(Operation *rawOp, Value value, + ValueBoundsConstraintSet &cstr) const { + auto op = cast<AffineLinearizeIndexOp>(rawOp); + assert(value == op.getResult() && + "value isn't the result of this linearize"); + + AffineExpr bound = cstr.getExpr(0); + AffineExpr stride = cstr.getExpr(1); + SmallVector<OpFoldResult> basis = op.getPaddedBasis(); + OperandRange multiIndex = op.getMultiIndex(); + unsigned numArgs = multiIndex.size(); + for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) { + unsigned argNum = numArgs - (revArgNum + 1); + if (argNum == 0) + break; + OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]); + bound = bound + cstr.getExpr(indexAsFoldRes) * stride; + stride = stride * cstr.getExpr(length); + } + bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride; + cstr.bound(value) == bound; + if (op.getDisjoint() && !basis.front().isNull()) { + cstr.bound(value) < stride *cstr.getExpr(basis.front()); + } + } +}; } // namespace } // namespace mlir @@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx); AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx); AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx); + AffineDelinearizeIndexOp::attachInterface< + AffineDelinearizeIndexOpInterface>(*ctx); + AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>( + *ctx); }); } |