aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp4
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp2
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp7
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp23
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp5
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp146
11 files changed, 166 insertions, 71 deletions
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index dcbaa56..247dba1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Case<vector::SplatOp>([&current](auto op) {
- current = op.getInput();
- return false;
- })
.Default([](Operation *) { return false; });
if (!skipOp) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index bad53c0..1002ebe 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
-/// %pad_1d = vector.splat %pad : vector<[4]xi32>
+/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 2b7bdc9..11f866c 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <cstdint>
#include <numeric>
@@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
{TypeAttr::get(memrefType.getElementType())}));
IndexType indexType = builder.getIndexType();
- int64_t numElements = std::accumulate(memrefType.getShape().begin(),
- memrefType.getShape().end(), int64_t{1},
- std::multiplies<int64_t>());
+ int64_t numElements = llvm::product_of(memrefType.getShape());
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
builder, loc, indexType, builder.getIndexAttr(numElements));
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 802691c..9bf9ca3 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
@@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Calculate the product of all elements in 'newShape' except for the -1
// placeholder, which we discard by negating the result.
- int64_t totalSizeNoPlaceholder = -std::accumulate(
- newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
+ int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
// If there is a 0 component in 'newShape', resolve the placeholder as
// 0.
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 79c2f23..245a3ef 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -20,6 +20,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include <numeric>
@@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
if (isPacked)
src = collapseLastDim(rewriter, src);
int64_t rows = vecShape[0];
- int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(vecShape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
ArrayRef<int64_t> shape = vecTy.getShape();
int64_t rows = shape[0];
- int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(shape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 363685a..778c616 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};
-// Convert all `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to ArmSME via another pattern.
-struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5461646..5355909 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -2161,19 +2161,6 @@ public:
}
};
-/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
-/// `vector.broadcast` through other patterns.
-struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
+ VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c45c45e..c9eba69 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOSCF
@@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value.
- auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t flatLength = llvm::product_of(shape);
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 311ff6f..56e8fee 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
-// Convert `vector.splat` to `vector.broadcast`. There is a path from
-// `vector.broadcast` to SPIRV via other patterns.
-struct VectorSplatToBroadcast final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
- LogicalResult
- matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
- adaptor.getInput());
- return success();
- }
-};
-
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
@@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorInterleaveOpConvert,
- VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
- VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorScalarBroadcastPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d8..71687b1 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -774,9 +775,7 @@ struct ConvertXeGPUToXeVMPass
if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
- int64_t sum =
- std::accumulate(type.getShape().begin(), type.getShape().end(),
- int64_t{1}, std::multiplies<int64_t>());
+ int64_t sum = llvm::product_of(type.getShape());
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f449d90..f276984 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
};
//===----------------------------------------------------------------------===//
+// GPU index id operations
+//===----------------------------------------------------------------------===//
+/*
+// Launch Config ops
+// dimidx - x, y, z - is fixed to i32
+// return type is set by XeVM type converter
+// get_local_id
+xevm::WorkitemIdXOp;
+xevm::WorkitemIdYOp;
+xevm::WorkitemIdZOp;
+// get_local_size
+xevm::WorkgroupDimXOp;
+xevm::WorkgroupDimYOp;
+xevm::WorkgroupDimZOp;
+// get_group_id
+xevm::WorkgroupIdXOp;
+xevm::WorkgroupIdYOp;
+xevm::WorkgroupIdZOp;
+// get_num_groups
+xevm::GridDimXOp;
+xevm::GridDimYOp;
+xevm::GridDimZOp;
+// get_global_id : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name and dimension argument for each op.
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
+ return {"get_local_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
+ return {"get_local_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
+ return {"get_local_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
+ return {"get_local_size", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
+ return {"get_local_size", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
+ return {"get_local_size", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
+ return {"get_group_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
+ return {"get_group_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
+ return {"get_group_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
+ return {"get_num_groups", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
+ return {"get_num_groups", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
+ return {"get_num_groups", 2};
+}
+/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
+/// a constant argument for the dimension - x, y or z.
+template <typename OpType>
+class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto [baseName, dim] = getConfig(op);
+ Type dimTy = rewriter.getI32Type();
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
+ std::string func = mangle(baseName, {dimTy}, {true});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+/*
+// Subgroup ops
+// get_sub_group_local_id
+xevm::LaneIdOp;
+// get_sub_group_id
+xevm::SubgroupIdOp;
+// get_sub_group_size
+xevm::SubgroupSizeOp;
+// get_num_sub_groups : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name for each op.
+static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
+static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
+static StringRef getConfig(xevm::SubgroupSizeOp) {
+ return "get_sub_group_size";
+}
+template <typename OpType>
+class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::string func = mangle(getConfig(op).str(), {});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
- BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
+ LaunchConfigOpToOCLPattern<GridDimXOp>,
+ LaunchConfigOpToOCLPattern<GridDimYOp>,
+ LaunchConfigOpToOCLPattern<GridDimZOp>,
+ SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
patterns.getContext());
}