diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Bindings/Python/Globals.h | 25 | ||||
-rw-r--r-- | mlir/lib/Bindings/Python/Pass.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/CAPI/IR/Pass.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 167 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 34 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 147 |
6 files changed, 332 insertions, 76 deletions
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 71a051c..1e81f53 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -17,6 +17,7 @@ #include "NanobindUtils.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -151,6 +152,29 @@ public: TracebackLoc &getTracebackLoc() { return tracebackLoc; } + class TypeIDAllocator { + public: + TypeIDAllocator() : allocator(mlirTypeIDAllocatorCreate()) {} + ~TypeIDAllocator() { + if (allocator.ptr) + mlirTypeIDAllocatorDestroy(allocator); + } + TypeIDAllocator(const TypeIDAllocator &) = delete; + TypeIDAllocator(TypeIDAllocator &&other) : allocator(other.allocator) { + other.allocator.ptr = nullptr; + } + + MlirTypeIDAllocator get() { return allocator; } + MlirTypeID allocate() { + return mlirTypeIDAllocatorAllocateTypeID(allocator); + } + + private: + MlirTypeIDAllocator allocator; + }; + + MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); } + private: static PyGlobals *instance; @@ -173,6 +197,7 @@ private: llvm::StringSet<> loadedDialectModules; TracebackLoc tracebackLoc; + TypeIDAllocator typeIDAllocator; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e489585..572afa9 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,6 +8,7 @@ #include "Pass.h" +#include "Globals.h" #include "IRModule.h" #include "mlir-c/Pass.h" // clang-format off @@ -57,6 +58,13 @@ private: /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + + //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- nb::class_<MlirExternalPass>(m, "ExternalPass") @@ -138,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -181,9 +197,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { name = nb::cast<std::string>( nb::borrow<nb::str>(run.attr("__name__"))); } - MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); - MlirTypeID passID = - mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + MlirTypeID passID = PyGlobals::get().allocateTypeID(); MlirExternalPassCallbacks callbacks; callbacks.construct = [](void *obj) { (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b0a6ec1..72bec11 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorHandling.h" #include <optional> using namespace mlir; @@ -79,6 +80,20 @@ void mlirPassManagerEnableTiming(MlirPassManager passManager) { unwrap(passManager)->enableTiming(); } +void mlirPassManagerEnableStatistics(MlirPassManager passManager, + MlirPassDisplayMode displayMode) { + PassDisplayMode mode; + switch (displayMode) { + case MLIR_PASS_DISPLAY_MODE_LIST: + mode = PassDisplayMode::List; + break; + case MLIR_PASS_DISPLAY_MODE_PIPELINE: + mode = PassDisplayMode::Pipeline; + break; + } + unwrap(passManager)->enableStatistics(mode); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e8f8824..7f419a0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF6x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) { + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x2 to f6x2."; + } + return success(); +} + LogicalResult ConvertF32x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; using SatMode = NVVM::SaturationMode; @@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); - switch (getType()) { - case ConvertFP8Type::E4M3: - case ConvertFP8Type::E5M2: - if (!isRoundingModeRN) - return emitOpError("Only RN rounding mode is supported for conversions " - "from f32x2 to .e4m3x2 or .e5m2x2 types"); - if (!isSatFinite) - return emitOpError("Only SATFINITE saturation mode is supported for " - "conversions from f32x2 to .e4m3x2 or .e5m2x2 types"); - break; - case ConvertFP8Type::UE8M0: - if (!(isRoundingModeRZ || isRoundingModeRP)) - return emitOpError("Only RZ or RP rounding modes are supported for " - "conversions from f32x2 to .ue8m0x2 type"); - if (hasRelu) - return emitOpError("relu not supported for conversions to .ue8m0x2 type"); - break; - } - return success(); + mlir::MLIRContext *ctx = getContext(); + + return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy()) + .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>( + [&](mlir::Type) -> LogicalResult { + if (!isRoundingModeRN) { + return emitOpError("Only RN rounding mode is supported for " + "conversions from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + if (!isSatFinite) { + return emitOpError("Only SATFINITE saturation mode is supported " + "for conversions " + "from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + return success(); + }) + .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult { + if (!(isRoundingModeRZ || isRoundingModeRP)) { + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from f32x2 to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + if (hasRelu) { + return emitOpError("relu not supported for conversions to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + return success(); + }) + .Default([&](mlir::Type) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << ", " + << mlir::Float8E5M2Type::get(ctx) << ", and " + << mlir::Float8E8M0FNUType::get(ctx) + << " types are " + "supported for conversions from f32x2 to f8x2"; + }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { - if (getType() == ConvertFP8Type::UE8M0) - return emitOpError("Only .e4m3 or .e5m2 types are supported for " - "conversions from f16x2 to f8x2."); + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f16x2 to f8x2."; + } return success(); } LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; - if (getType() != ConvertFP8Type::UE8M0) - return emitOpError( - "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2."); + if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy())) + return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) + << " type is supported for conversions from " + "bf16x2 to f8x2."; auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) @@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite -llvm::Intrinsic::ID -ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP6Type::E2M3: - return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); - case NVVM::ConvertFP6Type::E3M2: - return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); - } - llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); +llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) { + return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); + }) + .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) { + return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ @@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { : llvm::Intrinsic::nvvm_ff_to_##type##_rn llvm::Intrinsic::ID -ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, - NVVM::FPRoundingMode rnd, +ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); - case NVVM::ConvertFP8Type::UE8M0: - if (hasRoundingModeRZ) - return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); - else if (hasRoundingModeRP) - return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); - } - llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op"); + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); + }) + .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) { + if (hasRoundingModeRZ) + return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); + else if (hasRoundingModeRP) + return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); + + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F16x2_TO_F8X2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn -llvm::Intrinsic::ID -ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); - default: - llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op"); - } +llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 4919d9a..9d62491 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, if (!mask) { LDBG() << "No mask required"; + if (assumeDynamicDimsMatchVecSizes) { + llvm::TypeSwitch<Operation *>(opToMask) + .Case<vector::TransferReadOp, vector::TransferWriteOp>( + [&](auto xferOp) { + // For vector.transfer_read and vector.transfer_write, there is + // also the `in-bounds` attribute that has to be set explicitly + // to true. Otherwise, "out-of-bounds" access will be assumed + // and masks will be generated while lowering these. + LDBG() << "Assuming dynamic dimensions match vector sizes and " + "setting their in-bounds to true!"; + SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues(); + ShapedType xferType = xferOp.getShapedType(); + AffineMap permMap = xferOp.getPermutationMap(); + // Only set the in-bounds values to true for dynamic dims. + // Different mechanisms will set these accordingly for the + // static dims. + for (unsigned i = 0; i < xferOp.getTransferRank(); i++) { + auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i)); + // Skip broadcast dimensions. + if (!dimExpr) + continue; + unsigned pos = dimExpr.getPosition(); + if (xferType.isDynamicDim(pos)) + inBoundsMap[i] = true; + } + rewriter.modifyOpInPlace(xferOp, [&]() { + xferOp.setInBoundsAttr( + rewriter.getBoolArrayAttr(inBoundsMap)); + }); + }) + .Default([](Operation *op) { + // No-op if the operation is not an xfer read or write. + }); + } return opToMask; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 784e5d6..c28d2fc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); auto vecType = dyn_cast<VectorType>(op.getType()); - if (!vecAttr || !vecAttr.isSplat() || !vecType) + if (!vecAttr || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = @@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - // Current limitation: constant of vector with single value. - // TODO: support more complex cases, e.g., vector with multiple values. - Attribute singleVal = vecAttr.getSplatValue<Attribute>(); - auto newType = VectorType::get(sgShape, vecType.getElementType()); - auto sgAttr = DenseElementsAttr::get(newType, singleVal); - auto cstOp = - arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), - layout.dropSgLayoutAndData()); - SmallVector<Value> newConsts(count, cstOp); + Location loc = op.getLoc(); + auto eltType = vecType.getElementType(); - rewriter.replaceOpWithMultiple(op, {newConsts}); - return success(); + auto setLayoutIfNeeded = [&](Value val) { + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); + } + }; + + if (vecAttr.isSplat()) { + // Splat: single value for all subgroups + Attribute singleVal = vecAttr.getSplatValue<Attribute>(); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); + setLayoutIfNeeded(cstOp->getResult(0)); + rewriter.replaceOp(op, cstOp); + return success(); + } else if (sgShape == wgShape) { // if the entire vector is shared by all + // subgroups, don't distribute + auto newConstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); + setLayoutIfNeeded(newConstOp->getResult(0)); + rewriter.replaceOp(op, newConstOp); + return success(); + } else { + // Non-splat constant + // Only supports 1D & 2D + // TODO: support other cases that require SLM access + if (!eltType.isIndex()) + return rewriter.notifyMatchFailure( + op, "Unsupported element type for non-splat constant op."); + + if (wgShape.size() > 2) + return rewriter.notifyMatchFailure( + op, "Only 1D & 2D vector constant supported"); + + SmallVector<Attribute> values(vecAttr.getValues<Attribute>()); + int64_t rowStride = 0, colStride = 0; + int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; + int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; + + // Compute colStride and rowStride, and check for constant strides. + if (cols > 1) { + colStride = cast<IntegerAttr>(values[1]).getInt() - + cast<IntegerAttr>(values[0]).getInt(); + } + if (rows > 1) { + rowStride = cast<IntegerAttr>(values[cols]).getInt() - + cast<IntegerAttr>(values[0]).getInt(); + } + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + int64_t idx = r * cols + c; + // Check column stride + if (c > 0 && cols > 1) { + int64_t prevIdx = r * cols + (c - 1); + int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - + cast<IntegerAttr>(values[prevIdx]).getInt(); + if (diff != colStride) + return rewriter.notifyMatchFailure( + op, "Non-constant column stride in constant op."); + } + // Check row stride + if (r > 0 && rows > 1) { + int64_t prevIdx = (r - 1) * cols + c; + int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - + cast<IntegerAttr>(values[prevIdx]).getInt(); + if (diff != rowStride) + return rewriter.notifyMatchFailure( + op, "Non-constant row stride in constant op."); + } + } + } + + // Create a constant for the base tile. + // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. + // For 1D case, extract the first sgShape[0] elements. + SmallVector<Attribute> baseTileValues; + int baseTileCols = sgShape[sgShape.size() - 1]; + int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; + for (int64_t r = 0; r < baseTileRows; ++r) { + for (int64_t c = 0; c < baseTileCols; ++c) { + baseTileValues.push_back(values[r * cols + c]); + } + } + + auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), + baseTileValues); + auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); + + // Get subgroup id + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<Value, 2> strideConsts; + strideConsts.push_back( + rewriter.create<arith::ConstantIndexOp>(loc, colStride)); + if (rows > 1) + strideConsts.insert( + strideConsts.begin(), + rewriter.create<arith::ConstantIndexOp>(loc, rowStride)); + + SmallVector<Value> newConstOps; + for (auto offsets : *sgOffsets) { + // Multiply offset with stride, broadcast it and add to baseConstVec + Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0); + for (size_t i = 0; i < strideConsts.size(); ++i) { + Value mul = rewriter.create<arith::MulIOp>( + loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); + mulOffset = rewriter.create<arith::AddIOp>( + loc, rewriter.getIndexType(), mulOffset, mul); + } + // Broadcast to baseConstVec size + auto bcastOffset = rewriter.create<vector::BroadcastOp>( + loc, baseConstVec.getType(), mulOffset); + auto finalConst = + arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); + setLayoutIfNeeded(baseConstVec); + setLayoutIfNeeded(bcastOffset); + setLayoutIfNeeded(finalConst); + newConstOps.push_back(finalConst); + } + rewriter.replaceOpWithMultiple(op, {newConstOps}); + return success(); + } } }; |