diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 173 | 
1 files changed, 97 insertions, 76 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index f9aa28d5..397107b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -11,7 +11,6 @@  #include "mlir/Dialect/Index/IR/IndexOps.h"  #include "mlir/Dialect/Utils/IndexingUtils.h"  #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"  #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"  #include "mlir/IR/Builders.h"  #include "mlir/IR/DialectImplementation.h" @@ -38,55 +37,61 @@ void XeGPUDialect::initialize() {        >();  } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit.  static SmallVector<SmallVector<Value>> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, -                         SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, -                         ArrayRef<int64_t> sizePerSg, -                         ArrayRef<int64_t> sizePerWg) { - -  SmallVector<SmallVector<Value>> offsets; +genCoordinates(OpBuilder &builder, Location loc, +               SmallVector<Value> delinearizedId, +               ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, +               ArrayRef<int64_t> srcShape) { +  SmallVector<SmallVector<Value>> coordinates; + +  // A distribution unit must be less than or equal to `srcShape` +  SmallVector<int64_t> distUnitShape = llvm::map_to_vector( +      llvm::zip_equal(srcShape, +                      computeElementwiseMul(subShapesLayout, subShape)), +      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); -  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] -  SmallVector<Value> localOffsets = llvm::map_to_vector( -      llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { +  // Get the offset of `subShape` within a distribution unit. +  SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( +      llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {          return builder.createOrFold<index::MulOp>(              loc, std::get<0>(t),              builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));        }); -  // distUnit[i] is the minimum value between sizePerWg[i] and -  // sgLayout[i] * sizePerSg[i] -  SmallVector<int64_t> distUnit = llvm::map_to_vector( -      llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), -      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - +  // For each dist unit    for (SmallVector<int64_t> unitOffs : -       StaticTileOffsetRange(sizePerWg, distUnit)) { +       StaticTileOffsetRange(srcShape, distUnitShape)) { +    // Get dist unit offset within `srcShape`.      SmallVector<Value> base =          llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {            return arith::ConstantIndexOp::create(builder, loc, d);          }); - -    SmallVector<Value> adds = llvm::map_to_vector( -        llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { -          return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), -                                                     std::get<1>(t)); -        }); - +    // Calculate `subShape` offset within `srcShape`. +    SmallVector<Value> adds = +        llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), +                            [&](const auto &t) -> Value { +                              return builder.createOrFold<arith::AddIOp>( +                                  loc, std::get<0>(t), std::get<1>(t)); +                            }); +    // Do not go beyond `srcShape` bounds.      SmallVector<Value> mods = llvm::map_to_vector( -        llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { +        llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {            return builder.createOrFold<index::RemUOp>(                loc, std::get<0>(t),                arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));          }); -    offsets.push_back(mods); +    coordinates.push_back(mods);    } -  return offsets; +  return coordinates;  }  // Checks if the given shape can be evenly distributed based on the layout @@ -229,8 +234,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,    }    if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { -    return emitError() -           << "expected inst_data and lane_layout to have the same rank"; +    return emitError() << "expected inst_data and lane_layout to have the same " +                          "rank, got inst_data " +                       << inst_data.size() << ", lane_layout " +                       << lane_layout.size();    }    // sg_data is optional for Workgroup layout, but its presence requires @@ -271,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,  }  FailureOr<SmallVector<Value>> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, -                                  Value linearId) { -  // delinearizeSubgroupId is only available for -  // workgroup-level layout attribute -  if (!isForWorkgroup()) -    return failure(); +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {    // TODO: handle order attribute    auto hasDefaultOrder = [&]() { @@ -286,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,    };    if (!hasDefaultOrder())      return mlir::emitError(loc, "order attribute is currently not supported."); - -  auto dims = -      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { -        return builder.createOrFold<arith::ConstantIndexOp>(loc, d); -      }); +  SmallVector<int64_t> layout; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +  } else { +    return failure(); +  } +  auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { +    return builder.createOrFold<arith::ConstantIndexOp>(loc, d); +  });    return affine::delinearizeIndex(builder, loc, linearId, dims);  } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate  /// instructions for computing multi-dimensional offsets when distributed by  /// LayoutAttr.  FailureOr<SmallVector<SmallVector<Value>>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, -                       ArrayRef<int64_t> shape) { -  if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, +                                     Value linearId, ArrayRef<int64_t> shape) { +  SmallVector<int64_t> layout; +  SmallVector<int64_t> subShape; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +    subShape = getEffectiveSgDataAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +    subShape = getEffectiveLaneDataAsInt(); +  } else {      return failure(); - -  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); -  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); -  if (sgShape.empty()) { -    if (auto derivedShape = computeShapeRatio(shape, sgLayout)) -      sgShape = derivedShape.value(); +  } +  if (subShape.empty()) { +    if (auto derivedShape = computeShapeRatio(shape, layout)) +      subShape = derivedShape.value();      else        return failure();    }    // delinearize Ids -  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); +  auto maybeIds = delinearizeId(builder, loc, linearId);    if (failed(maybeIds))      return failure(); -  SmallVector<Value> sgIds = *maybeIds; +  SmallVector<Value> ids = *maybeIds; -  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, -                                  shape); +  return genCoordinates(builder, loc, ids, layout, subShape, shape);  }  //===----------------------------------------------------------------------===// @@ -374,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {  }  FailureOr<SmallVector<Value>> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, -                                 Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {    SliceAttr attr = flatten();    auto parent = dyn_cast<LayoutAttr>(attr.getParent()); -  return parent.delinearizeSubgroupId(builder, loc, linearId); +  return parent.delinearizeId(builder, loc, linearId);  } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr.  FailureOr<SmallVector<SmallVector<Value>>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, -                      ArrayRef<int64_t> shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, +                                    Value linearId, ArrayRef<int64_t> shape) {    assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");    if (!isForWorkgroup())      return failure(); -  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); -  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); -  if (sgShape.empty()) { -    if (auto derivedShape = computeShapeRatio(shape, sgLayout)) -      sgShape = derivedShape.value(); +  SmallVector<int64_t> layout; +  SmallVector<int64_t> subShape; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +    subShape = getEffectiveSgDataAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +    subShape = getEffectiveLaneDataAsInt(); +  } else { +    return failure(); +  } + +  if (subShape.empty()) { +    if (auto derivedShape = computeShapeRatio(shape, layout)) +      subShape = derivedShape.value();      else        return failure();    }    // delinearize Ids -  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); +  auto maybeIds = delinearizeId(builder, loc, linearId);    if (failed(maybeIds))      return failure(); @@ -411,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,    SmallVector<Value> sgIds =        XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); -  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, -                                  shape); +  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);  }  bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { @@ -569,8 +590,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,    // for gather and scatter ops, Low-precision types are packed in 32-bit units.    unsigned bitWidth = elementType.getIntOrFloatBitWidth();    int chunkAlignmentFactor = -      bitWidth < targetinfo::packedSizeInBitsForGatherScatter -          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth +      bitWidth < xegpu::uArch::generalPackedFormatBitSize +          ? xegpu::uArch::generalPackedFormatBitSize / bitWidth            : 1;    auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);    if (scatterAttr) {  | 
