//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir using namespace mlir; #define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { /// Casts the given vector value `v` to the expected vector type `expectedTy`. static Value castValueTo(ConversionPatternRewriter &rewriter, TypedValue v, VectorType expectedTy) { // If the type matches, simply return the value itself. if (v.getType() == expectedTy) return v; // If only shape differs, use shape cast. if (isa(v.getType()) && v.getType().getNumElements() == expectedTy.getNumElements()) return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v); // Else create an unrealized cast. auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(), expectedTy, ValueRange{v}); return newOp.getResult(0); } /// Checks if all XeGPU anchor ops and vector results have valid layouts. static LogicalResult verifyLayouts(Operation *root) { auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult { if (auto anchorOp = dyn_cast(nestedOp)) { auto layout = anchorOp.getAnchorLayout(); if (!layout) { nestedOp->emitError("expected anchor layout attribute on operation"); return WalkResult::interrupt(); } return WalkResult::advance(); } // For each vector result, check if the op contains a result layout // attribute. for (OpResult result : nestedOp->getResults()) { if (isa(result.getType())) { auto layout = xegpu::getDistributeLayoutAttr(result); if (!layout) { nestedOp->emitError( "expected result layout attribute on vector result"); return WalkResult::interrupt(); } } } return WalkResult::advance(); }); return walkResult.wasInterrupted() ? failure() : success(); } /// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc /// op. This simply drops the layout attribute from the tensor descriptor type. struct SgToWiCreateNdDesc : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::TensorDescType resultType = op.getType(); // If no layout, nothing to do. if (!resultType.getLayout()) return failure(); auto newOp = xegpu::CreateNdDescOp::create( rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(), op->getAttrs()); rewriter.replaceOp(op, newOp.getResult()); return success(); } }; /// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output /// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the /// original rank. struct SgToWiLoadNd : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = op.getAnchorLayout(); // If no layout, nothing to do. if (!layout) return failure(); // Check if the layout attached to the tensor descriptor is same as the // anchor layout. Otherwise, this is a conflict. if (op.getTensorDescType().getLayout() != layout) return rewriter.notifyMatchFailure( op, "conflicting layout attributes on tensor descriptor and anchor"); auto uArch = getUArch(xegpu::getChipStr(op).value_or("")); if (!uArch) return rewriter.notifyMatchFailure( op, "xegpu::LoadNdOp require target attribute attached to " "determine transpose " "requirement"); auto supportedWiResultTyOrFailure = xegpu::getDistributedVectorType(op.getTensorDescType()); auto expectedWiResultTyOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType()); if (failed(supportedWiResultTyOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute the workitem vector type for LoadNdOp"); if (failed(expectedWiResultTyOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute expected workitem vector type from lane layout"); auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(), adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr); // Set the packed attribute if the layout requires it. newOp.setPacked(xegpu::requirePacked(cast(layout))); // Set the transpose attribute if the layout requires it. if (xegpu::requireTranspose(cast(layout), uArch)) newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0})); rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(), expectedWiResultTyOrFailure.value())); return success(); } }; /// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored /// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the /// incoming value to 1D. struct SgToWiStoreNd : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = op.getAnchorLayout(); // If no layout, nothing to do. if (!layout) return failure(); // Check if the layout attached to the tensor descriptor and value layout is // same as the anchor layout. Otherwise, this is a conflict. if (op.getTensorDescType().getLayout() != layout) return rewriter.notifyMatchFailure( op, "conflicting layout attributes on tensor descriptor and anchor"); auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0)); if (valueLayout != layout) return rewriter.notifyMatchFailure( op, "conflicting layout attributes on value and anchor"); auto supportedWiValueTyOrFailure = xegpu::getDistributedVectorType(op.getTensorDescType()); if (failed(supportedWiValueTyOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute wi vector type for StoreNdOp value from tensor " "descriptor"); xegpu::StoreNdOp::create( rewriter, op.getLoc(), castValueTo(rewriter, cast>(adaptor.getValue()), supportedWiValueTyOrFailure.value()), adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr); rewriter.eraseOp(op); return success(); } }; /// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs /// and output of workitem-level Dpas op are 1D. Necessary casts are added to /// convert the inputs and output to/from 1D. struct SgToWiDpas : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // llvm::errs() << "DpasOpPattern matchAndRewrite called\n"; // Check if the op has A, B and CD layouts attached. auto layoutA = cast(op.getLayoutAAttr()); auto layoutB = cast(op.getLayoutBAttr()); auto layoutCd = cast(op.getLayoutCdAttr()); if (!layoutA || !layoutB || !layoutCd) return failure(); // llvm::errs() << "tryning to calculate wi types for dpas op\n"; auto wiResultTyOrFailure = xegpu::getDistributedVectorType(op.getType(), layoutCd); auto wiATypeOrFailure = xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA); auto wiBTypeOrFailure = xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB); auto expectedWiResultTyOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType()); if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) || failed(wiBTypeOrFailure)) return rewriter.notifyMatchFailure( op, "failed to calculate supported workitem vector types for DpasOp " "from layouts"); if (failed(expectedWiResultTyOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute expected workitem vector type for DpasOp from " "lane layout"); auto newOp = xegpu::DpasOp::create( rewriter, op->getLoc(), wiResultTyOrFailure.value(), castValueTo(rewriter, cast>(adaptor.getLhs()), wiATypeOrFailure.value()), castValueTo(rewriter, cast>(adaptor.getRhs()), wiBTypeOrFailure.value()), castValueTo(rewriter, cast>(adaptor.getAcc()), wiResultTyOrFailure.value()), /** layoutA**/ nullptr, /** layoutB**/ nullptr, /** layoutCd**/ nullptr); // Explicitly set the new types to enable correct type materializations. rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(), expectedWiResultTyOrFailure.value())); return success(); } }; /// Distributes elementwise ops to workitem-level elementwise ops. This /// currently handles elementwise ops with single result only. struct SgToWiElementWise : public ConversionPattern { SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx) : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Only match ops with elementwise trait and single result. if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); auto resultType = dyn_cast(op->getResult(0).getType()); if (!resultType) return rewriter.notifyMatchFailure( op, "operation result is not a vector type"); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(llvm::cast(op->getResult(0))); if (!layout || !layout.isForSubgroup()) return rewriter.notifyMatchFailure( op, "operation result does not have subgroup distribute layout"); auto wiShapeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType); if (failed(wiShapeOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute workitem vector type from the layout"); VectorType newResultType = wiShapeOrFailure.value(); OperationState state(op->getLoc(), op->getName()); state.addOperands(operands); state.addTypes(newResultType); // Copy all attributes except for DistributeLayoutAttr. for (auto attr : op->getAttrs()) { if (!isa(attr.getValue())) state.addAttribute(attr.getName(), attr.getValue()); } Operation *newOp = rewriter.create(state); rewriter.replaceOp(op, newOp->getResult(0)); return success(); } }; /// Distributes a subgroup-level arith ConstantOp to workitem-level arith /// ConstantOp. struct SgToWiArithConstant : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = dyn_cast(op.getType()); if (!resultType) return failure(); // Only handle dense vector constants auto dense = dyn_cast(op.getValue()); if (!dense) return rewriter.notifyMatchFailure( op, "only dense splat vector constants are supported"); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(llvm::cast(op.getResult())); if (!layout || !layout.isForSubgroup()) return rewriter.notifyMatchFailure( op, "operation result does not have subgroup distribute layout"); auto wiShapeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType); if (failed(wiShapeOrFailure)) return rewriter.notifyMatchFailure( op, "unable to compute workitem vector type from the layout"); VectorType newResultType = wiShapeOrFailure.value(); auto sclarValue = dense.getSplatValue(); auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue); auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType, newDenseAttr); rewriter.replaceOp(op, newOp.getResult()); return success(); } }; /// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op. struct SgToWiPrefetchNd : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = op.getAnchorLayout(); // If no layout, nothing to do. if (!layout) return failure(); xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr); rewriter.eraseOp(op); return success(); } }; struct XeGPUSgToWiDistributeExperimentalPass : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase< XeGPUSgToWiDistributeExperimentalPass> { void runOnOperation() override; }; } // namespace void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() { // Verify if all XeGPU anchor ops and vector ops have result layouts. // TODO: This can be removed once the full layout refactoring is done. Operation *root = getOperation(); if (failed(verifyLayouts(root))) { LLVM_DEBUG(DBGS() << "XeGPUSgToWiDistributeExperimentalPass: layout " "verification failed\n"); signalPassFailure(); return; } // Collect existing UnrealizedConversionCastOps. These must be preserved. llvm::SmallSetVector existingCasts; root->walk( [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); }); // Perform a structural type conversion to convert structural ops to have WI // types. This will insert UnrealizedConversionCastOps to make the IR // valid. auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { UnrealizedConversionCastOp castOp = UnrealizedConversionCastOp::create(builder, loc, type, inputs); return castOp.getResult(0); }; { ConversionTarget target(getContext()); TypeConverter typeConverter; RewritePatternSet patterns(&getContext()); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter); scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( typeConverter, patterns, target); target.addLegalOp(); (void)applyPartialConversion(root, target, std::move(patterns)); } // Structural type conversion can generate some redundant // UnrealizedConversionCastOps to materialize the SG type from type converted // WI type. These are redundant at this point and can be eliminated by // inserting shape casts instead. // Example: // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32> // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32> // This can be replaced with: // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32> OpBuilder builder(root); root->walk([&](UnrealizedConversionCastOp op) { // If this op existed before, nothing to do. if (existingCasts.contains(op)) return; // number of inputs and outputs must be 1. if (op.getNumOperands() != 1 || op.getNumResults() != 1) return; // Both input and output types must be vector types. auto singleInput = op.getInputs()[0]; auto inputTy = dyn_cast(singleInput.getType()); auto outputTy = dyn_cast(op.getResult(0).getType()); if (!inputTy || !outputTy) return; // Check if the defining op of the input is also an // UnrealizedConversionCastOp and it has a single user (which is this // op). auto definingOp = singleInput.getDefiningOp(); if (!definingOp || !definingOp->hasOneUse()) return; auto inputOfDefiningOp = definingOp.getInputs()[0]; // If the input of the defining op and output type are both vector types // have same number of elements, insert a shape cast. auto inputOfDefiningOpTy = dyn_cast(inputOfDefiningOp.getType()); if (inputOfDefiningOpTy && inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) { builder.setInsertionPoint(op); auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(), outputTy, inputOfDefiningOp); op.replaceAllUsesWith(ValueRange{shapeCast.getResult()}); return; } }); // At this point, we will have some dead UnrealizedConversionCastOps. Just // erase them. bool changed = true; while (changed) { changed = false; root->walk([&](UnrealizedConversionCastOp op) { // Skip existing casts. if (existingCasts.contains(op)) return; if (op.use_empty()) { op.erase(); changed = true; } }); } } void xegpu::populateXeGPUSgToWiDistributeTypeConversions( TypeConverter &typeConverter) { // Any type other than TensorDescType and VectorType are legal as is. typeConverter.addConversion([](Type type) -> std::optional { if (!isa(type)) return type; return std::nullopt; }); // For TensorDescType, drop the layout attribute if any. typeConverter.addConversion([](TensorDescType type) -> Type { if (type.getLayoutAttr()) { return type.dropLayouts(); } return type; }); // For VectorType, check if there is a distribute layout attribute on the // value. If so, convert to the distributed vector type based on the layout. typeConverter.addConversion([](Value v) -> std::optional { auto type = v.getType(); // If value is not vector type, nothing to do. if (!isa(type)) return std::nullopt; auto layout = xegpu::getDistributeLayoutAttr(v); if (!layout || !layout.isForSubgroup()) return type; // Vector type is distributed based on lane layout. auto newTyOrFailure = getDistVecTypeBasedOnLaneLayout(layout, cast(type)); if (failed(newTyOrFailure)) return type; return *newTyOrFailure; }); } void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { populateXeGPUSgToWiDistributeTypeConversions(typeConverter); // CreateNdDescOp is legal only if its result type has no layout attribute. target.addDynamicallyLegalOp( [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); }); // Any anchor XeGPU op is legal only if it has no anchor layout. target.addDynamicallyLegalDialect([](Operation *op) { auto anchorOp = dyn_cast(op); if (!anchorOp) return true; return !anchorOp.getAnchorLayout(); }); // Arith constants are legal only if they have no temporary layout attribute. target.addDynamicallyLegalOp( [=](arith::ConstantOp op) -> bool { // If the result type is not a vector, it's legal. if (!isa(op.getResult().getType())) return true; return !xegpu::getTemporaryLayout(dyn_cast(op.getResult())); }); // In math and arith dialects, only handle elementwise ops with a single // result and with a result layout attribute. target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { // Only handle elementwise mappable ops if (!OpTrait::hasElementwiseMappableTraits(op)) return true; // Only handle ops with single vector result if (op->getNumResults() != 1) return true; VectorType resultType = dyn_cast(op->getResult(0).getType()); if (!resultType) return true; // Check if all operands are vectors of the same shape for (Value operand : op->getOperands()) { VectorType operandType = dyn_cast(operand.getType()); if (!operandType || operandType.getShape() != resultType.getShape()) { return true; } } return !xegpu::getTemporaryLayout(dyn_cast(op->getResult(0))); }); target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); patterns.add( typeConverter, patterns.getContext()); }