aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp74
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp167
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp64
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp34
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp28
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp15
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp218
-rw-r--r--mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp1
-rw-r--r--mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp7
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp111
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp120
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp147
12 files changed, 735 insertions, 251 deletions
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5c8564b..4754f0b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -974,10 +974,10 @@ LogicalResult emitc::YieldOp::verify() {
Value result = getResult();
Operation *containingOp = getOperation()->getParentOp();
- if (result && containingOp->getNumResults() != 1)
+ if (!isa<DoOp>(containingOp) && result && containingOp->getNumResults() != 1)
return emitOpError() << "yields a value not returned by parent";
- if (!result && containingOp->getNumResults() != 0)
+ if (!isa<DoOp>(containingOp) && !result && containingOp->getNumResults() != 0)
return emitOpError() << "does not yield a value to be returned by parent";
return success();
@@ -1562,6 +1562,76 @@ LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
//===----------------------------------------------------------------------===//
+// DoOp
+//===----------------------------------------------------------------------===//
+
+void DoOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
+ p << " while ";
+ p.printRegion(getConditionRegion());
+ p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
+}
+
+LogicalResult emitc::DoOp::verify() {
+ Block &condBlock = getConditionRegion().front();
+
+ if (condBlock.getOperations().size() != 2)
+ return emitOpError(
+ "condition region must contain exactly two operations: "
+ "'emitc.expression' followed by 'emitc.yield', but found ")
+ << condBlock.getOperations().size() << " operations";
+
+ Operation &first = condBlock.front();
+ auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
+ if (!exprOp)
+ return emitOpError("expected first op in condition region to be "
+ "'emitc.expression', but got ")
+ << first.getName();
+
+ if (!exprOp.getResult().getType().isInteger(1))
+ return emitOpError("emitc.expression in condition region must return "
+ "'i1', but returns ")
+ << exprOp.getResult().getType();
+
+ Operation &last = condBlock.back();
+ auto condYield = dyn_cast<emitc::YieldOp>(last);
+ if (!condYield)
+ return emitOpError("expected last op in condition region to be "
+ "'emitc.yield', but got ")
+ << last.getName();
+
+ if (condYield.getNumOperands() != 1)
+ return emitOpError("expected condition region to return 1 value, but "
+ "it returns ")
+ << condYield.getNumOperands() << " values";
+
+ if (condYield.getOperand(0) != exprOp.getResult())
+ return emitError("'emitc.yield' must return result of "
+ "'emitc.expression' from this condition region");
+
+ Block &bodyBlock = getBodyRegion().front();
+ if (bodyBlock.mightHaveTerminator())
+ return emitOpError("body region must not contain terminator");
+
+ return success();
+}
+
+ParseResult DoOp::parse(OpAsmParser &parser, OperationState &result) {
+ Region *bodyRegion = result.addRegion();
+ Region *condRegion = result.addRegion();
+
+ if (parser.parseRegion(*bodyRegion) || parser.parseKeyword("while") ||
+ parser.parseRegion(*condRegion))
+ return failure();
+
+ if (bodyRegion->empty())
+ bodyRegion->emplaceBlock();
+
+ return parser.parseOptionalAttrDictWithKeyword(result.attributes);
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824..7f419a0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x2 to f6x2.";
+ }
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() {
bool hasRelu = getRelu();
- switch (getType()) {
- case ConvertFP8Type::E4M3:
- case ConvertFP8Type::E5M2:
- if (!isRoundingModeRN)
- return emitOpError("Only RN rounding mode is supported for conversions "
- "from f32x2 to .e4m3x2 or .e5m2x2 types");
- if (!isSatFinite)
- return emitOpError("Only SATFINITE saturation mode is supported for "
- "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
- break;
- case ConvertFP8Type::UE8M0:
- if (!(isRoundingModeRZ || isRoundingModeRP))
- return emitOpError("Only RZ or RP rounding modes are supported for "
- "conversions from f32x2 to .ue8m0x2 type");
- if (hasRelu)
- return emitOpError("relu not supported for conversions to .ue8m0x2 type");
- break;
- }
- return success();
+ mlir::MLIRContext *ctx = getContext();
+
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN) {
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ if (!isSatFinite) {
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions "
+ "from f32x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ }
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP)) {
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ if (hasRelu) {
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ }
+ return success();
+ })
+ .Default([&](mlir::Type) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are "
+ "supported for conversions from f32x2 to f8x2";
+ });
}
LogicalResult ConvertF16x2ToF8x2Op::verify() {
- if (getType() == ConvertFP8Type::UE8M0)
- return emitOpError("Only .e4m3 or .e5m2 types are supported for "
- "conversions from f16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f16x2 to f8x2.";
+ }
return success();
}
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
- if (getType() != ConvertFP8Type::UE8M0)
- return emitOpError(
- "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+ if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
+ << " type is supported for conversions from "
+ "bf16x2 to f8x2.";
auto rnd = getRnd();
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
@@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
-llvm::Intrinsic::ID
-ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP6Type::E2M3:
- return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
- case NVVM::ConvertFP6Type::E3M2:
- return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
- }
- llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
@@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
llvm::Intrinsic::ID
-ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
- NVVM::FPRoundingMode rnd,
+ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
- case NVVM::ConvertFP8Type::UE8M0:
- if (hasRoundingModeRZ)
- return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
- else if (hasRoundingModeRP)
- return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
- }
- llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ if (hasRoundingModeRZ)
+ return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
+ else if (hasRoundingModeRP)
+ return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
+
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
-llvm::Intrinsic::ID
-ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
- switch (type) {
- case NVVM::ConvertFP8Type::E4M3:
- return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
- case NVVM::ConvertFP8Type::E5M2:
- return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
- default:
- llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
- }
+llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7863c21..0dac688 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
- Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
- packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- int64_t numTiles = destRank - srcRank;
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ int64_t numberOfTiles = innerDimsPos.size();
- // 1. Extract the inner tile sizes.
- // Where possible, values are replaced with constant attributes (to match the
- // behaviour of `getPackOpSourceOrPaddedSource`).
- SmallVector<OpFoldResult> tileSizes;
- for (auto i : llvm::seq<unsigned>(0, srcRank)) {
- if (dimAndTileMapping.count(i)) {
- // Rather than taking the tile size as is, extact the actual constant
- // value Attribute where possible, e.g.:
- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
- auto [_, tileSize] =
- getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- tileSizes.push_back(tileSize);
- }
- }
+ // 1. Get the input that is going to be packed. If the input requires padding,
+ // add a padding operation and return that as the input.
+ Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // 1. All outer dims are 1 - the corresponding transposition order doesn't
+ // - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.
+
+ // 2.1 Get the permutation for linalg.transpose
SmallVector<int64_t> srcPermForTranspose;
- ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
- if (llvm::is_contained(innerDimPos, i))
+ if (llvm::is_contained(innerDimsPos, i))
continue;
srcPermForTranspose.push_back(i);
}
- srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
+ srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
+
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
+ SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
+ oneIdxAttr);
+ shapeForEmptyOp.append(packOp.getMixedTiles());
+
+ // getMixedTiles() may contain Values pointing to constant ops, not the
+ // constant attributes. Replace them with a true OpFoldResult.
+ llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
+ [&](OpFoldResult ofr) {
+ if (auto val = llvm::dyn_cast<Value>(ofr))
+ return getAsOpFoldResult(val);
+ return ofr;
+ });
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
+ LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
- // 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
- oneIdxAttr);
- transShapeForEmptyOp.append(tileSizes);
-
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
- srcPermForTranspose);
- Value empty =
- tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
- packOp.getSourceType().getElementType());
+ Value empty = tensor::EmptyOp::create(
+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
- // 2.2 Create linalg.transpose
+ // 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
- oneIdxAttr);
+ SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;
for (auto tileSize : packOp.getMixedTiles()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4919d9a..9d62491 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
if (!mask) {
LDBG() << "No mask required";
+ if (assumeDynamicDimsMatchVecSizes) {
+ llvm::TypeSwitch<Operation *>(opToMask)
+ .Case<vector::TransferReadOp, vector::TransferWriteOp>(
+ [&](auto xferOp) {
+ // For vector.transfer_read and vector.transfer_write, there is
+ // also the `in-bounds` attribute that has to be set explicitly
+ // to true. Otherwise, "out-of-bounds" access will be assumed
+ // and masks will be generated while lowering these.
+ LDBG() << "Assuming dynamic dimensions match vector sizes and "
+ "setting their in-bounds to true!";
+ SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
+ ShapedType xferType = xferOp.getShapedType();
+ AffineMap permMap = xferOp.getPermutationMap();
+ // Only set the in-bounds values to true for dynamic dims.
+ // Different mechanisms will set these accordingly for the
+ // static dims.
+ for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
+ // Skip broadcast dimensions.
+ if (!dimExpr)
+ continue;
+ unsigned pos = dimExpr.getPosition();
+ if (xferType.isDynamicDim(pos))
+ inBoundsMap[i] = true;
+ }
+ rewriter.modifyOpInPlace(xferOp, [&]() {
+ xferOp.setInBoundsAttr(
+ rewriter.getBoolArrayAttr(inBoundsMap));
+ });
+ })
+ .Default([](Operation *op) {
+ // No-op if the operation is not an xfer read or write.
+ });
+ }
return opToMask;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index ef172c1..37bdd8b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
- const float *table;
+ ArrayRef<float> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -199,14 +199,20 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
- ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+ assert(transform.table.size() ==
+ static_cast<size_t>(transform.rows * transform.cols));
+ assert(type.isFloat() && "Only floats are supported by Winograd");
+ ArrayRef<float> constVec(transform.table.data(),
+ transform.rows * transform.cols);
+ auto constAttrVec =
+ llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
+ return builder.getFloatAttr(type, v);
+ });
+ SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
return arith::ConstantOp::create(
builder, loc,
- DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ DenseFPElementsAttr::get(RankedTensorType::get(shape, type),
+ constAttrVec));
}
/// Extract height x width data from 4D tensors.
@@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 697cb35..237aab4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -27,7 +27,7 @@ using namespace mlir::nvgpu;
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
-void nvgpu::NVGPUDialect::initialize() {
+void NVGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
@@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() {
>();
}
-bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
if (!memorySpace)
return false;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
@@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
return false;
}
-bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
Attribute memorySpace = type.getMemorySpace();
return isSharedMemoryAddressSpace(memorySpace);
}
@@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixC,
const std::array<int64_t, 3> &mmaShape,
bool tf32Enabled, bool sparse = false) {
-
// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.
@@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() {
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
LogicalResult LdMatrixOp::verify() {
-
// ldmatrix reads data from source in shared memory
auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
@@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//
-unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
+static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
switch (kind) {
case TensorMapSwizzleKind::SWIZZLE_32B:
return 32;
@@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
}
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
- Operation *op, nvgpu::TensorMapDescriptorType descType,
+ Operation *op, TensorMapDescriptorType descType,
std::optional<MemRefType> memrefType = std::nullopt) {
MemRefType descMemref = descType.getTensor();
// Limitation
@@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
-
- nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+ WarpgroupAccumulatorType accType = getMatrixC().getType();
int64_t sizeM = accType.getFragmented().getDimSize(0);
int64_t sizeN = accType.getFragmented().getDimSize(1);
Type elemType = accType.getFragmented().getElementType();
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 46e82bd..2a857ed 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -43,7 +43,7 @@ using namespace mlir::transform;
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//
-void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
+void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
/// device-side async tokens cannot be materialized in nvvm. We just
@@ -62,62 +62,58 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
llvm_unreachable("unknown address space enum value");
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::DeviceAsyncTokenType type) -> Type {
- return llvmTypeConverter.convertType(
- IntegerType::get(type.getContext(), 32));
- });
- llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
+ llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
+ return llvmTypeConverter.convertType(
+ IntegerType::get(type.getContext(), 32));
+ });
+ llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- Type elemType = type.getFragmented().getElementType();
- int64_t sizeM = type.getFragmented().getDimSize(0);
- int64_t sizeN = type.getFragmented().getDimSize(1);
-
- unsigned numMembers;
- if (elemType.isF32() || elemType.isInteger(32))
- numMembers = sizeN / 2;
- else if (elemType.isF16())
- numMembers = sizeN / 4;
- else
- llvm_unreachable("unsupported type for warpgroup accumulator");
-
- SmallVector<Type> innerStructBody;
- for (unsigned i = 0; i < numMembers; i++)
- innerStructBody.push_back(elemType);
- auto innerStructType = LLVM::LLVMStructType::getLiteral(
- type.getContext(), innerStructBody);
-
- SmallVector<Type> structBody;
- for (int i = 0; i < sizeM; i += kWgmmaSizeM)
- structBody.push_back(innerStructType);
-
- auto convertedType =
- LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
- return llvmTypeConverter.convertType(convertedType);
- });
- llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
+ llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
+ SmallVector<Type> structBody;
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
+ auto convertedType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
+ return llvmTypeConverter.convertType(convertedType);
+ });
+ llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
});
llvmTypeConverter.addConversion(
- [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
+ [&](WarpgroupMatrixDescriptorType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::TensorMapDescriptorType type) -> Type {
- return LLVM::LLVMPointerType::get(type.getContext());
- });
+ llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
+ return LLVM::LLVMPointerType::get(type.getContext());
+ });
populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
}
-LogicalResult
-transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
- transform::TypeConverterBuilderOpInterface builder) {
+LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
+ TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
@@ -127,17 +123,18 @@ transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
// CreateAsyncGroupsOp
//===---------------------------------------------------------------------===//
-void transform::CreateAsyncGroupsOp::getEffects(
+void CreateAsyncGroupsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTargetMutable(), effects);
- transform::producesHandle(getOperation()->getOpResults(), effects);
- transform::modifiesPayload(effects);
+ consumesHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
}
-DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
- TransformRewriter &rewriter, Operation *target,
- ApplyToEachResultList &results, TransformState &state) {
- nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
+DiagnosedSilenceableFailure
+CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results,
+ TransformState &state) {
+ createAsyncGroups(rewriter, target, getBypassL1());
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
@@ -218,7 +215,7 @@ collectStage0PipeliningOps(scf::ForOp forOp,
continue;
}
- if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
+ if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
ops.insert(&op);
ops.insert(std::make_move_iterator(barriers.begin()),
std::make_move_iterator(barriers.end()));
@@ -246,7 +243,7 @@ setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
unsigned iteration, unsigned depth) {
// Based on the order of copies within the loop we need to set the number
// of copies in flight, unless it is already set.
- auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
+ auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
if (!waitOp || waitOp.getNumGroups())
return;
@@ -312,13 +309,12 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
// original number of iterations, in particular side-effect free operations
// and barriers, even if they cannot be predicated.
if (isMemoryEffectFree(op) ||
- isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
- nvgpu::DeviceAsyncWaitOp>(op)) {
+ isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
return op;
}
// Otherwise, only async copies can currently be predicated.
- auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
+ auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
if (!asyncCopyOp)
return nullptr;
@@ -335,8 +331,8 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
originalSrcElement, c0Index);
- auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
- rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
+ auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
+ rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
UnitAttr());
@@ -805,17 +801,16 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
rhsIndexFn, rhsShape);
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
- res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
- info.tf32Enabled);
+ res =
+ MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
}
-DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp linalgOp,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
+DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne(
+ TransformRewriter &rewriter, LinalgOp linalgOp,
+ ApplyToEachResultList &results, TransformState &state) {
bool fail = true;
// TODO: more robust detection of matmulOp, with transposes etc.
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
@@ -854,43 +849,42 @@ struct HopperBuilder {
HopperBuilder(RewriterBase &rewriter, Location loc)
: rewriter(rewriter), loc(loc) {}
- TypedValue<nvgpu::MBarrierGroupType>
+ TypedValue<MBarrierGroupType>
buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
/// Create tma descriptor op to initiate transfer from global to shared
/// memory. This must be done before the launch op, on the host.
- TypedValue<nvgpu::TensorMapDescriptorType>
+ TypedValue<TensorMapDescriptorType>
buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp);
/// Build a tma load from global memory to shared memory using `barrier` to
/// synchronize. Return the number of bytes that will be transferred.
- OpFoldResult
- buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
- TypedValue<MemRefType> sharedMemref,
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- SmallVectorImpl<Operation *> &loadOps);
- void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
+ OpFoldResult buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<MBarrierGroupType> barrier,
+ SmallVectorImpl<Operation *> &loadOps);
+ void buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> sizes);
/// If threadIdx.x == 0 does TMA request + wait, else just wait.
/// Return the operation that performs the transfer on thread0.
// TODO: In the future, don't hardcode to thread 0 but elect a leader.
SmallVector<Operation *> buildPredicateLoadsOnThread0(
- ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
- TypedValue<nvgpu::MBarrierGroupType> barrier);
+ TypedValue<MBarrierGroupType> barrier);
- void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
+ void buildTryWaitParity(TypedValue<MBarrierGroupType> barrier);
RewriterBase &rewriter;
Location loc;
};
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
- ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
- TypedValue<nvgpu::MBarrierGroupType> barrier) {
+ TypedValue<MBarrierGroupType> barrier) {
SmallVector<Operation *> loadOps;
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
@@ -931,22 +925,22 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
// return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
}
-TypedValue<nvgpu::MBarrierGroupType>
+TypedValue<MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value barrier = nvgpu::MBarrierCreateOp::create(
+ Value barrier = MBarrierCreateOp::create(
rewriter, loc,
- nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
+ MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
nvgpu::MBarrierInitOp::create(
rewriter, loc, barrier,
getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
Value());
gpu::BarrierOp::create(rewriter, loc);
- return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
+ return cast<TypedValue<MBarrierGroupType>>(barrier);
}
-TypedValue<nvgpu::TensorMapDescriptorType>
+TypedValue<TensorMapDescriptorType>
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) {
OpBuilder::InsertionGuard guard(rewriter);
@@ -962,29 +956,29 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value desc = nvgpu::TmaCreateDescriptorOp::create(
+ Value desc = TmaCreateDescriptorOp::create(
rewriter, loc,
- nvgpu::TensorMapDescriptorType::get(
- rewriter.getContext(),
- MemRefType::Builder(memref.getType())
- .setMemorySpace(sharedMemorySpace),
- TensorMapSwizzleKind::SWIZZLE_NONE,
- TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
- TensorMapInterleaveKind::INTERLEAVE_NONE),
+ TensorMapDescriptorType::get(rewriter.getContext(),
+ MemRefType::Builder(memref.getType())
+ .setMemorySpace(sharedMemorySpace),
+ TensorMapSwizzleKind::SWIZZLE_NONE,
+ TensorMapL2PromoKind::L2PROMO_NONE,
+ TensorMapOOBKind::OOB_ZERO,
+ TensorMapInterleaveKind::INTERLEAVE_NONE),
unrankedMemRef, sizes);
- return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
+ return cast<TypedValue<TensorMapDescriptorType>>(desc);
}
-OpFoldResult HopperBuilder::buildTmaAsyncLoad(
- TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
- TypedValue<MemRefType> sharedMemref,
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- SmallVectorImpl<Operation *> &loadOps) {
+OpFoldResult
+HopperBuilder::buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<MBarrierGroupType> barrier,
+ SmallVectorImpl<Operation *> &loadOps) {
MLIRContext *ctx = rewriter.getContext();
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
- Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
- rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero},
- zero, Value(), Value());
+ Operation *loadOp =
+ TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
+ ValueRange{zero, zero}, zero, Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -997,9 +991,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
return res;
}
-void HopperBuilder::buildBarrierArriveTx(
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- ArrayRef<OpFoldResult> mixedSizes) {
+void HopperBuilder::buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier,
+ ArrayRef<OpFoldResult> mixedSizes) {
assert(!mixedSizes.empty() && "expecte non-empty sizes");
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -1013,8 +1006,7 @@ void HopperBuilder::buildBarrierArriveTx(
Value());
}
-void HopperBuilder::buildTryWaitParity(
- TypedValue<nvgpu::MBarrierGroupType> barrier) {
+void HopperBuilder::buildTryWaitParity(TypedValue<MBarrierGroupType> barrier) {
Type i1 = rewriter.getI1Type();
Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
// 10M is an arbitrary, not too small or too big number to specify the number
@@ -1058,11 +1050,11 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
launchOp.getBlockSizeZ()});
- TypedValue<nvgpu::MBarrierGroupType> barrier =
+ TypedValue<MBarrierGroupType> barrier =
buildAndInitBarrierInSharedMemory(numThreads);
SmallVector<TypedValue<MemRefType>> shmems;
- SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
+ SmallVector<TypedValue<TensorMapDescriptorType>> globalDescs;
for (Operation *op : copyOps) {
auto copyOp = cast<linalg::CopyOp>(op);
auto inMemRef =
@@ -1071,7 +1063,7 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
"expected in to be a 2D memref");
// 2. Build global memory descriptor.
- TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
+ TypedValue<TensorMapDescriptorType> globalDesc =
buildGlobalMemRefDescriptor(inMemRef, launchOp);
globalDescs.push_back(globalDesc);
@@ -1098,9 +1090,8 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
}
DiagnosedSilenceableFailure
-transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
+RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter,
+ TransformResults &results, TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
gpu::LaunchOp commonLaunchOp;
Operation *firstOp, *failingOp;
@@ -1137,15 +1128,14 @@ transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
namespace {
class NVGPUTransformDialectExtension
- : public transform::TransformDialectExtension<
- NVGPUTransformDialectExtension> {
+ : public TransformDialectExtension<NVGPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
- declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<NVGPUDialect>();
declareGeneratedDialect<NVVM::NVVMDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 5b89c87..7f626a6 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -64,6 +64,5 @@ private:
void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
-
patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
}
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 809d634..9e5ea93 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
const WarpMatrixInfo &fragmentType) {
Type elementType = fragmentType.vectorType.getElementType();
ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
- FailureOr<nvgpu::FragmentElementInfo> regInfo =
- getMmaSyncRegisterType(fragmentType);
+ FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
if (failed(regInfo))
return failure();
@@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
(logicalValueIdDim % elementsPerRegister)});
}
-FailureOr<nvgpu::LdMatrixParams>
-nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
+FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type,
+ bool transpose) {
LdMatrixParams params;
Type elType = type.vectorType.getElementType();
params.fragmentType = type.vectorType;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6598ac1..6564a4e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -7,6 +7,7 @@
// =============================================================================
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -44,6 +45,7 @@ struct MemRefPointerLikeModel
Type getElementType(Type pointer) const {
return cast<MemRefType>(pointer).getElementType();
}
+
mlir::acc::VariableTypeCategory
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
Type varType) const {
@@ -70,6 +72,115 @@ struct MemRefPointerLikeModel
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
return mlir::acc::VariableTypeCategory::array;
}
+
+ mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
+ StringRef varName, Type varType,
+ Value originalVar) const {
+ auto memrefTy = cast<MemRefType>(pointer);
+
+ // Check if this is a static memref (all dimensions are known) - if yes
+ // then we can generate an alloca operation.
+ if (memrefTy.hasStaticShape())
+ return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+
+ // For dynamic memrefs, extract sizes from the original variable if
+ // provided. Otherwise they cannot be handled.
+ if (originalVar && originalVar.getType() == memrefTy &&
+ memrefTy.hasRank()) {
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
+ if (memrefTy.isDynamicDim(i)) {
+ // Extract the size of dimension i from the original variable
+ auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
+ auto dimSize =
+ memref::DimOp::create(builder, loc, originalVar, indexValue);
+ dynamicSizes.push_back(dimSize);
+ }
+ // Note: We only add dynamic sizes to the dynamicSizes array
+ // Static dimensions are handled automatically by AllocOp
+ }
+ return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
+ .getResult();
+ }
+
+ // TODO: Unranked not yet supported.
+ return {};
+ }
+
+ bool genFree(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> varPtr, Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ // Walk through casts to find the original allocation
+ Value currentValue = memrefValue;
+ Operation *originalAlloc = nullptr;
+
+ // Follow the chain of operations to find the original allocation
+ // even if a casted result is provided.
+ while (currentValue) {
+ if (auto *definingOp = currentValue.getDefiningOp()) {
+ // Check if this is an allocation operation
+ if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
+ originalAlloc = definingOp;
+ break;
+ }
+
+ // Check if this is a cast operation we can look through
+ if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
+ currentValue = castOp.getSource();
+ continue;
+ }
+
+ // Check for other cast-like operations
+ if (auto reinterpretCastOp =
+ dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
+ currentValue = reinterpretCastOp.getSource();
+ continue;
+ }
+
+ // If we can't look through this operation, stop
+ break;
+ }
+ // This is a block argument or similar - can't trace further.
+ break;
+ }
+
+ if (originalAlloc) {
+ if (isa<memref::AllocaOp>(originalAlloc)) {
+ // This is an alloca - no dealloc needed, but return true (success)
+ return true;
+ }
+ if (isa<memref::AllocOp>(originalAlloc)) {
+ // This is an alloc - generate dealloc
+ memref::DeallocOp::create(builder, loc, memrefValue);
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ bool genCopy(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> destination,
+ TypedValue<PointerLikeType> source, Type varType) const {
+ // Generate a copy operation between two memrefs
+ auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
+ auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
+
+ // As per memref documentation, source and destination must have same
+ // element type and shape in order to be compatible. We do not want to fail
+ // with an IR verification error - thus check that before generating the
+ // copy operation.
+ if (destMemref && srcMemref &&
+ destMemref.getType().getElementType() ==
+ srcMemref.getType().getElementType() &&
+ destMemref.getType().getShape() == srcMemref.getType().getShape()) {
+ memref::CopyOp::create(builder, loc, srcMemref, destMemref);
+ return true;
+ }
+
+ return false;
+ }
};
struct LLVMPointerPointerLikeModel
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e8..14e235f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -47,6 +47,7 @@
#include <cassert>
#include <cstdint>
+#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}
+/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
+///
+/// Example:
+/// %b = vector.broadcast %x : i32 to vector<3xf32>
+/// %e:3 = vector.to_elements %b : vector<3xf32>
+/// user_op %e#0, %e#1, %e#2
+/// becomes:
+/// user_op %x, %x, %x
+///
+/// The vector source case is handled by a canonicalization pattern.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+ // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
+ if (isa<VectorType>(bcastOp.getSource().getType()))
+ return failure();
+
+ auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ Value scalar = bcastOp.getSource();
+ results.assign(resultVecType.getNumElements(), scalar);
+ return success();
+}
+
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- return foldToElementsFromElements(*this, results);
+ if (succeeded(foldToElementsFromElements(*this, results)))
+ return success();
+ return foldToElementsOfBroadcast(*this, results);
}
LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector.
+/// - Build `vector.to_elements %v` and remap each destination element to the
+/// corresponding source element using broadcast rules (match or 1 →
+/// replicate).
+///
+/// Example:
+/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+/// %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
+/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+/// // %src_elems#1, %src_elems#0, %src_elems#1
+struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ // Only handle broadcasts from a vector source here.
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ ArrayRef<int64_t> dstShape = dstType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+
+ int64_t dstRank = dstShape.size();
+ int64_t srcRank = srcShape.size();
+
+ // Create elements for the broadcast source vector.
+ auto srcElems = vector::ToElementsOp::create(
+ rewriter, toElementsOp.getLoc(), bcastOp.getSource());
+
+ int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+ std::multiplies<int64_t>());
+
+ SmallVector<Value> replacements;
+ replacements.reserve(dstCount);
+
+ // For each element of the destination, determine which element of the
+ // source should be used. We walk all destination positions using a single
+ // counter, decode it into per-dimension indices, then build the matching
+ // source position: use the same index where sizes match, and use 0 where
+ // the source size is 1 (replication). This mapping is needed so we can
+ // replace each result of to_elements with the corresponding element from
+ // the broadcast source.
+ // Inner-dimension stretch example:
+ // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
+ // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
+ // becomes:
+ // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
+ // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+ // // %src_elems#1, %src_elems#0, %src_elems#1,
+ // // %src_elems#2, %src_elems#3, %src_elems#2,
+ // // %src_elems#3, %src_elems#2, %src_elems#3
+
+ // Row-major strides for the destination shape.
+ SmallVector<int64_t> dstStrides = computeStrides(dstShape);
+ // Row-major strides for the source shape.
+ SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t> dstIdx(dstRank);
+ SmallVector<int64_t> srcIdx(srcRank);
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ // Convert linear destination index to per-dimension indices.
+ dstIdx = delinearize(lin, dstStrides);
+ for (int64_t k = 0; k < srcRank; ++k)
+ srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
+ // Convert per-dimension source indices back to a linear index.
+ int64_t srcLin = linearize(srcIdx, srcStrides);
+ replacements.push_back(srcElems.getResult(srcLin));
+ }
+
+ rewriter.replaceOp(toElementsOp, replacements);
+ return success();
+ }
+};
+
+void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ToElementsOfBroadcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 784e5d6..c28d2fc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
- if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ if (!vecAttr || !vecType)
return failure();
xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
- // Current limitation: constant of vector with single value.
- // TODO: support more complex cases, e.g., vector with multiple values.
- Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
auto newType = VectorType::get(sgShape, vecType.getElementType());
- auto sgAttr = DenseElementsAttr::get(newType, singleVal);
- auto cstOp =
- arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
- layout.dropSgLayoutAndData());
- SmallVector<Value> newConsts(count, cstOp);
+ Location loc = op.getLoc();
+ auto eltType = vecType.getElementType();
- rewriter.replaceOpWithMultiple(op, {newConsts});
- return success();
+ auto setLayoutIfNeeded = [&](Value val) {
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
+ }
+ };
+
+ if (vecAttr.isSplat()) {
+ // Splat: single value for all subgroups
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+ setLayoutIfNeeded(cstOp->getResult(0));
+ rewriter.replaceOp(op, cstOp);
+ return success();
+ } else if (sgShape == wgShape) { // if the entire vector is shared by all
+ // subgroups, don't distribute
+ auto newConstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+ setLayoutIfNeeded(newConstOp->getResult(0));
+ rewriter.replaceOp(op, newConstOp);
+ return success();
+ } else {
+ // Non-splat constant
+ // Only supports 1D & 2D
+ // TODO: support other cases that require SLM access
+ if (!eltType.isIndex())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported element type for non-splat constant op.");
+
+ if (wgShape.size() > 2)
+ return rewriter.notifyMatchFailure(
+ op, "Only 1D & 2D vector constant supported");
+
+ SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+ int64_t rowStride = 0, colStride = 0;
+ int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
+ int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
+
+ // Compute colStride and rowStride, and check for constant strides.
+ if (cols > 1) {
+ colStride = cast<IntegerAttr>(values[1]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+ if (rows > 1) {
+ rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+ cast<IntegerAttr>(values[0]).getInt();
+ }
+
+ for (int64_t r = 0; r < rows; ++r) {
+ for (int64_t c = 0; c < cols; ++c) {
+ int64_t idx = r * cols + c;
+ // Check column stride
+ if (c > 0 && cols > 1) {
+ int64_t prevIdx = r * cols + (c - 1);
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != colStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant column stride in constant op.");
+ }
+ // Check row stride
+ if (r > 0 && rows > 1) {
+ int64_t prevIdx = (r - 1) * cols + c;
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+ cast<IntegerAttr>(values[prevIdx]).getInt();
+ if (diff != rowStride)
+ return rewriter.notifyMatchFailure(
+ op, "Non-constant row stride in constant op.");
+ }
+ }
+ }
+
+ // Create a constant for the base tile.
+ // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
+ // For 1D case, extract the first sgShape[0] elements.
+ SmallVector<Attribute> baseTileValues;
+ int baseTileCols = sgShape[sgShape.size() - 1];
+ int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
+ for (int64_t r = 0; r < baseTileRows; ++r) {
+ for (int64_t c = 0; c < baseTileCols; ++c) {
+ baseTileValues.push_back(values[r * cols + c]);
+ }
+ }
+
+ auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
+ baseTileValues);
+ auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+ // Get subgroup id
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<Value, 2> strideConsts;
+ strideConsts.push_back(
+ rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ if (rows > 1)
+ strideConsts.insert(
+ strideConsts.begin(),
+ rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+
+ SmallVector<Value> newConstOps;
+ for (auto offsets : *sgOffsets) {
+ // Multiply offset with stride, broadcast it and add to baseConstVec
+ Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (size_t i = 0; i < strideConsts.size(); ++i) {
+ Value mul = rewriter.create<arith::MulIOp>(
+ loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
+ mulOffset = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), mulOffset, mul);
+ }
+ // Broadcast to baseConstVec size
+ auto bcastOffset = rewriter.create<vector::BroadcastOp>(
+ loc, baseConstVec.getType(), mulOffset);
+ auto finalConst =
+ arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+ setLayoutIfNeeded(baseConstVec);
+ setLayoutIfNeeded(bcastOffset);
+ setLayoutIfNeeded(finalConst);
+ newConstOps.push_back(finalConst);
+ }
+ rewriter.replaceOpWithMultiple(op, {newConstOps});
+ return success();
+ }
}
};