aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp')
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp29
1 files changed, 12 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 8525543..fd40e7c 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -177,9 +177,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes =
haloSizes.empty()
- ? rewriter
- .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
- i64)
+ ? tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64)
.getResult()
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
@@ -306,13 +305,11 @@ public:
auto ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
- auto rank =
- rewriter
- .create<mpi::CommRankOp>(
- loc,
- TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
- commWorld)
- .getRank();
+ auto rank = mpi::CommRankOp::create(
+ rewriter, loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -703,10 +700,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz))
- sz =
- rewriter
- .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
- .getResult();
+ sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
+ value)
+ .getResult();
}
// most of the offset/size/stride data is the same for all dims
@@ -758,9 +754,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
- auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex,
- splitAxes)
+ auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
+ myMultiIndex, splitAxes)
.getResults();
// MPI operates on i32...
Value neighbourIDs[2] = {