diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2023-12-06 14:31:05 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-06 14:31:05 +0000 |
commit | b0b69fd879a03f3d37b8cd78049d27939de23ce2 (patch) | |
tree | 23cbfb69f399f5675008c655a448fdb9aeeb23a8 /mlir | |
parent | c4a77bfb62b7caeb8a4d73a09df7e18e438b890f (diff) | |
download | llvm-b0b69fd879a03f3d37b8cd78049d27939de23ce2.zip llvm-b0b69fd879a03f3d37b8cd78049d27939de23ce2.tar.gz llvm-b0b69fd879a03f3d37b8cd78049d27939de23ce2.tar.bz2 |
[mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings (#73922)
Since #73253, loops over tiles in SSA form (i.e. loops that take
`iter_args` and yield a new tile) are supported, so this patch updates
ArmSME lowerings to this form. This is a NFC, as it still lowers to the
same intrinsics, but this makes IR less 'surprising' at a higher-level,
and may be recognised by more transforms.
Example:
IR before:
```mlir
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1
{
arm_sme.move_vector_to_tile_slice
%broadcast_to_1d, %tile, %tile_slice_index :
vector<[4]xi32> into vector<[4]x[4]xi32>
}
// ... later use %tile
```
IR now:
```mlir
%broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
{
%tile_update = arm_sme.move_vector_to_tile_slice
%broadcast_to_1d, %iter_tile, %tile_slice_index :
vector<[4]xi32> into vector<[4]x[4]xi32>
scf.yield %tile_update : vector<[4]x[4]xi32>
}
// ... later use %broadcast_to_tile
```
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 81 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 117 | ||||
-rw-r--r-- | mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir | 19 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir | 14 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir | 7 |
5 files changed, 135 insertions, 103 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index fece030..c3c9780 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, /// AFTER: /// ```mlir /// %ptrue_s = arith.constant dense<true> : vector<[4]xi1> -/// %tile = arm_sme.get_tile : vector<[4]x[4]xi32> +/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32> /// %vscale = vector.vscale /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index /// %min_svl_s = arith.constant 4 : index /// %svl_s = arith.muli %min_svl_s, %vscale : index -/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { +/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1 +/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx], -/// %ptrue_s, %tile, %tile_slice_idx +/// %ptrue_s, %iter_tile, %tile_slice_idx /// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> +/// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// ``` struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { @@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { auto tileElementType = tileType.getElementType(); // Allocate a new SME tile. - auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>( + auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>( rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. @@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { // ..., SVL_Q). auto numTileSlices = rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = - rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step); + auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { getMemrefIndices(tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), tileSliceIndex, numTileSlices, memrefIndices, loc, rewriter); - tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>( - rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + auto currentTile = forOp.getRegionIterArg(0); + auto loadSlice = + tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>( + rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, + currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + rewriter.create<scf::YieldOp>(loc, loadSlice.getResult()); rewriter.setInsertionPointAfter(forOp); - // Replace 'arm_sme.tile_load' with the tile. - rewriter.replaceOp(tileLoadOp, tile); + // Replace 'arm_sme.tile_load' with the result. + rewriter.replaceOp(tileLoadOp, forOp.getResult(0)); return success(); } @@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { /// ```mlir /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index -/// %tile = arm_sme.zero : vector<[4]x[4]xi32> +/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32> /// %num_rows = arith.constant 2 : index /// %num_cols = vector.create_mask %c4 : vector<[4]xi1> -/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 { +/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1 +/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// %tile_update = arm_sme.load_tile_slice -/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx : +/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx : /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32> +/// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// ``` /// @@ -202,14 +209,15 @@ struct TileLoadOpWithMaskAndPadZeroConversion // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive rows // however, no load will occur so these need to be zeroed. - auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>( + auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>( rewriter, loc, tileType); // Create a loop to load the active tile slices from memory. auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); auto upperBound = numRows; - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step, + ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -217,17 +225,20 @@ struct TileLoadOpWithMaskAndPadZeroConversion // tile. SmallVector<Value> memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); + auto currentTile = forOp.getRegionIterArg(0); getMemrefIndices(tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), tileSliceIndex, upperBound, memrefIndices, loc, rewriter); - tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>( - rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + auto loadSlice = + tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>( + rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, + currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + rewriter.create<scf::YieldOp>(loc, loadSlice.getResult()); rewriter.setInsertionPointAfter(forOp); - // Replace 'arm_sme.tile_load' with the tile. - rewriter.replaceOp(tileLoadOp, tile); + // Replace 'arm_sme.tile_load' with the result. + rewriter.replaceOp(tileLoadOp, forOp.getResult(0)); return success(); } @@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion /// ```mlir /// ... /// %pad_1d = arith.constant dense<1> : vector<[4]xi32> -/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { +/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1 +/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// ... /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1> /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d /// : memref<?x?xi32>, vector<[4]xi1>, /// vector<[4]xi32> into vector<[4]xi32> /// // Insert slice into tile -/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx -/// : vector<[4]xi32> into vector<[4]x[4]xi32> +/// %tile_update = arm_sme.move_vector_to_tile_slice +/// %slice, %iter_tile, %tile_slice_idx : +/// vector<[4]xi32> into vector<[4]x[4]xi32> +/// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// ``` struct TileLoadOpWithMaskAndPadNonZeroConversion @@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion loc, rewriter.getI32Type(), numCols); // Allocate a new SME tile. - auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>( + auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>( rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. @@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); auto numTileSlices = rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = - rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step); + auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); auto tileSliceIndex = forOp.getInductionVar(); + auto currentTile = forOp.getRegionIterArg(0); // Combine masks. auto rowIsActive = rewriter.create<arith::CmpIOp>( @@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion /*passthru=*/pad1DOp); // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile. - tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>( - rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex, - tileLoadOp.getLayout()); + auto moveSlice = + tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>( + rewriter, loc, tileType, loadSlice->getResult(0), currentTile, + tileSliceIndex, tileLoadOp.getLayout()); + rewriter.create<scf::YieldOp>(loc, moveSlice.getResult()); rewriter.setInsertionPointAfter(forOp); - // Replace 'arm_sme.tile_load' with the tile. - rewriter.replaceOp(tileLoadOp, tile); + // Replace 'arm_sme.tile_load' with the result. + rewriter.replaceOp(tileLoadOp, forOp.getResult(0)); return success(); } diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 4b3fd26..312e89c 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -26,21 +26,26 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) { } /// Generates a for loop over ZA tile slices where the induction variable is -/// the tile slice index. Sets the IR Builder insertion point as the loop body. -/// Callers of this method are responsible for restoring it if needed. -static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc, - Type eltType) { +/// the tile slice index and each iteration yields a new tile. Loop body is +/// built via the callback, which returns the next tile value. +template <typename LoopBodyCallback> +static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, + Location loc, Value initTile, + LoopBodyCallback callback) { + OpBuilder::InsertionGuard g(rewriter); auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(eltType)); + loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); auto vscale = rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); auto numTileSlices = rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = - rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step); + auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step, + ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); + auto nextTile = callback(forOp); + rewriter.create<scf::YieldOp>(loc, nextTile.getResult()); return forOp; } @@ -242,27 +247,27 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' // ops that broadcast the constant to each tile slice. - OpBuilder::InsertionGuard g(rewriter); auto loc = constantOp.getLoc(); - // Unpack 1-d vector type from 2-d vector type. - auto tileSliceType = - VectorType::get(tileType.getShape().drop_front(), tileElementType, - /*scalableDims=*/{true}); + // To fill a tile with a constant, we create a 1-D splat of the constant, + // then move that into each tile slice (the largest unit we can set at once, + // outside of operations like the outerproduct). + VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue<Attribute>()); auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); - auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); - - auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); - auto tileSliceIndex = forOp.getInductionVar(); - - // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice. - rewriter.create<arm_sme::MoveVectorToTileSliceOp>( - loc, tileType, constantOp1D, tile, tileSliceIndex); - - rewriter.replaceOp(constantOp, tile); + auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto forOp = + createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) { + auto tileSliceIndex = forOp.getInductionVar(); + auto currentTile = forOp.getRegionIterArg(0); + // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile + // slice. + return rewriter.create<arm_sme::MoveVectorToTileSliceOp>( + loc, tileType, constantOp1D, currentTile, tileSliceIndex); + }); + rewriter.replaceOp(constantOp, forOp.getResult(0)); return success(); } @@ -277,9 +282,13 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> -/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { -/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, -/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> +/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices +/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) +/// { +/// %tile_update = arm_sme.move_vector_to_tile_slice +/// %broadcast_to_1d, %iter_tile, %tile_slice_index : +/// vector<[4]xi32> into vector<[4]x[4]xi32> +/// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// Supports scalar, 0-d vector, and 1-d vector broadcasts. @@ -293,20 +302,16 @@ struct BroadcastOpToArmSMELowering if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); - OpBuilder::InsertionGuard g(rewriter); auto loc = broadcastOp.getLoc(); auto srcType = broadcastOp.getSourceType(); auto srcVectorType = dyn_cast<VectorType>(srcType); - auto tileElementType = tileType.getElementType(); Value broadcastOp1D; if (srcType.isIntOrFloat() || (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. - auto tileSliceType = - VectorType::get(tileType.getShape().drop_front(), tileElementType, - /*scalableDims=*/{true}); + VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); broadcastOp1D = rewriter.create<vector::BroadcastOp>( loc, tileSliceType, broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) @@ -315,18 +320,20 @@ struct BroadcastOpToArmSMELowering else return failure(); - auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); // Create a loop over ZA tile slices. - auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); - auto tileSliceIndex = forOp.getInductionVar(); - - // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each - // tile slice. - rewriter.create<arm_sme::MoveVectorToTileSliceOp>( - loc, tileType, broadcastOp1D, tile, tileSliceIndex); - - rewriter.replaceOp(broadcastOp, tile); + auto forOp = + createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) { + auto tileSliceIndex = forOp.getInductionVar(); + auto currentTile = forOp.getRegionIterArg(0); + // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value + // to each tile slice. + return rewriter.create<arm_sme::MoveVectorToTileSliceOp>( + loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); + }); + + rewriter.replaceOp(broadcastOp, forOp.getResult(0)); return success(); } @@ -341,9 +348,13 @@ struct BroadcastOpToArmSMELowering /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> -/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { -/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, -/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> +/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices +/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) +/// { +/// %tile_update = arm_sme.move_vector_to_tile_slice +/// %broadcast_to_1d, %iter_tile, %tile_slice_index : +/// vector<[4]xi32> into vector<[4]x[4]xi32> +/// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// This is identical to vector.broadcast of a scalar. @@ -356,11 +367,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); - OpBuilder::InsertionGuard g(rewriter); auto loc = splatOp.getLoc(); - auto srcType = splatOp.getOperand().getType(); - auto tileElementType = tileType.getElementType(); assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); // Avoid unused-variable warning when building without assertions. @@ -371,17 +379,19 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( loc, tileSliceType, splatOp.getInput()); - auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); // Next, create a loop over ZA tile slices and "move" the generated 1-d // vector to each slice. - auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); - auto tileSliceIndex = forOp.getInductionVar(); - - rewriter.create<arm_sme::MoveVectorToTileSliceOp>( - loc, tileType, broadcastOp1D, tile, tileSliceIndex); + auto forOp = + createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) { + auto tileSliceIndex = forOp.getInductionVar(); + auto currentTile = forOp.getRegionIterArg(0); + return rewriter.create<arm_sme::MoveVectorToTileSliceOp>( + loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); + }); - rewriter.replaceOp(splatOp, tile); + rewriter.replaceOp(splatOp, forOp.getResult(0)); return success(); } @@ -424,7 +434,6 @@ struct TransposeOpToArmSMELowering if (permutation[0] != 1 || permutation[1] != 0) return failure(); - OpBuilder::InsertionGuard g(rewriter); auto loc = transposeOp.getLoc(); // Allocate buffer to store input tile to. diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index efefc6c..5d79a04 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -6,16 +6,17 @@ // CHECK-LABEL: func.func @arm_sme_tile_load_hor( // CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) { -// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> +// CHECK-DAG: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index -// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) { // CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1> // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> +// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> +// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) { %c0 = arith.constant 0 : index %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32> @@ -40,10 +41,11 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index // CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> -// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32> -// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] { +// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32> +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) { // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> +// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32> +// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index @@ -68,7 +70,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) // CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32 // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index -// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE]]) -> (vector<[4]x[4]xi32>) { // CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index // CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32 // CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32 @@ -77,7 +79,8 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index // CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32> // CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> -// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir index b8db105..ae2d0f4 100644 --- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir @@ -95,11 +95,12 @@ func.func @arith_constant_dense_2d_zero_f64() { // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8> +// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8> // CHECK: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index -// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8> +// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[16]x[16]xi8>) { +// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8> +// CHECK: scf.yield %[[TILE_UPDATE]] : vector<[16]x[16]xi8> // CHECK: "prevent.dce"(%[[TILE]]) : (vector<[16]x[16]xi8>) -> () func.func @arith_constant_dense_2d_nonzero_i8() { %two = arith.constant dense<2> : vector<[16]x[16]xi8> @@ -114,11 +115,12 @@ func.func @arith_constant_dense_2d_nonzero_i8() { // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64> +// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64> // CHECK: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index -// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64> +// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[2]x[2]xf64>) { +// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64> +// CHECK: scf.yield %[[TILE_UPDATE]] : vector<[2]x[2]xf64> // CHECK: "prevent.dce"(%[[TILE]]) : (vector<[2]x[2]xf64>) -> () func.func @arith_constant_dense_2d_nonzero_f64() { %two = arith.constant dense<2.0> : vector<[2]x[2]xf64> diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 5bc147c..6ea949d 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -452,11 +452,12 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32> -// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> +// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index -// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: %[[C10:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) { +// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +// CHECK: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32> // CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> () func.func @broadcast_vec2d_from_i32(%arg0: i32) { %0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32> |