diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 5 | ||||
-rw-r--r-- | mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 22 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 23 | ||||
-rw-r--r-- | mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 5 | ||||
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 146 |
11 files changed, 166 insertions, 71 deletions
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index dcbaa56..247dba1 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Case<vector::SplatOp>([¤t](auto op) { - current = op.getInput(); - return false; - }) .Default([](Operation *) { return false; }); if (!skipOp) { diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index bad53c0..1002ebe 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { /// AFTER: /// ```mlir /// ... -/// %pad_1d = vector.splat %pad : vector<[4]xi32> +/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32> /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) { /// ... diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 2b7bdc9..11f866c 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <cstdint> #include <numeric> @@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, {TypeAttr::get(memrefType.getElementType())})); IndexType indexType = builder.getIndexType(); - int64_t numElements = std::accumulate(memrefType.getShape().begin(), - memrefType.getShape().end(), int64_t{1}, - std::multiplies<int64_t>()); + int64_t numElements = llvm::product_of(memrefType.getShape()); emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( builder, loc, indexType, builder.getIndexAttr(numElements)); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 802691c..9bf9ca3 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> @@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. - int64_t totalSizeNoPlaceholder = -std::accumulate( - newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>()); + int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape); // If there is a 0 component in 'newShape', resolve the placeholder as // 0. diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index 79c2f23..245a3ef 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -20,6 +20,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" #include <numeric> @@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter, if (isPacked) src = collapseLastDim(rewriter, src); int64_t rows = vecShape[0]; - int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(vecShape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); @@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, ArrayRef<int64_t> shape = vecTy.getShape(); int64_t rows = shape[0]; - int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(shape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); return amx::TileLoadOp::create(rewriter, loc, tileType, buf, diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 363685a..778c616 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering } }; -// Convert all `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to ArmSME via another pattern. -struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> { - using Base::Base; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), - splatOp.getInput()); - return success(); - } -}; - } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast, - TransferReadToArmSMELowering, TransferWriteToArmSMELowering, - TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, - VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, + patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering, + TransferWriteToArmSMELowering, TransposeOpToArmSMELowering, + VectorLoadToArmSMELowering, VectorStoreToArmSMELowering, + VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, ExtractFromCreateMaskToPselLowering>(&ctx); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5461646..5355909 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -2161,19 +2161,6 @@ public: } }; -/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from -/// `vector.broadcast` through other patterns. -struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering, + VectorBroadcastScalarToLowRankLowering, VectorBroadcastScalarToNdLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index c45c45e..c9eba69 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOSCF @@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { if (vectorType.getRank() != 1) { // Flatten n-D vectors to 1D. This is done to allow indexing with a // non-constant value. - auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies<int64_t>()); + int64_t flatLength = llvm::product_of(shape); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 311ff6f..56e8fee 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> { } }; -// Convert `vector.splat` to `vector.broadcast`. There is a path from -// `vector.broadcast` to SPIRV via other patterns. -struct VectorSplatToBroadcast final - : public OpConversionPattern<vector::SplatOp> { - using Base::Base; - LogicalResult - matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), - adaptor.getInput()); - return success(); - } -}; - struct VectorBitcastConvert final : public OpConversionPattern<vector::BitCastOp> { using Base::Base; @@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns( VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, - VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, - VectorShuffleOpConvert, VectorInterleaveOpConvert, - VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, - VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, + VectorScalarBroadcastPattern, VectorLoadOpConverter, + VectorStoreOpConverter, VectorStepOpConvert>( typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9ead1d8..71687b1 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" @@ -774,9 +775,7 @@ struct ConvertXeGPUToXeVMPass if (rank < 1 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. - int64_t sum = - std::accumulate(type.getShape().begin(), type.getShape().end(), - int64_t{1}, std::multiplies<int64_t>()); + int64_t sum = llvm::product_of(type.getShape()); return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f449d90..f276984 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { }; //===----------------------------------------------------------------------===// +// GPU index id operations +//===----------------------------------------------------------------------===// +/* +// Launch Config ops +// dimidx - x, y, z - is fixed to i32 +// return type is set by XeVM type converter +// get_local_id +xevm::WorkitemIdXOp; +xevm::WorkitemIdYOp; +xevm::WorkitemIdZOp; +// get_local_size +xevm::WorkgroupDimXOp; +xevm::WorkgroupDimYOp; +xevm::WorkgroupDimZOp; +// get_group_id +xevm::WorkgroupIdXOp; +xevm::WorkgroupIdYOp; +xevm::WorkgroupIdZOp; +// get_num_groups +xevm::GridDimXOp; +xevm::GridDimYOp; +xevm::GridDimZOp; +// get_global_id : to be added if needed +*/ + +// Helpers to get the OpenCL function name and dimension argument for each op. +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) { + return {"get_local_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) { + return {"get_local_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) { + return {"get_local_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) { + return {"get_local_size", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) { + return {"get_local_size", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) { + return {"get_local_size", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) { + return {"get_group_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) { + return {"get_group_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) { + return {"get_group_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) { + return {"get_num_groups", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) { + return {"get_num_groups", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) { + return {"get_num_groups", 2}; +} +/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with +/// a constant argument for the dimension - x, y or z. +template <typename OpType> +class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto [baseName, dim] = getConfig(op); + Type dimTy = rewriter.getI32Type(); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast<int64_t>(dim)); + std::string func = mangle(baseName, {dimTy}, {true}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +/* +// Subgroup ops +// get_sub_group_local_id +xevm::LaneIdOp; +// get_sub_group_id +xevm::SubgroupIdOp; +// get_sub_group_size +xevm::SubgroupSizeOp; +// get_num_sub_groups : to be added if needed +*/ + +// Helpers to get the OpenCL function name for each op. +static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; } +static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; } +static StringRef getConfig(xevm::SubgroupSizeOp) { + return "get_sub_group_size"; +} +template <typename OpType> +class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string func = mangle(getConfig(op).str(), {}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, BlockLoadStore1DToOCLPattern<BlockLoadOp>, - BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + BlockLoadStore1DToOCLPattern<BlockStoreOp>, + LaunchConfigOpToOCLPattern<WorkitemIdXOp>, + LaunchConfigOpToOCLPattern<WorkitemIdYOp>, + LaunchConfigOpToOCLPattern<WorkitemIdZOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimXOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimYOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimZOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdXOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdYOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdZOp>, + LaunchConfigOpToOCLPattern<GridDimXOp>, + LaunchConfigOpToOCLPattern<GridDimYOp>, + LaunchConfigOpToOCLPattern<GridDimZOp>, + SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>( patterns.getContext()); } |