diff options
author | James Newling <james.newling@gmail.com> | 2025-07-23 13:18:09 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-23 10:18:09 -0700 |
commit | e67f3237d6242d1c362fa52e782ddfd5ae54a8af (patch) | |
tree | 09bdcc4db2f25b972fb3f9cd37ab179161bacf67 | |
parent | 8ef0c50ecac8f1e707c02bee855f43eda114f8db (diff) | |
download | llvm-e67f3237d6242d1c362fa52e782ddfd5ae54a8af.zip llvm-e67f3237d6242d1c362fa52e782ddfd5ae54a8af.tar.gz llvm-e67f3237d6242d1c362fa52e782ddfd5ae54a8af.tar.bz2 |
[mlir][armsme][vector] Replace splat with broadcast (#148024)
Part of deprecation of vector.splat
RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
4 files changed, 20 insertions, 64 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 9bc3fa3..8a2e3b63 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -607,7 +607,8 @@ struct InsertTileSliceConversion rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one); + auto allActiveMask = + vector::BroadcastOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 9a37b30..e28d5122 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp); + auto pad1DOp = + vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp); auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, tileLoadOp.getBase(), diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 125ea1e..9efa34a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering } }; -/// Conversion pattern for vector.splat. -/// -/// Example: -/// -/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> -/// -/// is converted to: -/// -/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> -/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices -/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) -/// { -/// %tile_update = arm_sme.insert_tile_slice -/// %broadcast_to_1d, %iter_tile[%tile_slice_index] : -/// vector<[4]xi32> into vector<[4]x[4]xi32> -/// scf.yield %tile_update : vector<[4]x[4]xi32> -/// } -/// -/// This is identical to vector.broadcast of a scalar. -struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { - using OpRewritePattern<vector::SplatOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - auto tileType = splatOp.getResult().getType(); - if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) - return failure(); - - auto loc = splatOp.getLoc(); - auto srcType = splatOp.getOperand().getType(); - - assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); - // Avoid unused-variable warning when building without assertions. - (void)srcType; - - // First, broadcast the scalar to a 1-d vector. - VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - Value broadcastOp1D = vector::BroadcastOp::create( - rewriter, loc, tileSliceType, splatOp.getInput()); - - auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); - - auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, - Value currentTile) { - auto nextTile = arm_sme::InsertTileSliceOp::create( - b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); - return nextTile.getResult(); - }; - - // Next, create a loop over ZA tile slices and "move" the generated 1-d - // vector to each slice. - auto forOp = - createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); - - rewriter.replaceOp(splatOp, forOp.getResult(0)); - - return success(); - } -}; - /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. @@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering } }; +// Convert all `vector.splat` to `vector.broadcast`. There is a path from +// `vector.broadcast` to ArmSME via another pattern. +struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> { + using OpRewritePattern<vector::SplatOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::SplatOp splatOp, + PatternRewriter &rewriter) const final { + + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), + splatOp.getInput()); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, + patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast, TransferReadToArmSMELowering, TransferWriteToArmSMELowering, TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 4ae710a..6f2766d 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) // CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index // CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1> // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32> +// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32> // CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> // CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32> // CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> |