diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 24e9095..83406c8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -113,9 +112,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, if (layout.size() != shape.size()) return std::nullopt; auto ratio = computeShapeRatio(shape, layout); - if (!ratio.has_value()) + if (ratio.has_value()) { + newShape = ratio.value(); + } else if (!rr || !computeShapeRatio(layout, shape).has_value()) { return std::nullopt; - newShape = ratio.value(); + } + // Round-robin case: continue with original newShape } if (data.size()) { @@ -226,8 +228,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { - return emitError() - << "expected inst_data and lane_layout to have the same rank"; + return emitError() << "expected inst_data and lane_layout to have the same " + "rank, got inst_data " + << inst_data.size() << ", lane_layout " + << lane_layout.size(); } // sg_data is optional for Workgroup layout, but its presence requires @@ -566,8 +570,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, // for gather and scatter ops, Low-precision types are packed in 32-bit units. unsigned bitWidth = elementType.getIntOrFloatBitWidth(); int chunkAlignmentFactor = - bitWidth < targetinfo::packedSizeInBitsForGatherScatter - ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth + bitWidth < xegpu::uArch::generalPackedFormatBitSize + ? xegpu::uArch::generalPackedFormatBitSize / bitWidth : 1; auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding); if (scatterAttr) { |
