aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp')
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp250
1 files changed, 125 insertions, 125 deletions
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index bbe1490..0ff9fb3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
@@ -82,40 +81,40 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
ValueRange rankDiffs, Value outputDimension) {
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Value broadcastedDim = one;
for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
Value shape = std::get<0>(tup);
Value rankDiff = std::get<1>(tup);
- Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
- outputDimension, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
+ outputDimension, rankDiff);
Type indexTy = lb.getIndexType();
broadcastedDim =
- lb.create<IfOp>(
- outOfBounds,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, broadcastedDim);
- },
- [&](OpBuilder &b, Location loc) {
- // The broadcasting logic is:
- // - if one extent (here we arbitrarily choose the
- // extent from the greater-rank operand) is equal to 1,
- // then take the extent from the other operand
- // - otherwise, take the extent as-is.
- // Note that this logic remains correct in the presence
- // of dimensions of zero extent.
- Value lesserRankOperandDimension = b.create<arith::SubIOp>(
- loc, indexTy, outputDimension, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{lesserRankOperandDimension});
-
- Value dimIsOne =
- b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- lesserRankOperandExtent, one);
- Value dim = b.create<arith::SelectOp>(
- loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
- b.create<scf::YieldOp>(loc, dim);
- })
+ IfOp::create(
+ lb, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ scf::YieldOp::create(b, loc, broadcastedDim);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // The broadcasting logic is:
+ // - if one extent (here we arbitrarily choose the
+ // extent from the greater-rank operand) is equal to 1,
+ // then take the extent from the other operand
+ // - otherwise, take the extent as-is.
+ // Note that this logic remains correct in the presence
+ // of dimensions of zero extent.
+ Value lesserRankOperandDimension = arith::SubIOp::create(
+ b, loc, indexTy, outputDimension, rankDiff);
+ Value lesserRankOperandExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{lesserRankOperandDimension});
+
+ Value dimIsOne =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ lesserRankOperandExtent, one);
+ Value dim = arith::SelectOp::create(
+ b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
+ scf::YieldOp::create(b, loc, dim);
+ })
.getResult(0);
}
return broadcastedDim;
@@ -133,7 +132,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -141,31 +140,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
- Value replacement = lb.create<tensor::GenerateOp>(
- getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ Value replacement = tensor::GenerateOp::create(
+ lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value broadcastedDim =
getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
rankDiffs, args[0]);
- b.create<tensor::YieldOp>(loc, broadcastedDim);
+ tensor::YieldOp::create(b, loc, broadcastedDim);
});
if (replacement.getType() != op.getType())
- replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+ replacement = tensor::CastOp::create(lb, op.getType(), replacement);
rewriter.replaceOp(op, replacement);
return success();
}
@@ -193,13 +192,13 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
auto loc = op.getLoc();
SmallVector<Value, 4> extentOperands;
for (auto extent : op.getShape()) {
- extentOperands.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
+ extentOperands.push_back(arith::ConstantIndexOp::create(
+ rewriter, loc, extent.getLimitedValue()));
}
Type resultTy =
RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
Value tensor =
- rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
+ tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@@ -245,8 +244,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
- Value zero = lb.create<arith::ConstantIndexOp>(0);
- Value one = lb.create<arith::ConstantIndexOp>(1);
+ Value zero = arith::ConstantIndexOp::create(lb, 0);
+ Value one = arith::ConstantIndexOp::create(lb, 1);
Type indexTy = lb.getIndexType();
// Save all the ranks for bounds checking. Because this is a tensor
@@ -254,26 +253,26 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
- return lb.create<tensor::DimOp>(v, zero);
+ return tensor::DimOp::create(lb, v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
+ maxRank = arith::MaxUIOp::create(lb, v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
- return lb.create<arith::SubIOp>(indexTy, maxRank, v);
+ return arith::SubIOp::create(lb, indexTy, maxRank, v);
}));
Type i1Ty = rewriter.getI1Type();
- Value trueVal =
- rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
+ Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
+ rewriter.getBoolAttr(true));
- auto reduceResult = lb.create<ForOp>(
- loc, zero, maxRank, one, ValueRange{trueVal},
+ auto reduceResult = ForOp::create(
+ lb, loc, zero, maxRank, one, ValueRange{trueVal},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
// Find a non-1 dim, if it exists. Note that the first part of this
// could reuse the Broadcast lowering entirely, but we redo the work
@@ -285,38 +284,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
Value shape, rankDiff;
std::tie(shape, rankDiff) = tup;
- Value outOfBounds = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, iv, rankDiff);
+ Value outOfBounds = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
broadcastable =
- b.create<IfOp>(
- loc, outOfBounds,
- [&](OpBuilder &b, Location loc) {
- // Non existent dimensions are always broadcastable
- b.create<scf::YieldOp>(loc, broadcastable);
- },
- [&](OpBuilder &b, Location loc) {
- // Every value needs to be either 1, or the same non-1
- // value to be broadcastable in this dim.
- Value operandDimension =
- b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
- Value dimensionExtent = b.create<tensor::ExtractOp>(
- loc, shape, ValueRange{operandDimension});
-
- Value equalOne = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent, one);
- Value equalBroadcasted = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, dimensionExtent,
- broadcastedDim);
- Value result = b.create<arith::AndIOp>(
- loc, broadcastable,
- b.create<arith::OrIOp>(loc, equalOne,
- equalBroadcasted));
- b.create<scf::YieldOp>(loc, result);
- })
+ IfOp::create(
+ b, loc, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ // Non existent dimensions are always broadcastable
+ scf::YieldOp::create(b, loc, broadcastable);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Every value needs to be either 1, or the same non-1
+ // value to be broadcastable in this dim.
+ Value operandDimension =
+ arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
+ Value dimensionExtent = tensor::ExtractOp::create(
+ b, loc, shape, ValueRange{operandDimension});
+
+ Value equalOne = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
+ Value equalBroadcasted =
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
+ dimensionExtent, broadcastedDim);
+ Value result = arith::AndIOp::create(
+ b, loc, broadcastable,
+ arith::OrIOp::create(b, loc, equalOne,
+ equalBroadcasted));
+ scf::YieldOp::create(b, loc, result);
+ })
.getResult(0);
}
- b.create<scf::YieldOp>(loc, broadcastable);
+ scf::YieldOp::create(b, loc, broadcastable);
});
rewriter.replaceOp(op, reduceResult.getResults().front());
@@ -339,7 +338,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
// Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
// lowerings. This can be further optimized if needed to avoid intermediate
// steps.
- auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
+ auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
op.getIndex());
return success();
@@ -421,16 +420,17 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
Type indexTy = rewriter.getIndexType();
Value rank =
- rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
- auto loop = rewriter.create<scf::ForOp>(
- loc, zero, rank, one, op.getInitVals(),
+ auto loop = scf::ForOp::create(
+ rewriter, loc, zero, rank, one, op.getInitVals(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
+ Value extent =
+ tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
SmallVector<Value, 2> mappedValues{iv, extent};
mappedValues.append(args.begin(), args.end());
@@ -444,7 +444,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
SmallVector<Value, 2> mappedResults;
for (auto result : reduceBody->getTerminator()->getOperands())
mappedResults.push_back(mapping.lookup(result));
- b.create<scf::YieldOp>(loc, mappedResults);
+ scf::YieldOp::create(b, loc, mappedResults);
});
rewriter.replaceOp(op, loop.getResults());
@@ -507,44 +507,44 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value firstShape = adaptor.getShapes().front();
Value firstRank =
- rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
+ tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
Value result = nullptr;
// Generate a linear sequence of compares, all with firstShape as lhs.
for (Value shape : adaptor.getShapes().drop_front(1)) {
- Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
- Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- firstRank, rank);
- auto same = rewriter.create<IfOp>(
- loc, eqRank,
+ Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
+ Value eqRank = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
+ auto same = IfOp::create(
+ rewriter, loc, eqRank,
[&](OpBuilder &b, Location loc) {
- Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value one = arith::ConstantIndexOp::create(b, loc, 1);
Value init =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
- auto loop = b.create<scf::ForOp>(
- loc, zero, firstRank, one, ValueRange{init},
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true));
+ auto loop = scf::ForOp::create(
+ b, loc, zero, firstRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
- b.create<tensor::ExtractOp>(loc, firstShape, iv);
- Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
- Value eqExtent = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
- Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
- b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+ tensor::ExtractOp::create(b, loc, firstShape, iv);
+ Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
+ Value eqExtent = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
+ Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
+ scf::YieldOp::create(b, loc, ValueRange({conjNext}));
});
- b.create<scf::YieldOp>(loc, loop.getResults());
+ scf::YieldOp::create(b, loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result =
- b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
- b.create<scf::YieldOp>(loc, result);
+ arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false));
+ scf::YieldOp::create(b, loc, result);
});
result = !result ? same.getResult(0)
- : rewriter.create<arith::AndIOp>(loc, result,
- same.getResult(0));
+ : arith::AndIOp::create(rewriter, loc, result,
+ same.getResult(0));
}
rewriter.replaceOp(op, result);
return success();
@@ -581,18 +581,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
- Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
+ Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
extentValues.push_back(extent);
} else {
- Value extent = rewriter.create<arith::ConstantIndexOp>(
- loc, rankedTensorTy.getDimSize(i));
+ Value extent = arith::ConstantIndexOp::create(
+ rewriter, loc, rankedTensorTy.getDimSize(i));
extentValues.push_back(extent);
}
}
// Materialize extent tensor.
- Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
+ Value staticExtentTensor = tensor::FromElementsOp::create(
+ rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
extentValues);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
@@ -601,13 +601,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Lower to `tensor.generate` otherwise.
auto *ctx = rewriter.getContext();
- Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
+ Value rank = tensor::RankOp::create(rewriter, loc, tensor);
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
- Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
- b.create<tensor::YieldOp>(loc, extent);
+ Value extent = tensor::DimOp::create(b, loc, tensor, dim);
+ tensor::YieldOp::create(b, loc, extent);
});
return success();
@@ -634,22 +634,22 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value zero = b.create<arith::ConstantIndexOp>(0);
- Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
+ Value zero = arith::ConstantIndexOp::create(b, 0);
+ Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
// index < 0 ? index + rank : index
Value originalIndex = adaptor.getIndex();
- Value add = b.create<arith::AddIOp>(originalIndex, rank);
+ Value add = arith::AddIOp::create(b, originalIndex, rank);
Value indexIsNegative =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
- Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
+ Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex);
- Value one = b.create<arith::ConstantIndexOp>(1);
+ Value one = arith::ConstantIndexOp::create(b, 1);
Value head =
- b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
- Value tailSize = b.create<arith::SubIOp>(rank, index);
- Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
- tailSize, one);
+ tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
+ Value tailSize = arith::SubIOp::create(b, rank, index);
+ Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
+ tailSize, one);
rewriter.replaceOp(op, {head, tail});
return success();
}