aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AsmParser/Parser.cpp4
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp8
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp5
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp215
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp10
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp8
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp47
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp10
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp162
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp38
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp186
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp10
-rw-r--r--mlir/lib/IR/Operation.cpp18
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp3
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp112
20 files changed, 630 insertions, 225 deletions
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 82bdb84..74936e3 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -407,8 +407,8 @@ Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
"hexadecimal float constant out of range for type");
}
- APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
- intValue.getRawData());
+ APInt truncatedValue(typeSizeInBits,
+ ArrayRef(intValue.getRawData(), intValue.getNumWords()));
result.emplace(semantics, truncatedValue);
return success();
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ba57155..03ed4d5 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -240,8 +240,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 798d8b0..b75968e 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
@@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
@@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
+ matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Get or convert default block.
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index a2dfc12..a922338 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -68,7 +68,7 @@ struct ClampFOpConversion final
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- typename math::ClampFOp::Adaptor adaptor(operands);
+ math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2e..de552ce 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
+ if (valOrResVecTy.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 4d2d873..3d1a734 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+ .Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; })
+ .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
.Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
- // TODO: AtomicRMW supports other kinds of reductions this is
- // currently not detecting, add those when the need arises.
return std::nullopt;
});
if (!maybeKind)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 0992ce14..d478220 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -584,6 +584,10 @@ void ForOp::print(OpAsmPrinter &p) {
LogicalResult ForOp::verifyRegions() {
// Check that the body defines as single block argument for the induction
// variable.
+ if (getBody()->getNumArguments() != 1)
+ return emitOpError("expected body to have a single block argument for the "
+ "induction variable");
+
if (getInductionVar().getType() != getLowerBound().getType())
return emitOpError(
"expected induction variable to be same type as bounds and step");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a5ffb9e..262d9b7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -365,6 +365,59 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Stochastic Rounding Conversion Ops
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to f16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+ if (getRnd() != FPRoundingMode::RS)
+ return emitOpError("Only RS rounding mode is supported for "
+ "conversions from f32x2 to bf16x2.");
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF8x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f32x4 to f8x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF6x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x4 to f6x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF4x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from "
+ "f32x4 to f4x4.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -867,15 +920,40 @@ LogicalResult MmaOp::verify() {
}
LogicalResult ShflOp::verify() {
- if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
- return success();
- auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
- auto elementType = (type && type.getBody().size() == 2)
- ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
- : nullptr;
- if (!elementType || elementType.getWidth() != 1)
- return emitError("expected return type to be a two-element struct with "
- "i1 as the second element");
+ auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+
+ auto verifyTypeError = [&](Twine desc, Type expectedType,
+ Type actualType) -> LogicalResult {
+ return emitOpError("expected " + desc + " to be of type ")
+ << expectedType << " but got " << actualType << " instead";
+ };
+
+ if (returnStructType) {
+ if (!getReturnValueAndIsValid())
+ return emitOpError("\"return_value_and_is_valid\" attribute must be "
+ "specified when the return type is a struct type");
+
+ if (returnStructType.getBody().size() != 2)
+ return emitOpError("expected return type to be a two-element struct");
+
+ llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
+ auto resultType = returnStruct[0];
+ if (resultType != getVal().getType())
+ return verifyTypeError("first element in the returned struct",
+ getVal().getType(), resultType);
+
+ auto predicateType = returnStruct[1];
+ if (!predicateType.isInteger(1))
+ return verifyTypeError("second element in the returned struct",
+ mlir::IntegerType::get(getContext(), 1),
+ predicateType);
+ } else {
+ if (getReturnValueAndIsValid())
+ return emitOpError("expected return type to be a two-element struct");
+
+ if (getType() != getVal().getType())
+ return verifyTypeError("return type", getVal().getType(), getType());
+ }
return success();
}
@@ -1577,6 +1655,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
return success();
}
+LogicalResult NVVM::ReduxOp::verify() {
+ mlir::Type reduxType = getType();
+
+ if (!reduxType.isF32()) {
+ if (getAbs())
+ return emitOpError("abs attribute is supported only for f32 type");
+ if (getNan())
+ return emitOpError("nan attribute is supported only for f32 type");
+ }
+
+ NVVM::ReduxKind kind = getKind();
+ switch (kind) {
+ case NVVM::ReduxKind::ADD:
+ case NVVM::ReduxKind::AND:
+ case NVVM::ReduxKind::OR:
+ case NVVM::ReduxKind::XOR:
+ case NVVM::ReduxKind::MAX:
+ case NVVM::ReduxKind::MIN:
+ case NVVM::ReduxKind::UMAX:
+ case NVVM::ReduxKind::UMIN:
+ if (!reduxType.isInteger(32))
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'i32'.";
+ break;
+ case NVVM::ReduxKind::FMIN:
+ case NVVM::ReduxKind::FMAX:
+ if (!reduxType.isF32())
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'f32'.";
+ break;
+ }
+
+ return success();
+}
+
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
@@ -2469,6 +2584,85 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
+llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+
+ if (hasRelu && hasSatFinite)
+ return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
+ if (hasRelu)
+ return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
+ if (hasSatFinite)
+ return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
+ return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+}
+
+llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
+ bool hasRelu = getRelu();
+ bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+
+ if (hasRelu && hasSatFinite)
+ return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
+ if (hasRelu)
+ return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
+ if (hasSatFinite)
+ return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
+ return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
@@ -2508,6 +2702,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");
+ if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
+ result = emitError("offset argument is only supported for shape 16x32bx2");
+
auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f..e271ac5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2568,6 +2568,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
// Both source and result stride must have the same static value. In that
@@ -2582,11 +2587,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
- // Dimensions of size 1 should be skipped, because their strides are
- // meaningless and could have any arbitrary value.
- if (srcShape[idx - 1] == 1)
- continue;
-
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 35eba72..b2f1d84 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -3068,8 +3068,12 @@ LogicalResult acc::LoopOp::verify() {
if (getRegion().empty())
return emitError("expected non-empty body.");
- // When it is container-like - it is expected to hold a loop-like operation.
- if (isContainerLike()) {
+ if (getUnstructured()) {
+ if (!isContainerLike())
+ return emitError(
+ "unstructured acc.loop must not have induction variables");
+ } else if (isContainerLike()) {
+ // When it is container-like - it is expected to hold a loop-like operation.
// Obtain the maximum collapse count - we use this to check that there
// are enough loops contained.
uint64_t collapseCount = getCollapseValue().value_or(1);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2946b53..881e256 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
struct ConditionPropagation : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
+ /// Kind of parent region in the ancestor cache.
+ enum class Parent { Then, Else, None };
+
+ /// Returns the kind of region ("then", "else", or "none") of the
+ /// IfOp that the given region is transitively nested in. Updates
+ /// the cache accordingly.
+ static Parent getParentType(Region *toCheck, IfOp op,
+ DenseMap<Region *, Parent> &cache,
+ Region *endRegion) {
+ SmallVector<Region *> seen;
+ while (toCheck != endRegion) {
+ auto found = cache.find(toCheck);
+ if (found != cache.end())
+ return found->second;
+ seen.push_back(toCheck);
+ if (&op.getThenRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Then;
+ return Parent::Then;
+ }
+ if (&op.getElseRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Else;
+ return Parent::Else;
+ }
+ toCheck = toCheck->getParentRegion();
+ }
+
+ for (Region *region : seen)
+ cache[region] = Parent::None;
+ return Parent::None;
+ }
+
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
Value constantTrue = nullptr;
Value constantFalse = nullptr;
+ DenseMap<Region *, Parent> cache;
for (OpOperand &use :
llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+ switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
+ op.getCondition().getParentRegion())) {
+ case Parent::Then: {
changed = true;
if (!constantTrue)
@@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
+ break;
+ }
+ case Parent::Else: {
changed = true;
if (!constantFalse)
@@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
+ break;
+ }
+ case Parent::None:
+ break;
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7a26cd3..1fbcf5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1050,7 +1050,7 @@ public:
/// Sparse codegen rule for position accesses.
class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
public:
- using OpAdaptor = typename ToPositionsOp::Adaptor;
+ using OpAdaptor = ToPositionsOp::Adaptor;
using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
@@ -1073,7 +1073,7 @@ public:
class SparseToCoordinatesConverter
: public OpConversionPattern<ToCoordinatesOp> {
public:
- using OpAdaptor = typename ToCoordinatesOp::Adaptor;
+ using OpAdaptor = ToCoordinatesOp::Adaptor;
using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
@@ -1099,7 +1099,7 @@ public:
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
- using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
+ using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
@@ -1121,7 +1121,7 @@ public:
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
- using OpAdaptor = typename ToValuesOp::Adaptor;
+ using OpAdaptor = ToValuesOp::Adaptor;
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae3423c..daef0ba 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::ori:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::OR, vector);
- // TODO: Add remaining reduction operations.
+ case arith::AtomicRMWKind::minnumf:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINNUMF, vector);
+ case arith::AtomicRMWKind::maxnumf:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXNUMF, vector);
+ case arith::AtomicRMWKind::xori:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::XOR, vector);
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 83406c8..397107b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
- SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+ SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
+ SmallVector<SmallVector<Value>> coordinates;
+
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
- offsets.push_back(mods);
+ coordinates.push_back(mods);
}
- return offsets;
+ return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
@@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
- return failure();
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
@@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
};
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
-
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ SmallVector<int64_t> layout;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ } else {
+ return failure();
+ }
+ auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
//===----------------------------------------------------------------------===//
@@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
@@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2..7b6c4b6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
+ UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ ArrayAttr strideAttr = mdescTy.getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ if (subgroup_block_io && layout) {
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (!laneData.empty()) {
+ bool isLaneDataContiguous =
+ std::all_of(laneData.begin(), std::prev(laneData.end()),
+ [](int x) { return x == 1; });
+ if (!isLaneDataContiguous)
+ return emitError() << "With subgroup_block_io, accessed data must be "
+ "contiguous and coalesced.";
+ for (size_t i = 0; i < laneData.size(); ++i) {
+ if (laneLayout[i] != blockShape[i])
+ return emitError() << "With subgroup_block_io, the block shape must "
+ "match the lane layout.";
+ if (laneLayout[i] != 1 && strides[i] != 1)
+ return emitError() << "With subgroup_block_io, the distributed "
+ "dimensions must be contiguous.";
+ }
+ }
+ }
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
} else {
- SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
@@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() {
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
namespace mlir {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5a3b27e..bbd7733 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -912,6 +913,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+ Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+ SmallVector<Value> newCoods;
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ newCoods = llvm::to_vector(llvm::map_range(
+ ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ return newCoods;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ const int operandIdx = producedByLastLoad->getOperandNumber();
+
+ VectorType sgPayloadTy =
+ dyn_cast<VectorType>(matrixOp.getResult().getType());
+ VectorType warpResultTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ return success();
+ }
+};
+
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the store op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1443,7 +1624,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
GpuBarrierDistribution, VectorMultiReductionDistribution,
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
+ VectorBitcastDistribution, LoadMatrixDistribution,
+ StoreMatrixDistribution,
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
@@ -1468,6 +1650,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9..79eea55 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeDescOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeDescOffsets))
return failure();
@@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index ce421f4..8212d6d 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -463,28 +463,26 @@ void Operation::updateOrderIfNecessary() {
//===----------------------------------------------------------------------===//
auto llvm::ilist_detail::SpecificNodeAccess<
- typename llvm::ilist_detail::compute_node_options<
- ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * {
+ llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>::
+ getNodePtr(pointer n) -> node_type * {
return NodeAccess::getNodePtr<OptionsT>(n);
}
auto llvm::ilist_detail::SpecificNodeAccess<
- typename llvm::ilist_detail::compute_node_options<
- ::mlir::Operation>::type>::getNodePtr(const_pointer n)
- -> const node_type * {
+ llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>::
+ getNodePtr(const_pointer n) -> const node_type * {
return NodeAccess::getNodePtr<OptionsT>(n);
}
auto llvm::ilist_detail::SpecificNodeAccess<
- typename llvm::ilist_detail::compute_node_options<
- ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer {
+ llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>::
+ getValuePtr(node_type *n) -> pointer {
return NodeAccess::getValuePtr<OptionsT>(n);
}
auto llvm::ilist_detail::SpecificNodeAccess<
- typename llvm::ilist_detail::compute_node_options<
- ::mlir::Operation>::type>::getValuePtr(const node_type *n)
- -> const_pointer {
+ llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>::
+ getValuePtr(const node_type *n) -> const_pointer {
return NodeAccess::getValuePtr<OptionsT>(n);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 3d86b09..0964e1b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -36,9 +36,6 @@ using mlir::LLVM::detail::createIntrinsicCall;
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
NVVM::ReduxKind kind,
bool hasAbs, bool hasNaN) {
- if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
- llvm_unreachable("unsupported data type for redux");
-
switch (kind) {
case NVVM::ReduxKind::ADD:
return llvm::Intrinsic::nvvm_redux_sync_add;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbf..2fe0697 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;
- /// A set of blocks that were inserted (newly-created blocks or moved blocks)
- /// by the current pattern.
- SetVector<Block *> patternInsertedBlocks;
-
/// A list of unresolved materializations that were created by the current
/// pattern.
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
if (!config.allowPatternRollback && config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
- patternInsertedBlocks.insert(block);
-
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ private:
bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
- LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- const RewriterState &curState,
- const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks);
-
- /// Legalizes the actions registered during the execution of a pattern.
LogicalResult
- legalizePatternBlockRewrites(Operation *op,
- const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps);
+ legalizePatternResult(Operation *op, const Pattern &pattern,
+ const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps);
+
LogicalResult
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto cleanup = llvm::make_scope_exit([&]() {
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
});
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
static void
reportNewIrLegalizationFatalError(const Pattern &pattern,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
auto newOpNames = llvm::map_range(
newOps, [](Operation *op) { return op->getName().getStringRef(); });
auto modifiedOpNames = llvm::map_range(
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
- StringRef detachedBlockStr = "(detached block)";
- auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
- if (block->getParentOp())
- return block->getParentOp()->getName().getStringRef();
- return detachedBlockStr;
- });
- llvm::report_fatal_error(
- "pattern '" + pattern.getDebugName() +
- "' produced IR that could not be legalized. " + "new ops: {" +
- llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
- llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
- llvm::join(insertedBlockNames, ", ") + "}");
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' produced IR that could not be legalized. " +
+ "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+ "modified ops: {" +
+ llvm::join(modifiedOpNames, ", ") + "}");
}
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
}
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
- SetVector<Block *> insertedBlocks =
- moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, curState, newOps,
- modifiedOps, insertedBlocks);
+ auto result =
+ legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
- reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
- insertedBlocks);
+ reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, const RewriterState &curState,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
[[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
- failed(legalizePatternRootUpdates(modifiedOps)) ||
+ if (failed(legalizePatternRootUpdates(modifiedOps)) ||
failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return success();
}
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps) {
- ConversionPatternRewriterImpl &impl = rewriter.getImpl();
- SmallPtrSet<Operation *, 16> alreadyLegalized;
-
- // If the pattern moved or created any blocks, make sure the types of block
- // arguments get legalized.
- for (Block *block : insertedBlocks) {
- if (impl.erasedBlocks.contains(block))
- continue;
-
- // Only check blocks outside of the current operation.
- Operation *parentOp = block->getParentOp();
- if (!parentOp || parentOp == op || block->getNumArguments() == 0)
- continue;
-
- // If the region of the block has a type converter, try to convert the block
- // directly.
- if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion) {
- LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
- "block"));
- return failure();
- }
- impl.applySignatureConversion(block, converter, *conversion);
- continue;
- }
-
- // Otherwise, try to legalize the parent operation if it was not generated
- // by this pattern. This is because we will attempt to legalize the parent
- // operation, and blocks in regions created by this pattern will already be
- // legalized later on.
- if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "operation '{0}'({1}) became illegal after rewrite",
- parentOp->getName(), parentOp));
- return failure();
- }
- }
- }
- return success();
-}
-
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
TypeConverter::SignatureConversion result(type.getNumInputs());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
- failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
- failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
- typeConverter, &result)))
+ failed(typeConverter.convertTypes(type.getResults(), newResults)))
return failure();
+ if (!funcOp.getFunctionBody().empty())
+ rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+ &typeConverter);
// Update the function signature in-place.
auto newType = FunctionType::get(rewriter.getContext(),