aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp')
-rw-r--r--mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 82a9fb0..e93b99b 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -91,6 +91,64 @@ struct AffineMaxOpInterface
};
};
+struct AffineDelinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineDelinearizeIndexOp>(rawOp);
+ auto result = cast<OpResult>(value);
+ assert(result.getOwner() == rawOp &&
+ "bounded value isn't a result of this delinearize_index");
+ unsigned resIdx = result.getResultNumber();
+
+ AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
+
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ AffineExpr divisor = cstr.getExpr(1);
+ for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
+ divisor = divisor * cstr.getExpr(basisElem);
+
+ if (resIdx == 0) {
+ cstr.bound(value) == linearIdx.floorDiv(divisor);
+ if (!basis.front().isNull())
+ cstr.bound(value) < cstr.getExpr(basis.front());
+ return;
+ }
+ AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
+ cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
+ }
+};
+
+struct AffineLinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineLinearizeIndexOp>(rawOp);
+ assert(value == op.getResult() &&
+ "value isn't the result of this linearize");
+
+ AffineExpr bound = cstr.getExpr(0);
+ AffineExpr stride = cstr.getExpr(1);
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ OperandRange multiIndex = op.getMultiIndex();
+ unsigned numArgs = multiIndex.size();
+ for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
+ unsigned argNum = numArgs - (revArgNum + 1);
+ if (argNum == 0)
+ break;
+ OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
+ bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
+ stride = stride * cstr.getExpr(length);
+ }
+ bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
+ cstr.bound(value) == bound;
+ if (op.getDisjoint() && !basis.front().isNull()) {
+ cstr.bound(value) < stride *cstr.getExpr(basis.front());
+ }
+ }
+};
} // namespace
} // namespace mlir
@@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
+ AffineDelinearizeIndexOp::attachInterface<
+ AffineDelinearizeIndexOpInterface>(*ctx);
+ AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
+ *ctx);
});
}