aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp2
3 files changed, 13 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e9095..f9aa28d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
if (layout.size() != shape.size())
return std::nullopt;
auto ratio = computeShapeRatio(shape, layout);
- if (!ratio.has_value())
+ if (ratio.has_value()) {
+ newShape = ratio.value();
+ } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
return std::nullopt;
- newShape = ratio.value();
+ }
+ // Round-robin case: continue with original newShape
}
if (data.size()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2c37140..ec5feb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
+ // Remove leading unit dimensions from vector ops and then
+ // do the unrolling.
+ {
+ RewritePatternSet patterns(ctx);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ (void)applyPatternsGreedily(op, std::move(patterns));
+ }
xegpu::UnrollOptions options;
options.setFilterConstraint(
[&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd..a38993e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -147,7 +147,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto parentOp = arg.getOwner()->getParentOp();
+ auto *parentOp = arg.getOwner()->getParentOp();
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)