aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/Globals.h25
-rw-r--r--mlir/lib/Bindings/Python/Pass.cpp20
-rw-r--r--mlir/lib/CAPI/IR/Pass.cpp15
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp167
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp34
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp147
6 files changed, 332 insertions, 76 deletions
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 71a051c..1e81f53 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -17,6 +17,7 @@
#include "NanobindUtils.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
@@ -151,6 +152,29 @@ public:
TracebackLoc &getTracebackLoc() { return tracebackLoc; }
+ class TypeIDAllocator {
+ public:
+ TypeIDAllocator() : allocator(mlirTypeIDAllocatorCreate()) {}
+ ~TypeIDAllocator() {
+ if (allocator.ptr)
+ mlirTypeIDAllocatorDestroy(allocator);
+ }
+ TypeIDAllocator(const TypeIDAllocator &) = delete;
+ TypeIDAllocator(TypeIDAllocator &&other) : allocator(other.allocator) {
+ other.allocator.ptr = nullptr;
+ }
+
+ MlirTypeIDAllocator get() { return allocator; }
+ MlirTypeID allocate() {
+ return mlirTypeIDAllocatorAllocateTypeID(allocator);
+ }
+
+ private:
+ MlirTypeIDAllocator allocator;
+ };
+
+ MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); }
+
private:
static PyGlobals *instance;
@@ -173,6 +197,7 @@ private:
llvm::StringSet<> loadedDialectModules;
TracebackLoc tracebackLoc;
+ TypeIDAllocator typeIDAllocator;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e489585..572afa9 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,6 +8,7 @@
#include "Pass.h"
+#include "Globals.h"
#include "IRModule.h"
#include "mlir-c/Pass.h"
// clang-format off
@@ -57,6 +58,13 @@ private:
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
+ // Mapping of enumerated types
+ //----------------------------------------------------------------------------
+ nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
+ .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
+ .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
+
+ //----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
nb::class_<MlirExternalPass>(m, "ExternalPass")
@@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
mlirPassManagerEnableTiming(passManager.get());
},
"Enable pass timing.")
+ .def(
+ "enable_statistics",
+ [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
+ mlirPassManagerEnableStatistics(passManager.get(), displayMode);
+ },
+ "displayMode"_a =
+ MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
+ "Enable pass statistics.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
name = nb::cast<std::string>(
nb::borrow<nb::str>(run.attr("__name__")));
}
- MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
- MlirTypeID passID =
- mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ MlirTypeID passID = PyGlobals::get().allocateTypeID();
MlirExternalPassCallbacks callbacks;
callbacks.construct = [](void *obj) {
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b0a6ec1..72bec11 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -13,6 +13,7 @@
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/ErrorHandling.h"
#include <optional>
using namespace mlir;
@@ -79,6 +80,20 @@ void mlirPassManagerEnableTiming(MlirPassManager passManager) {
unwrap(passManager)->enableTiming();
}
+void mlirPassManagerEnableStatistics(MlirPassManager passManager,
+ MlirPassDisplayMode displayMode) {
+ PassDisplayMode mode;
+ switch (displayMode) {
+ case MLIR_PASS_DISPLAY_MODE_LIST:
+ mode = PassDisplayMode::List;
+ break;
+ case MLIR_PASS_DISPLAY_MODE_PIPELINE:
+ mode = PassDisplayMode::Pipeline;
+ break;
+ }
+ unwrap(passManager)->enableStatistics(mode);
+}
+
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
MlirStringRef operationName) {
return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824..7f419a0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x2 to f6x2.";
+ }
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
- switch (getType()) {
- case ConvertFP8Type::E4M3:
- case ConvertFP8Type::E5M2:
- if (!isRoundingModeRN)
- return emitOpError("Only RN rounding mode is supported for conversions "
- "from f32x2 to .e4m3x2 or .e5m2x2 types");
- if (!isSatFinite)
- return emitOpError("Only SATFINITE saturation mode is supported for "
- "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
- break;
- case ConvertFP8Type::UE8M0:
- if (!(isRoundingModeRZ || isRoundingModeRP))
- return emitOpError("Only RZ or RP rounding modes are supported for "
- "conversions from f32x2 to .ue8m0x2 type");
- if (hasRelu)
- return emitOpError("relu not supported for conversions to .ue8m0x2 type");
- break;
- }
- return success();
+ mlir::MLIRContext *ctx = getContext();
+
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN) {
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ if (!isSatFinite) {
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions "
+ "from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ if (hasRelu) {
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ return success();
+ })
+ .Default([&](mlir::Type) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are "
+ "supported for conversions from f32x2 to f8x2";
+ });
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
- if (getType() == ConvertFP8Type::UE8M0)
- return emitOpError("Only .e4m3 or .e5m2 types are supported for "
- "conversions from f16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f16x2 to f8x2.";
+ }
return success();
}
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
- if (getType() != ConvertFP8Type::UE8M0)
- return emitOpError(
- "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
+ << " type is supported for conversions from "
+ "bf16x2 to f8x2.";
auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
-llvm::Intrinsic::ID
-ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP6Type::E2M3:
- return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
- case NVVM::ConvertFP6Type::E3M2:
- return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
- }
- llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
@@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
llvm::Intrinsic::ID
-ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
- NVVM::FPRoundingMode rnd,
+ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
- case NVVM::ConvertFP8Type::UE8M0:
- if (hasRoundingModeRZ)
- return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
- else if (hasRoundingModeRP)
- return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
- }
- llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ if (hasRoundingModeRZ)
+ return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
+ else if (hasRoundingModeRP)
+ return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
+
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
-llvm::Intrinsic::ID
-ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
- default:
- llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
- }
+llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4919d9a..9d62491 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
if (!mask) {
LDBG() << "No mask required";
+ if (assumeDynamicDimsMatchVecSizes) {
+ llvm::TypeSwitch<Operation *>(opToMask)
+ .Case<vector::TransferReadOp, vector::TransferWriteOp>(
+ [&](auto xferOp) {
+ // For vector.transfer_read and vector.transfer_write, there is
+ // also the `in-bounds` attribute that has to be set explicitly
+ // to true. Otherwise, "out-of-bounds" access will be assumed
+ // and masks will be generated while lowering these.
+ LDBG() << "Assuming dynamic dimensions match vector sizes and "
+ "setting their in-bounds to true!";
+ SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
+ ShapedType xferType = xferOp.getShapedType();
+ AffineMap permMap = xferOp.getPermutationMap();
+ // Only set the in-bounds values to true for dynamic dims.
+ // Different mechanisms will set these accordingly for the
+ // static dims.
+ for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
+ // Skip broadcast dimensions.
+ if (!dimExpr)
+ continue;
+ unsigned pos = dimExpr.getPosition();
+ if (xferType.isDynamicDim(pos))
+ inBoundsMap[i] = true;
+ }
+ rewriter.modifyOpInPlace(xferOp, [&]() {
+ xferOp.setInBoundsAttr(
+ rewriter.getBoolArrayAttr(inBoundsMap));
+ });
+ })
+ .Default([](Operation *op) {
+ // No-op if the operation is not an xfer read or write.
+ });
+ }
return opToMask;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 784e5d6..c28d2fc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
- if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ if (!vecAttr || !vecType)
return failure();
xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
+ // subgroups, don't distribute
+ auto newConstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+ setLayoutIfNeeded(newConstOp->getResult(0));
+ rewriter.replaceOp(op, newConstOp);
+ return success();
+ } else {
+ // Non-splat constant
+ // Only supports 1D & 2D
+ // TODO: support other cases that require SLM access
+ if (!eltType.isIndex())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported element type for non-splat constant op.");
+
+ if (wgShape.size() > 2)
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D & 2D vector constant supported");
+
+ SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+ int64_t rowStride = 0, colStride = 0;
+ int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
+ int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
+
+ // Compute colStride and rowStride, and check for constant strides.
+ if (cols > 1) {
+ colStride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+ if (rows > 1) {
+ rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ int64_t idx = r * cols + c;
+ // Check column stride
+ if (c > 0 && cols > 1) {
+ int64_t prevIdx = r * cols + (c - 1);
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != colStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant column stride in constant op.");
+ }
+ // Check row stride
+ if (r > 0 && rows > 1) {
+ int64_t prevIdx = (r - 1) * cols + c;
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != rowStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant row stride in constant op.");
+ }
+ }
+ }
+
+ // Create a constant for the base tile.
+ // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
+ // For 1D case, extract the first sgShape[0] elements.
+ SmallVector<Attribute> baseTileValues;
+ int baseTileCols = sgShape[sgShape.size() - 1];
+ int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
+ for (int64_t r = 0; r < baseTileRows; ++r) {
+ for (int64_t c = 0; c < baseTileCols; ++c) {
+ baseTileValues.push_back(values[r * cols + c]);
+ }
+ }
+
+ auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
+ baseTileValues);
+ auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+ // Get subgroup id
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<Value, 2> strideConsts;
+ strideConsts.push_back(
+ rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ if (rows > 1)
+ strideConsts.insert(
+ strideConsts.begin(),
+ rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+
+ SmallVector<Value> newConstOps;
+ for (auto offsets : *sgOffsets) {
+ // Multiply offset with stride, broadcast it and add to baseConstVec
+ Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (size_t i = 0; i < strideConsts.size(); ++i) {
+ Value mul = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
+ mulOffset = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), mulOffset, mul);
+ }
+ // Broadcast to baseConstVec size
+ auto bcastOffset = rewriter.create<vector::BroadcastOp>(
+ loc, baseConstVec.getType(), mulOffset);
+ auto finalConst =
+ arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+ setLayoutIfNeeded(baseConstVec);
+ setLayoutIfNeeded(bcastOffset);
+ setLayoutIfNeeded(finalConst);
+ newConstOps.push_back(finalConst);
+ }
+ rewriter.replaceOpWithMultiple(op, {newConstOps});
+ return success();
+ }
}
};