diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 85 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt | 17 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 225 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 81 |
5 files changed, 382 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6..46b8251 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 397107b..fb5d1e7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, FailureOr<SmallVector<Value>> LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); - SmallVector<int64_t> layout; + SmallVector<int64_t> sgLayoutInt; if (isForWorkgroup()) { - layout = getEffectiveSgLayoutAsInt(); + sgLayoutInt = getEffectiveSgLayoutAsInt(); } else if (isForSubgroup()) { - layout = getEffectiveLaneLayoutAsInt(); + sgLayoutInt = getEffectiveLaneLayoutAsInt(); } else { return failure(); } - auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); - return affine::delinearizeIndex(builder, loc, linearId, dims); + DenseI32ArrayAttr orderAttr = getOrder(); + + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } + + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + } + } + return result; } /// Implements DistributeLayoutAttr::computeDistributedCoords to generate diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..48fe841 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000..8943ba0 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,225 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" + +#include <optional> + +using namespace mlir; +using namespace mlir::transform; + +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, get that value and cast it to int type. +static DiagnosedSilenceableFailure convertMixedValuesToInt( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { + for (OpFoldResult ofr : ofrs) { + // Attribute case. + if (auto attr = dyn_cast<Attribute>(ofr)) { + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + result.push_back(intAttr.getInt()); + continue; + } + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + } + + // Transform param case. + Value transformValue = cast<Value>(ofr); + if (isa<TransformParamTypeInterface>(transformValue.getType())) { + ArrayRef<Attribute> params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back( + cast<IntegerAttr>(params.front()).getValue().getSExtValue()); + continue; + } + + // Payload value case. + auto payloadOps = state.getPayloadOps(transformValue); + if (!llvm::hasSingleElement(payloadOps)) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(transformValue.getLoc()) + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; + return diag; + } + + Operation *op = *payloadOps.begin(); + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + + IntegerAttr intAttr; + if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) + return transformOp.emitSilenceableError() + << "requires param or handle to be the result of a constant like " + "op"; + + result.push_back(intAttr.getInt()); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Create a layout attribute from the given parameters. +static xegpu::LayoutAttr +createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, + ArrayRef<int32_t> sgData, + std::optional<ArrayRef<int32_t>> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +static xegpu::CreateNdDescOp +setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + assert(descOp.getMixedOffsets().size() == 0 && + "create desc op with offsets is not supported"); + auto oldTensorDesc = descOp.getType(); + auto descType = xegpu::TensorDescType::get( + oldTensorDesc.getShape(), oldTensorDesc.getElementType(), + /*array_length=*/oldTensorDesc.getArrayLength(), + /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), + /*memory_space=*/oldTensorDesc.getMemorySpace(), + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( + descOp, descType, descOp.getSource(), descOp.getMixedSizes(), + descOp.getMixedStrides()); + return newDescOp; +} + +void transform::SetDescLayoutOp::build(OpBuilder &builder, + OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedSgLayout, + ArrayRef<OpFoldResult> mixedSgData, + ArrayRef<OpFoldResult> mixedInstData) { + SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; + SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, result, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData); +} + +DiagnosedSilenceableFailure +transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + SmallVector<int32_t> sgLayout; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); + if (!status.succeeded()) + return status; + + SmallVector<int32_t> sgData; + status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); + if (!status.succeeded()) + return status; + + SmallVector<int32_t> instData; + status = + convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional<ArrayRef<int32_t>>(instData); + + // For now only create_nd_desc op is supported. + auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); + if (!descOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.create_nd_desc op, but got: " + << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Set layout attr in desc op's return type. Replaces old desc op. + auto layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); + auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + + // Map result handles. + results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +namespace { +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect<scf::SCFDialect>(); + declareGeneratedDialect<arith::ArithDialect>(); + declareGeneratedDialect<xegpu::XeGPUDialect>(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + >(); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions<XeGPUTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d12a04df..0a9ef0a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp } }; +// This pattern transforms vector.transpose ops to work at subgroup level. +struct WgToSgVectorTransposeOp + : public OpConversionPattern<vector::TransposeOp> { + using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + + ArrayRef<int64_t> wgShape = resultType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getVector()); + if (!sourceLayout || !sourceLayout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sourceSgLayout = + sourceLayout.getEffectiveSgLayoutAsInt(); + SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); + DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + DenseI32ArrayAttr resultOrder = layout.getOrder(); + + if (!sourceOrder || !resultOrder) { + return rewriter.notifyMatchFailure( + op, "Both source and result must have order attributes"); + } + + ArrayRef<int64_t> permutation = op.getPermutation(); + size_t permutationSize = permutation.size(); + if (sourceSgLayout.size() != permutationSize || + resultSgLayout.size() != permutationSize) { + return rewriter.notifyMatchFailure( + op, "Layouts and permutation must have the same rank"); + } + + // Check that sgLayout, sgData & order are properly transposed for source + // and result + if (!layout.isTransposeOf(sourceLayout, permutation)) + return rewriter.notifyMatchFailure( + op, "Result layout is not a valid transpose of source layout " + "according to permutation"); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + SmallVector<Value> newTransposeOps; + for (auto src : adaptor.getVector()) { + auto newTranspose = vector::TransposeOp::create( + rewriter, op.getLoc(), newResultType, src, permutation); + xegpu::setDistributeLayoutAttr(newTranspose->getResult(0), + layout.dropSgLayoutAndData()); + newTransposeOps.push_back(newTranspose.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newTransposeOps}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp>(patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); @@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::BroadcastOp>( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - - target.addDynamicallyLegalOp<vector::MultiDimReductionOp>( - [=](vector::MultiDimReductionOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); |
