//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::xegpu; namespace { #define DEBUG_TYPE "test-xegpu-unroll" struct TestXeGPUUnrollingPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns) StringRef getArgument() const final { return "test-xegpu-unrolling-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to unroll ops in the xegpu dialect"; } void getDependentDialects(::mlir::DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); } TestXeGPUUnrollingPatterns() = default; TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass) : PassWrapper(pass) {} void runOnOperation() override { MLIRContext *ctx = &getContext(); xegpu::UnrollOptions options; options.setNativeShapeFn([&](Operation *op) -> std::optional> { if (isa(op)) { xegpu::TensorDescType tdescTy; if (auto createNdOp = dyn_cast(op)) { tdescTy = createNdOp.getType(); } else if (auto updateNdOp = dyn_cast(op)) { tdescTy = updateNdOp.getTensorDescType(); } else if (auto prefetchNdOp = dyn_cast(op)) { tdescTy = prefetchNdOp.getTensorDescType(); } else if (auto loadNdOp = dyn_cast(op)) { tdescTy = loadNdOp.getTensorDescType(); } else if (auto storeNdOp = dyn_cast(op)) { tdescTy = storeNdOp.getTensorDescType(); } else if (auto createOp = dyn_cast(op)) { tdescTy = createOp.getType(); } else if (auto updateOp = dyn_cast(op)) { tdescTy = updateOp.getTensorDescType(); } else if (auto prefetchOp = dyn_cast(op)) { tdescTy = prefetchOp.getTensorDescType(); } else if (auto loadOp = dyn_cast(op)) { if (loadOp.getOffsets()) { auto layout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); if (layout && layout.isForSubgroup()) { auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); } return std::nullopt; } tdescTy = loadOp.getTensorDescType(); } else if (auto storeOp = dyn_cast(op)) { if (storeOp.getOffsets()) { auto layout = llvm::dyn_cast_or_null( op->getAttr("layout")); if (layout && layout.isForSubgroup()) { auto inst_data = layout.getEffectiveInstDataAsInt(); if (!inst_data.empty()) return SmallVector(inst_data.begin(), inst_data.end()); } return std::nullopt; } tdescTy = storeOp.getTensorDescType(); } if (auto layout = tdescTy.getLayoutAttr()) { auto inst_data = layout.getInstData(); if (inst_data && layout.isForSubgroup()) return SmallVector(inst_data.asArrayRef().begin(), inst_data.asArrayRef().end()); } } if (isa(op)) return SmallVector{8, 16, 16}; return std::nullopt; }); options.setUnrolledTypesFn( [&](ShapedType type, ArrayRef tileShape, bool returnSingleType = false) -> SmallVector { Type elemTy = type.getElementType(); Type newTy; // TensorDescType needs to drop the inst_data field in the layout // attribute if (auto tdescTy = dyn_cast(type)) { Attribute encoding = tdescTy.getEncoding(); auto layout = tdescTy.getLayoutAttr(); // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { int64_t chunkSize = tdescTy.getChunkSizeAsInt(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; auto instData = layout.getInstData(); if (!instData.empty()) blockedChunkSize = instData.asArrayRef().back(); // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; } } if (layout) { if (layout.getLaneLayout() == nullptr) layout = xegpu::LayoutAttr(); else layout = layout.dropInstData(); } newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout); } else { newTy = type.clone(tileShape, elemTy); } if (returnSingleType) return SmallVector{newTy}; std::optional> ratio = computeShapeRatio(type.getShape(), tileShape); assert(ratio && "Expecting the ratio to be valid."); return SmallVector(computeProduct(*ratio), newTy); }); RewritePatternSet patterns(ctx); populateXeGPUUnrollPatterns(patterns, options); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; #undef DEBUG_TYPE #define DEBUG_TYPE "test-xegpu-layout-interface" // Test pattern for distributing vector::StepOp from workgroup to subgroup. // Validates DistributeLayoutAttr interfaces for offset computation // abstraction between LayoutAttr and SliceAttr. class TestStepOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto layoutName = xegpu::getLayoutName(op->getResult(0)); auto sliceAttr = op->getAttrOfType(layoutName); if (!sliceAttr || sliceAttr.getRank() != 1) return failure(); std::optional> sgShape = sliceAttr.getEffectiveSgDataAsInt(); if (!sgShape) return failure(); Location loc = op.getLoc(); VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape); if (failed(maybeOffsets)) return failure(); VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); Value base = vector::StepOp::create(rewriter, loc, newTy); SmallVector newOps; for (auto offsets : *maybeOffsets) { Value bcast = vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); Value add = arith::AddIOp::create(rewriter, loc, base, bcast); newOps.push_back(add); } rewriter.replaceOpWithMultiple(op, {newOps}); return success(); } }; struct TestXeGPUSGDistribute : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUSGDistribute) StringRef getArgument() const final { return "test-xegpu-sg-distribute"; } StringRef getDescription() const final { return "Test the implementation of XeGPU Subgroup Distribution"; } void getDependentDialects(::mlir::DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); } TestXeGPUSGDistribute() = default; TestXeGPUSGDistribute(const TestXeGPUSGDistribute &pass) = default; void runOnOperation() override { RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUSubgroupDistributePatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; struct TestXeGPULayoutInterface : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface) StringRef getArgument() const final { return "test-xegpu-layout-interface"; } StringRef getDescription() const final { return "Test the implementation of XeGPU Layout interfaces"; } void getDependentDialects(::mlir::DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); } TestXeGPULayoutInterface() = default; TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass) : PassWrapper(pass) {} void runOnOperation() override { MLIRContext *ctx = &getContext(); TypeConverter typeConverter; auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { return UnrealizedConversionCastOp::create(builder, loc, type, inputs) .getResult(0); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); RewritePatternSet patterns(ctx); patterns.add(typeConverter, ctx); ConversionTarget target(*ctx); auto isLegal = [&](xegpu::SliceAttr layout) -> bool { return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp( [&](vector::StepOp op) -> bool { auto layoutName = xegpu::getLayoutName(op->getResult(0)); auto sliceAttr = op->getAttrOfType(layoutName); return isLegal(sliceAttr); }); target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); (void)applyPartialConversion(getOperation(), target, std::move(patterns)); } }; } // namespace namespace mlir { namespace test { void registerTestXeGPULowerings() { PassRegistration(); PassRegistration(); PassRegistration(); } } // namespace test } // namespace mlir