diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 390 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 137 |
2 files changed, 400 insertions, 127 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index f9aa28d5..1a19ab5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,10 +8,8 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -38,55 +36,61 @@ void XeGPUDialect::initialize() { >(); } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit. static SmallVector<SmallVector<Value>> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, - SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, - ArrayRef<int64_t> sizePerSg, - ArrayRef<int64_t> sizePerWg) { - - SmallVector<SmallVector<Value>> offsets; +genCoordinates(OpBuilder &builder, Location loc, + SmallVector<Value> delinearizedId, + ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, + ArrayRef<int64_t> srcShape) { + SmallVector<SmallVector<Value>> coordinates; + + // A distribution unit must be less than or equal to `srcShape` + SmallVector<int64_t> distUnitShape = llvm::map_to_vector( + llvm::zip_equal(srcShape, + computeElementwiseMul(subShapesLayout, subShape)), + [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] - SmallVector<Value> localOffsets = llvm::map_to_vector( - llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { - return builder.createOrFold<index::MulOp>( + // Get the offset of `subShape` within a distribution unit. + SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( + llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::MulIOp>( loc, std::get<0>(t), builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); }); - // distUnit[i] is the minimum value between sizePerWg[i] and - // sgLayout[i] * sizePerSg[i] - SmallVector<int64_t> distUnit = llvm::map_to_vector( - llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), - [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - + // For each dist unit for (SmallVector<int64_t> unitOffs : - StaticTileOffsetRange(sizePerWg, distUnit)) { + StaticTileOffsetRange(srcShape, distUnitShape)) { + // Get dist unit offset within `srcShape`. SmallVector<Value> base = llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { return arith::ConstantIndexOp::create(builder, loc, d); }); - - SmallVector<Value> adds = llvm::map_to_vector( - llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { - return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), - std::get<1>(t)); - }); - + // Calculate `subShape` offset within `srcShape`. + SmallVector<Value> adds = + llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), + [&](const auto &t) -> Value { + return builder.createOrFold<arith::AddIOp>( + loc, std::get<0>(t), std::get<1>(t)); + }); + // Do not go beyond `srcShape` bounds. SmallVector<Value> mods = llvm::map_to_vector( - llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { - return builder.createOrFold<index::RemUOp>( + llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { + return builder.createOrFold<arith::RemUIOp>( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); - offsets.push_back(mods); + coordinates.push_back(mods); } - return offsets; + return coordinates; } // Checks if the given shape can be evenly distributed based on the layout @@ -229,8 +233,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { - return emitError() - << "expected inst_data and lane_layout to have the same rank"; + return emitError() << "expected inst_data and lane_layout to have the same " + "rank, got inst_data " + << inst_data.size() << ", lane_layout " + << lane_layout.size(); } // sg_data is optional for Workgroup layout, but its presence requires @@ -271,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } FailureOr<SmallVector<Value>> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { - // delinearizeSubgroupId is only available for - // workgroup-level layout attribute - if (!isForWorkgroup()) +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { + + SmallVector<int64_t> sgLayoutInt; + if (isForWorkgroup()) { + sgLayoutInt = getEffectiveSgLayoutAsInt(); + } else if (isForSubgroup()) { + sgLayoutInt = getEffectiveLaneLayoutAsInt(); + } else { return failure(); + } - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); + DenseI32ArrayAttr orderAttr = getOrder(); - auto dims = - llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } - return affine::delinearizeIndex(builder, loc, linearId, dims); + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal); + } + } + return result; } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate /// instructions for computing multi-dimensional offsets when distributed by /// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { - if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { return failure(); - - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + } + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); - SmallVector<Value> sgIds = *maybeIds; + SmallVector<Value> ids = *maybeIds; + + return genCoordinates(builder, loc, ids, layout, subShape, shape); +} + +bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::SliceAttr>(other)) + return false; + + return *this == dyn_cast<xegpu::LayoutAttr>(other); +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) { + auto sgDataOpt = getSgData(); + auto instDataOpt = getInstData(); + auto laneDataOpt = getLaneData(); + + SmallVector<int32_t> sgData; + SmallVector<int32_t> instData; + SmallVector<int32_t> laneData; + + if (sgDataOpt) { + sgData = llvm::to_vector(sgDataOpt.asArrayRef()); + } + if (instDataOpt) { + instData = llvm::to_vector(instDataOpt.asArrayRef()); + } + if (laneDataOpt) { + laneData = llvm::to_vector(laneDataOpt.asArrayRef()); + } - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgData.size())) + sgData[dim] = 1; + if (dim < static_cast<int64_t>(instData.size())) + instData[dim] = 1; + if (dim < static_cast<int64_t>(laneData.size())) + laneData[dim] = 1; + } + + return LayoutAttr::get( + getContext(), getSgLayout(), + sgData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgData), + instData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), instData), + getLaneLayout(), + laneData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneData), + getOrder()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + auto sgLayoutOpt = getSgLayout(); + auto laneLayoutOpt = getLaneLayout(); + + SmallVector<int32_t> sgLayout; + SmallVector<int32_t> laneLayout; + + if (sgLayoutOpt) { + sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef()); + } + if (laneLayoutOpt) { + laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgLayout.size())) + sgLayout[dim] = 1; + if (dim < static_cast<int64_t>(laneLayout.size())) + laneLayout[dim] = 1; + } + + return LayoutAttr::get( + getContext(), + sgLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgLayout), + getSgData(), getInstData(), + laneLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneLayout), + getLaneData(), getOrder()); } //===----------------------------------------------------------------------===// @@ -374,34 +521,43 @@ SliceAttr SliceAttr::flatten() const { } FailureOr<SmallVector<Value>> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, - Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { SliceAttr attr = flatten(); auto parent = dyn_cast<LayoutAttr>(attr.getParent()); - return parent.delinearizeSubgroupId(builder, loc, linearId); + return parent.delinearizeId(builder, loc, linearId); } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, - ArrayRef<int64_t> shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, + Value linearId, ArrayRef<int64_t> shape) { assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); if (!isForWorkgroup()) return failure(); - SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); - SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); - if (sgShape.empty()) { - if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); + SmallVector<int64_t> layout; + SmallVector<int64_t> subShape; + if (isForWorkgroup()) { + layout = getEffectiveSgLayoutAsInt(); + subShape = getEffectiveSgDataAsInt(); + } else if (isForSubgroup()) { + layout = getEffectiveLaneLayoutAsInt(); + subShape = getEffectiveLaneDataAsInt(); + } else { + return failure(); + } + + if (subShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, layout)) + subShape = derivedShape.value(); else return failure(); } // delinearize Ids - auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); + auto maybeIds = delinearizeId(builder, loc, linearId); if (failed(maybeIds)) return failure(); @@ -411,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, SmallVector<Value> sgIds = XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); - return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, - shape); + return genCoordinates(builder, loc, sgIds, layout, subShape, shape); } bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { @@ -435,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::LayoutAttr>(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + +// Helper function to adjust unit dimensions from sliced space to parent space +static SetVector<int64_t> +adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims, + ArrayRef<int64_t> sliceDims) { + // Reconstruct parent's non-sliced dimensions + + int64_t parentRank = sliceDims.size() + unitDims.size(); + llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(), + sliceDims.end()); + SmallVector<int64_t> nonSlicedDims; + for (int64_t i = 0; i < parentRank; ++i) { + if (!slicedDimsSet.contains(i)) + nonSlicedDims.push_back(i); + } + + // Map unit dims from sliced space to parent space + SetVector<int64_t> adjustUnitDims; + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(nonSlicedDims.size())) { + adjustUnitDims.insert(nonSlicedDims[dim]); + } + } + + return adjustUnitDims; +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims), + attr.getDims()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims), + attr.getDims()); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// @@ -569,8 +787,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, // for gather and scatter ops, Low-precision types are packed in 32-bit units. unsigned bitWidth = elementType.getIntOrFloatBitWidth(); int chunkAlignmentFactor = - bitWidth < targetinfo::packedSizeInBitsForGatherScatter - ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth + bitWidth < xegpu::uArch::generalPackedFormatBitSize + ? xegpu::uArch::generalPackedFormatBitSize / bitWidth : 1; auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding); if (scatterAttr) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2..91ba07a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, + UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref<InFlightDiagnostic()> emitError) { if (!dataTy) { if (subgroup_block_io) return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; + "are only allowed when result is a VectorType."; else return success(); } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, ArrayRef<int64_t> dataShape = dataTy.getShape(); ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + ArrayAttr strideAttr = mdescTy.getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + if (subgroup_block_io && layout) { + auto laneData = layout.getEffectiveLaneDataAsInt(); + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (!laneData.empty()) { + bool isLaneDataContiguous = + std::all_of(laneData.begin(), std::prev(laneData.end()), + [](int x) { return x == 1; }); + if (!isLaneDataContiguous) + return emitError() << "With subgroup_block_io, accessed data must be " + "contiguous and coalesced."; + for (size_t i = 0; i < laneData.size(); ++i) { + if (laneLayout[i] != blockShape[i]) + return emitError() << "With subgroup_block_io, the block shape must " + "match the lane layout."; + if (laneLayout[i] != 1 && strides[i] != 1) + return emitError() << "With subgroup_block_io, the distributed " + "dimensions must be contiguous."; + } + } + } if (dataShape.size() == 2) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitError() << "data shape must not exceed mem_desc shape."; } else { - SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); // if the subgroup_block_io attribute is set, mdescTy must have block // attribute if (subgroup_block_io && !blockShape.size()) @@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); // if shape and strides are from Memref, we don't need attributes for them - // to keep the IR print clean. - if (staticShape == memrefShape && staticStrides == memrefStrides) { + // to keep the IR print clean (only do so for full-static case, otherwise + // printer would fail trying to print empty array-attr). + if (staticShape == memrefShape && staticStrides == memrefStrides && + dynamicShape.empty() && dynamicStrides.empty()) { staticShapeAttr = DenseI64ArrayAttr(); staticStridesAttr = DenseI64ArrayAttr(); } @@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult PrefetchNdOp::verify() { @@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() { return emitOpError("invalid l3_hint: ") << getL3HintAttr(); int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, @@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, - packed, transpose, l1_hint, l2_hint, l3_hint); + packed, transpose, l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/layout); } LogicalResult LoadNdOp::verify() { @@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() { << tdescTy; int64_t tDescRank = tdescTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/nullptr); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult StoreNdOp::verify() { @@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() { << dstTy; int64_t tDescRank = dstTy.getRank(); - int64_t offsetSize = static_cast<int64_t>(getOffsets().size()); - int64_t constOffsetSize = - getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; - if (((offsetSize != 0) && (offsetSize != tDescRank)) || - ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + int64_t offsetSize = getMixedOffsets().size(); + if (offsetSize != 0 && offsetSize != tDescRank) return emitOpError( "Mismatched ranks between offsets and tensor descriptor"); @@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, - IntegerAttr{}); + IntegerAttr{}, /*anchor_layout=*/nullptr); } //===----------------------------------------------------------------------===// @@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + auto loc = source.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); +} + +void StoreScatterOp::build( + OpBuilder &builder, OperationState &state, Value value, Value dest, + ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { + auto loc = dest.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() { MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() { UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); MemDescType mdescTy = getMemDesc().getType(); return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + getLayoutAttr(), [&]() { return emitError(); }); } namespace mlir { |
