aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp')
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp113
1 files changed, 59 insertions, 54 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c..e28d5122 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
auto tileSliceOffset = tileSliceIndex;
auto baseIndexPlusTileSliceOffset =
- rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+ arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
outIndices.push_back(indices[1]);
@@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
if (memrefIndices.size() != 2)
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
@@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// elements in a vector of SVL bits for a given element type (SVL_B,
// SVL_H, ..., SVL_Q).
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
Value predicate;
Value upperBound;
@@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// The upper bound of the loop must be clamped at `numTileSlices` as
// `vector.create_mask` allows operands to be greater than the size of a
// dimension.
- auto numRowI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), maskDim0);
- auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), numTileSlices);
+ auto numRowI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), maskDim0);
+ auto numTileSlicesI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), numTileSlices);
auto upperBoundI64 =
- rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
- upperBound = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), upperBoundI64);
+ arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
+ upperBound = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), upperBoundI64);
predicate =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
} else {
upperBound = numTileSlices;
// No mask. Create an 'all true' predicate for the tile slice.
- predicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ predicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
}
bool hasCarriedArgs = bool(initTile);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- hasCarriedArgs ? ValueRange{initTile}
- : ValueRange{});
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
rewriter.setInsertionPointToStart(forOp.getBody());
Value tileSliceIndex = forOp.getInductionVar();
@@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
assert(bool(nextTile) == hasCarriedArgs);
if (nextTile)
- rewriter.create<scf::YieldOp>(loc, nextTile);
+ scf::YieldOp::create(rewriter, loc, nextTile);
return forOp;
}
@@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+ initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
} else {
- initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ return arm_sme::LoadTileSliceOp::create(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];
- auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), numCols);
+ auto numColsI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), numCols);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
+ auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
- auto rowIsActive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
- auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), rowIsActive);
- auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
- auto maskIndex =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
+ auto rowIsActive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+ auto rowIsActiveI32 = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), rowIsActive);
+ auto mask =
+ arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
+ auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
- loc, predicateType, maskIndex.getResult());
+ auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
+ maskIndex.getResult());
auto memrefIndices = getMemrefIndices(
tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
@@ -324,17 +327,19 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+ auto pad1DOp =
+ vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
- auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
- loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
+ tileLoadOp.getBase(),
+ memrefIndices, maskOp1D,
+ /*passthru=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
- auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
- tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
+ auto insertSlice = arm_sme::InsertTileSliceOp::create(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
rewriter.setInsertionPointAfter(forOp);