diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 74 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 167 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 64 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 34 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 28 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 218 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 111 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 120 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 147 |
12 files changed, 735 insertions, 251 deletions
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5c8564b..4754f0b 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -974,10 +974,10 @@ LogicalResult emitc::YieldOp::verify() { Value result = getResult(); Operation *containingOp = getOperation()->getParentOp(); - if (result && containingOp->getNumResults() != 1) + if (!isa<DoOp>(containingOp) && result && containingOp->getNumResults() != 1) return emitOpError() << "yields a value not returned by parent"; - if (!result && containingOp->getNumResults() != 0) + if (!isa<DoOp>(containingOp) && !result && containingOp->getNumResults() != 0) return emitOpError() << "does not yield a value to be returned by parent"; return success(); @@ -1562,6 +1562,76 @@ LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } //===----------------------------------------------------------------------===// +// DoOp +//===----------------------------------------------------------------------===// + +void DoOp::print(OpAsmPrinter &p) { + p << ' '; + p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false); + p << " while "; + p.printRegion(getConditionRegion()); + p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); +} + +LogicalResult emitc::DoOp::verify() { + Block &condBlock = getConditionRegion().front(); + + if (condBlock.getOperations().size() != 2) + return emitOpError( + "condition region must contain exactly two operations: " + "'emitc.expression' followed by 'emitc.yield', but found ") + << condBlock.getOperations().size() << " operations"; + + Operation &first = condBlock.front(); + auto exprOp = dyn_cast<emitc::ExpressionOp>(first); + if (!exprOp) + return emitOpError("expected first op in condition region to be " + "'emitc.expression', but got ") + << first.getName(); + + if (!exprOp.getResult().getType().isInteger(1)) + return emitOpError("emitc.expression in condition region must return " + "'i1', but returns ") + << exprOp.getResult().getType(); + + Operation &last = condBlock.back(); + auto condYield = dyn_cast<emitc::YieldOp>(last); + if (!condYield) + return emitOpError("expected last op in condition region to be " + "'emitc.yield', but got ") + << last.getName(); + + if (condYield.getNumOperands() != 1) + return emitOpError("expected condition region to return 1 value, but " + "it returns ") + << condYield.getNumOperands() << " values"; + + if (condYield.getOperand(0) != exprOp.getResult()) + return emitError("'emitc.yield' must return result of " + "'emitc.expression' from this condition region"); + + Block &bodyBlock = getBodyRegion().front(); + if (bodyBlock.mightHaveTerminator()) + return emitOpError("body region must not contain terminator"); + + return success(); +} + +ParseResult DoOp::parse(OpAsmParser &parser, OperationState &result) { + Region *bodyRegion = result.addRegion(); + Region *condRegion = result.addRegion(); + + if (parser.parseRegion(*bodyRegion) || parser.parseKeyword("while") || + parser.parseRegion(*condRegion)) + return failure(); + + if (bodyRegion->empty()) + bodyRegion->emplaceBlock(); + + return parser.parseOptionalAttrDictWithKeyword(result.attributes); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e8f8824..7f419a0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -216,6 +216,18 @@ LogicalResult ConvertFloatToTF32Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF6x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) { + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x2 to f6x2."; + } + return success(); +} + LogicalResult ConvertF32x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; using SatMode = NVVM::SaturationMode; @@ -227,41 +239,67 @@ LogicalResult ConvertF32x2ToF8x2Op::verify() { bool hasRelu = getRelu(); - switch (getType()) { - case ConvertFP8Type::E4M3: - case ConvertFP8Type::E5M2: - if (!isRoundingModeRN) - return emitOpError("Only RN rounding mode is supported for conversions " - "from f32x2 to .e4m3x2 or .e5m2x2 types"); - if (!isSatFinite) - return emitOpError("Only SATFINITE saturation mode is supported for " - "conversions from f32x2 to .e4m3x2 or .e5m2x2 types"); - break; - case ConvertFP8Type::UE8M0: - if (!(isRoundingModeRZ || isRoundingModeRP)) - return emitOpError("Only RZ or RP rounding modes are supported for " - "conversions from f32x2 to .ue8m0x2 type"); - if (hasRelu) - return emitOpError("relu not supported for conversions to .ue8m0x2 type"); - break; - } - return success(); + mlir::MLIRContext *ctx = getContext(); + + return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy()) + .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>( + [&](mlir::Type) -> LogicalResult { + if (!isRoundingModeRN) { + return emitOpError("Only RN rounding mode is supported for " + "conversions from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + if (!isSatFinite) { + return emitOpError("Only SATFINITE saturation mode is supported " + "for conversions " + "from f32x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + } + return success(); + }) + .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult { + if (!(isRoundingModeRZ || isRoundingModeRP)) { + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from f32x2 to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + if (hasRelu) { + return emitOpError("relu not supported for conversions to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + } + return success(); + }) + .Default([&](mlir::Type) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << ", " + << mlir::Float8E5M2Type::get(ctx) << ", and " + << mlir::Float8E8M0FNUType::get(ctx) + << " types are " + "supported for conversions from f32x2 to f8x2"; + }); } LogicalResult ConvertF16x2ToF8x2Op::verify() { - if (getType() == ConvertFP8Type::UE8M0) - return emitOpError("Only .e4m3 or .e5m2 types are supported for " - "conversions from f16x2 to f8x2."); + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) { + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f16x2 to f8x2."; + } return success(); } LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; - if (getType() != ConvertFP8Type::UE8M0) - return emitOpError( - "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2."); + if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy())) + return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) + << " type is supported for conversions from " + "bf16x2 to f8x2."; auto rnd = getRnd(); if (!(rnd == RndMode::RZ || rnd == RndMode::RP)) @@ -1980,15 +2018,19 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite -llvm::Intrinsic::ID -ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP6Type::E2M3: - return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); - case NVVM::ConvertFP6Type::E3M2: - return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); - } - llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); +llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) { + return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu); + }) + .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) { + return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ @@ -2000,41 +2042,50 @@ ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) { : llvm::Intrinsic::nvvm_ff_to_##type##_rn llvm::Intrinsic::ID -ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, - NVVM::FPRoundingMode rnd, +ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd, NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ); bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); - case NVVM::ConvertFP8Type::UE8M0: - if (hasRoundingModeRZ) - return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); - else if (hasRoundingModeRP) - return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); - } - llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op"); + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu); + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu); + }) + .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) { + if (hasRoundingModeRZ) + return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite); + else if (hasRoundingModeRP) + return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite); + + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_F16x2_TO_F8X2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \ : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn -llvm::Intrinsic::ID -ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) { - switch (type) { - case NVVM::ConvertFP8Type::E4M3: - return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); - case NVVM::ConvertFP8Type::E5M2: - return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); - default: - llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op"); - } +llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) + .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { + return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu); + }) + .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { + return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu); + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 7863c21..0dac688 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - DenseMap<int64_t, OpFoldResult> dimAndTileMapping = - packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - int64_t numTiles = destRank - srcRank; + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + int64_t numberOfTiles = innerDimsPos.size(); - // 1. Extract the inner tile sizes. - // Where possible, values are replaced with constant attributes (to match the - // behaviour of `getPackOpSourceOrPaddedSource`). - SmallVector<OpFoldResult> tileSizes; - for (auto i : llvm::seq<unsigned>(0, srcRank)) { - if (dimAndTileMapping.count(i)) { - // Rather than taking the tile size as is, extact the actual constant - // value Attribute where possible, e.g.: - // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] - auto [_, tileSize] = - getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); - tileSizes.push_back(tileSize); - } - } + // 1. Get the input that is going to be packed. If the input requires padding, + // add a padding operation and return that as the input. + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // 1. All outer dims are 1 - the corresponding transposition order doesn't + // - All outer dims are 1 - the corresponding transposition order doesn't // matter, but requires all dim indices to be present. + + // 2.1 Get the permutation for linalg.transpose SmallVector<int64_t> srcPermForTranspose; - ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the // rank of the inner tiling, correspond to the last `k` indices of the @@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // rank of the source tensor. For example if we have a source tensor with // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. - if (llvm::is_contained(innerDimPos, i)) + if (llvm::is_contained(innerDimsPos, i)) continue; srcPermForTranspose.push_back(i); } - srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); + srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); + + // 2.2 Create the init tensor for linalg.transpose with the correct shape + SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, + oneIdxAttr); + shapeForEmptyOp.append(packOp.getMixedTiles()); + + // getMixedTiles() may contain Values pointing to constant ops, not the + // constant attributes. Replace them with a true OpFoldResult. + llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(), + [&](OpFoldResult ofr) { + if (auto val = llvm::dyn_cast<Value>(ofr)) + return getAsOpFoldResult(val); + return ofr; + }); LDBG() << "Pack permutation: " << packOp; LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); + LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp); - // 2.1 Create tensor.empty (init value for TransposeOp) - SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, - oneIdxAttr); - transShapeForEmptyOp.append(tileSizes); - - applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, - srcPermForTranspose); - Value empty = - tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, - packOp.getSourceType().getElementType()); + Value empty = tensor::EmptyOp::create( + rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType()); - // 2.2 Create linalg.transpose + // 2.3 Create linalg.transpose auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); @@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), - oneIdxAttr); + SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); SmallVector<int64_t> writeShape; for (auto tileSize : packOp.getMixedTiles()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 4919d9a..9d62491 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, if (!mask) { LDBG() << "No mask required"; + if (assumeDynamicDimsMatchVecSizes) { + llvm::TypeSwitch<Operation *>(opToMask) + .Case<vector::TransferReadOp, vector::TransferWriteOp>( + [&](auto xferOp) { + // For vector.transfer_read and vector.transfer_write, there is + // also the `in-bounds` attribute that has to be set explicitly + // to true. Otherwise, "out-of-bounds" access will be assumed + // and masks will be generated while lowering these. + LDBG() << "Assuming dynamic dimensions match vector sizes and " + "setting their in-bounds to true!"; + SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues(); + ShapedType xferType = xferOp.getShapedType(); + AffineMap permMap = xferOp.getPermutationMap(); + // Only set the in-bounds values to true for dynamic dims. + // Different mechanisms will set these accordingly for the + // static dims. + for (unsigned i = 0; i < xferOp.getTransferRank(); i++) { + auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i)); + // Skip broadcast dimensions. + if (!dimExpr) + continue; + unsigned pos = dimExpr.getPosition(); + if (xferType.isDynamicDim(pos)) + inBoundsMap[i] = true; + } + rewriter.modifyOpInPlace(xferOp, [&]() { + xferOp.setInBoundsAttr( + rewriter.getBoolArrayAttr(inBoundsMap)); + }); + }) + .Default([](Operation *op) { + // No-op if the operation is not an xfer read or write. + }); + } return opToMask; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index ef172c1..37bdd8b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = { /// Structure to keep information of constant transform matrices. struct TransformMatrix { - TransformMatrix(const float *table, int64_t rows, int64_t cols, + TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols, int64_t scalarFactor = 1) : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} - const float *table; + ArrayRef<float> table; int64_t rows; int64_t cols; int64_t scalarFactor; @@ -199,14 +199,20 @@ struct TransformMatrix { /// Utility function to convert constant array to arith.constant Value. Value create2DTransformMatrix(OpBuilder &builder, Location loc, TransformMatrix transform, Type type) { - ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); - + assert(transform.table.size() == + static_cast<size_t>(transform.rows * transform.cols)); + assert(type.isFloat() && "Only floats are supported by Winograd"); + ArrayRef<float> constVec(transform.table.data(), + transform.rows * transform.cols); + auto constAttrVec = + llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute { + return builder.getFloatAttr(type, v); + }); + SmallVector<int64_t, 2> shape{transform.rows, transform.cols}; return arith::ConstantOp::create( builder, loc, - DenseFPElementsAttr::get( - RankedTensorType::get( - SmallVector<int64_t>{transform.rows, transform.cols}, type), - constVec)); + DenseFPElementsAttr::get(RankedTensorType::get(shape, type), + constAttrVec)); } /// Extract height x width data from 4D tensors. @@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); - Value BT = - create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); + Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType); // Multiply BT x d. auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{BT, matmulRetValue}, @@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); - Value B = - create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); + Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType); // Multiply v = (BT x d) x B. auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{matmulRetValue, B}, diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 697cb35..237aab4 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -27,7 +27,7 @@ using namespace mlir::nvgpu; #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" -void nvgpu::NVGPUDialect::initialize() { +void NVGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc" @@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() { >(); } -bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { +bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { if (!memorySpace) return false; if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace)) @@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { return false; } -bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { +bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { Attribute memorySpace = type.getMemorySpace(); return isSharedMemoryAddressSpace(memorySpace); } @@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue<VectorType> matrixC, const std::array<int64_t, 3> &mmaShape, bool tf32Enabled, bool sparse = false) { - // The verification for mma.sync covering various shapes and data types is // based on the fundamental tensor core shape. @@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() { // NVGPU_LdMatrixOp //===----------------------------------------------------------------------===// LogicalResult LdMatrixOp::verify() { - // ldmatrix reads data from source in shared memory auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType()); @@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() { // NVGPU_TmaAsyncLoadOp //===----------------------------------------------------------------------===// -unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { +static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { switch (kind) { case TensorMapSwizzleKind::SWIZZLE_32B: return 32; @@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { } std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( - Operation *op, nvgpu::TensorMapDescriptorType descType, + Operation *op, TensorMapDescriptorType descType, std::optional<MemRefType> memrefType = std::nullopt) { MemRefType descMemref = descType.getTensor(); // Limitation @@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() { //===----------------------------------------------------------------------===// LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { - - nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); + WarpgroupAccumulatorType accType = getMatrixC().getType(); int64_t sizeM = accType.getFragmented().getDimSize(0); int64_t sizeN = accType.getFragmented().getDimSize(1); Type elemType = accType.getFragmented().getElementType(); diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 46e82bd..2a857ed 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -43,7 +43,7 @@ using namespace mlir::transform; // Apply...ConversionPatternsOp //===----------------------------------------------------------------------===// -void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( +void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); /// device-side async tokens cannot be materialized in nvvm. We just @@ -62,62 +62,58 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( llvm_unreachable("unknown address space enum value"); return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic); }); - llvmTypeConverter.addConversion( - [&](nvgpu::DeviceAsyncTokenType type) -> Type { - return llvmTypeConverter.convertType( - IntegerType::get(type.getContext(), 32)); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { + llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type { + return llvmTypeConverter.convertType( + IntegerType::get(type.getContext(), 32)); + }); + llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupAccumulatorType type) -> Type { - Type elemType = type.getFragmented().getElementType(); - int64_t sizeM = type.getFragmented().getDimSize(0); - int64_t sizeN = type.getFragmented().getDimSize(1); - - unsigned numMembers; - if (elemType.isF32() || elemType.isInteger(32)) - numMembers = sizeN / 2; - else if (elemType.isF16()) - numMembers = sizeN / 4; - else - llvm_unreachable("unsupported type for warpgroup accumulator"); - - SmallVector<Type> innerStructBody; - for (unsigned i = 0; i < numMembers; i++) - innerStructBody.push_back(elemType); - auto innerStructType = LLVM::LLVMStructType::getLiteral( - type.getContext(), innerStructBody); - - SmallVector<Type> structBody; - for (int i = 0; i < sizeM; i += kWgmmaSizeM) - structBody.push_back(innerStructType); - - auto convertedType = - LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); - return llvmTypeConverter.convertType(convertedType); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { + llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type { + Type elemType = type.getFragmented().getElementType(); + int64_t sizeM = type.getFragmented().getDimSize(0); + int64_t sizeN = type.getFragmented().getDimSize(1); + + unsigned numMembers; + if (elemType.isF32() || elemType.isInteger(32)) + numMembers = sizeN / 2; + else if (elemType.isF16()) + numMembers = sizeN / 4; + else + llvm_unreachable("unsupported type for warpgroup accumulator"); + + SmallVector<Type> innerStructBody; + for (unsigned i = 0; i < numMembers; i++) + innerStructBody.push_back(elemType); + auto innerStructType = + LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); + + SmallVector<Type> structBody; + for (int i = 0; i < sizeM; i += kWgmmaSizeM) + structBody.push_back(innerStructType); + + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return llvmTypeConverter.convertType(convertedType); + }); + llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type { return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); }); llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + [&](WarpgroupMatrixDescriptorType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::TensorMapDescriptorType type) -> Type { - return LLVM::LLVMPointerType::get(type.getContext()); - }); + llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type { + return LLVM::LLVMPointerType::get(type.getContext()); + }); populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); } -LogicalResult -transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( - transform::TypeConverterBuilderOpInterface builder) { +LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( + TypeConverterBuilderOpInterface builder) { if (builder.getTypeConverterType() != "LLVMTypeConverter") return emitOpError("expected LLVMTypeConverter"); return success(); @@ -127,17 +123,18 @@ transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( // CreateAsyncGroupsOp //===---------------------------------------------------------------------===// -void transform::CreateAsyncGroupsOp::getEffects( +void CreateAsyncGroupsOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { - transform::consumesHandle(getTargetMutable(), effects); - transform::producesHandle(getOperation()->getOpResults(), effects); - transform::modifiesPayload(effects); + consumesHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); } -DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( - TransformRewriter &rewriter, Operation *target, - ApplyToEachResultList &results, TransformState &state) { - nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); +DiagnosedSilenceableFailure +CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, + TransformState &state) { + createAsyncGroups(rewriter, target, getBypassL1()); results.push_back(target); return DiagnosedSilenceableFailure::success(); } @@ -218,7 +215,7 @@ collectStage0PipeliningOps(scf::ForOp forOp, continue; } - if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { + if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) { ops.insert(&op); ops.insert(std::make_move_iterator(barriers.begin()), std::make_move_iterator(barriers.end())); @@ -246,7 +243,7 @@ setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, unsigned iteration, unsigned depth) { // Based on the order of copies within the loop we need to set the number // of copies in flight, unless it is already set. - auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); + auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op); if (!waitOp || waitOp.getNumGroups()) return; @@ -312,13 +309,12 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // original number of iterations, in particular side-effect free operations // and barriers, even if they cannot be predicated. if (isMemoryEffectFree(op) || - isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, - nvgpu::DeviceAsyncWaitOp>(op)) { + isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) { return op; } // Otherwise, only async copies can currently be predicated. - auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); + auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op); if (!asyncCopyOp) return nullptr; @@ -335,8 +331,8 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0); auto srcElements = arith::SelectOp::create(rewriter, loc, predicate, originalSrcElement, c0Index); - auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create( - rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create( + rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, UnitAttr()); @@ -805,17 +801,16 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { rhsIndexFn, rhsShape); Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); - res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, - info.tf32Enabled); + res = + MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); } -DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( - transform::TransformRewriter &rewriter, LinalgOp linalgOp, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne( + TransformRewriter &rewriter, LinalgOp linalgOp, + ApplyToEachResultList &results, TransformState &state) { bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { @@ -854,43 +849,42 @@ struct HopperBuilder { HopperBuilder(RewriterBase &rewriter, Location loc) : rewriter(rewriter), loc(loc) {} - TypedValue<nvgpu::MBarrierGroupType> + TypedValue<MBarrierGroupType> buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); /// Create tma descriptor op to initiate transfer from global to shared /// memory. This must be done before the launch op, on the host. - TypedValue<nvgpu::TensorMapDescriptorType> + TypedValue<TensorMapDescriptorType> buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, gpu::LaunchOp launchOp); /// Build a tma load from global memory to shared memory using `barrier` to /// synchronize. Return the number of bytes that will be transferred. - OpFoldResult - buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, - TypedValue<MemRefType> sharedMemref, - TypedValue<nvgpu::MBarrierGroupType> barrier, - SmallVectorImpl<Operation *> &loadOps); - void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, + OpFoldResult buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc, + TypedValue<MemRefType> sharedMemref, + TypedValue<MBarrierGroupType> barrier, + SmallVectorImpl<Operation *> &loadOps); + void buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier, ArrayRef<OpFoldResult> sizes); /// If threadIdx.x == 0 does TMA request + wait, else just wait. /// Return the operation that performs the transfer on thread0. // TODO: In the future, don't hardcode to thread 0 but elect a leader. SmallVector<Operation *> buildPredicateLoadsOnThread0( - ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, + ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors, ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, - TypedValue<nvgpu::MBarrierGroupType> barrier); + TypedValue<MBarrierGroupType> barrier); - void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); + void buildTryWaitParity(TypedValue<MBarrierGroupType> barrier); RewriterBase &rewriter; Location loc; }; SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( - ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, + ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors, ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, - TypedValue<nvgpu::MBarrierGroupType> barrier) { + TypedValue<MBarrierGroupType> barrier) { SmallVector<Operation *> loadOps; Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); @@ -931,22 +925,22 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); } -TypedValue<nvgpu::MBarrierGroupType> +TypedValue<MBarrierGroupType> HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value barrier = nvgpu::MBarrierCreateOp::create( + Value barrier = MBarrierCreateOp::create( rewriter, loc, - nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); + MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); nvgpu::MBarrierInitOp::create( rewriter, loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero, Value()); gpu::BarrierOp::create(rewriter, loc); - return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier); + return cast<TypedValue<MBarrierGroupType>>(barrier); } -TypedValue<nvgpu::TensorMapDescriptorType> +TypedValue<TensorMapDescriptorType> HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, gpu::LaunchOp launchOp) { OpBuilder::InsertionGuard guard(rewriter); @@ -962,29 +956,29 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value desc = nvgpu::TmaCreateDescriptorOp::create( + Value desc = TmaCreateDescriptorOp::create( rewriter, loc, - nvgpu::TensorMapDescriptorType::get( - rewriter.getContext(), - MemRefType::Builder(memref.getType()) - .setMemorySpace(sharedMemorySpace), - TensorMapSwizzleKind::SWIZZLE_NONE, - TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, - TensorMapInterleaveKind::INTERLEAVE_NONE), + TensorMapDescriptorType::get(rewriter.getContext(), + MemRefType::Builder(memref.getType()) + .setMemorySpace(sharedMemorySpace), + TensorMapSwizzleKind::SWIZZLE_NONE, + TensorMapL2PromoKind::L2PROMO_NONE, + TensorMapOOBKind::OOB_ZERO, + TensorMapInterleaveKind::INTERLEAVE_NONE), unrankedMemRef, sizes); - return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); + return cast<TypedValue<TensorMapDescriptorType>>(desc); } -OpFoldResult HopperBuilder::buildTmaAsyncLoad( - TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, - TypedValue<MemRefType> sharedMemref, - TypedValue<nvgpu::MBarrierGroupType> barrier, - SmallVectorImpl<Operation *> &loadOps) { +OpFoldResult +HopperBuilder::buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc, + TypedValue<MemRefType> sharedMemref, + TypedValue<MBarrierGroupType> barrier, + SmallVectorImpl<Operation *> &loadOps) { MLIRContext *ctx = rewriter.getContext(); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - Operation *loadOp = nvgpu::TmaAsyncLoadOp::create( - rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, - zero, Value(), Value()); + Operation *loadOp = + TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc, + ValueRange{zero, zero}, zero, Value(), Value()); loadOps.push_back(loadOp); auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); SmallVector<AffineExpr> symbols(mixedSizes.size()); @@ -997,9 +991,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad( return res; } -void HopperBuilder::buildBarrierArriveTx( - TypedValue<nvgpu::MBarrierGroupType> barrier, - ArrayRef<OpFoldResult> mixedSizes) { +void HopperBuilder::buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier, + ArrayRef<OpFoldResult> mixedSizes) { assert(!mixedSizes.empty() && "expecte non-empty sizes"); MLIRContext *ctx = rewriter.getContext(); SmallVector<AffineExpr> symbols(mixedSizes.size()); @@ -1013,8 +1006,7 @@ void HopperBuilder::buildBarrierArriveTx( Value()); } -void HopperBuilder::buildTryWaitParity( - TypedValue<nvgpu::MBarrierGroupType> barrier) { +void HopperBuilder::buildTryWaitParity(TypedValue<MBarrierGroupType> barrier) { Type i1 = rewriter.getI1Type(); Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0); // 10M is an arbitrary, not too small or too big number to specify the number @@ -1058,11 +1050,11 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), launchOp.getBlockSizeZ()}); - TypedValue<nvgpu::MBarrierGroupType> barrier = + TypedValue<MBarrierGroupType> barrier = buildAndInitBarrierInSharedMemory(numThreads); SmallVector<TypedValue<MemRefType>> shmems; - SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; + SmallVector<TypedValue<TensorMapDescriptorType>> globalDescs; for (Operation *op : copyOps) { auto copyOp = cast<linalg::CopyOp>(op); auto inMemRef = @@ -1071,7 +1063,7 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { "expected in to be a 2D memref"); // 2. Build global memory descriptor. - TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = + TypedValue<TensorMapDescriptorType> globalDesc = buildGlobalMemRefDescriptor(inMemRef, launchOp); globalDescs.push_back(globalDesc); @@ -1098,9 +1090,8 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { } DiagnosedSilenceableFailure -transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { +RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); gpu::LaunchOp commonLaunchOp; Operation *firstOp, *failingOp; @@ -1137,15 +1128,14 @@ transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, namespace { class NVGPUTransformDialectExtension - : public transform::TransformDialectExtension< - NVGPUTransformDialectExtension> { + : public TransformDialectExtension<NVGPUTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) NVGPUTransformDialectExtension() { declareGeneratedDialect<arith::ArithDialect>(); declareGeneratedDialect<affine::AffineDialect>(); - declareGeneratedDialect<nvgpu::NVGPUDialect>(); + declareGeneratedDialect<NVGPUDialect>(); declareGeneratedDialect<NVVM::NVVMDialect>(); declareGeneratedDialect<vector::VectorDialect>(); registerTransformOps< diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp index 5b89c87..7f626a6 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -64,6 +64,5 @@ private: void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns( RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) { - patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 809d634..9e5ea93 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); - FailureOr<nvgpu::FragmentElementInfo> regInfo = - getMmaSyncRegisterType(fragmentType); + FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType); if (failed(regInfo)) return failure(); @@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, (logicalValueIdDim % elementsPerRegister)}); } -FailureOr<nvgpu::LdMatrixParams> -nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { +FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose) { LdMatrixParams params; Type elType = type.vectorType.getElementType(); params.fragmentType = type.vectorType; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6598ac1..6564a4e 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -7,6 +7,7 @@ // ============================================================================= #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -44,6 +45,7 @@ struct MemRefPointerLikeModel Type getElementType(Type pointer) const { return cast<MemRefType>(pointer).getElementType(); } + mlir::acc::VariableTypeCategory getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr, Type varType) const { @@ -70,6 +72,115 @@ struct MemRefPointerLikeModel assert(memrefTy.getRank() > 0 && "rank expected to be positive"); return mlir::acc::VariableTypeCategory::array; } + + mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, + StringRef varName, Type varType, + Value originalVar) const { + auto memrefTy = cast<MemRefType>(pointer); + + // Check if this is a static memref (all dimensions are known) - if yes + // then we can generate an alloca operation. + if (memrefTy.hasStaticShape()) + return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + + // For dynamic memrefs, extract sizes from the original variable if + // provided. Otherwise they cannot be handled. + if (originalVar && originalVar.getType() == memrefTy && + memrefTy.hasRank()) { + SmallVector<Value> dynamicSizes; + for (int64_t i = 0; i < memrefTy.getRank(); ++i) { + if (memrefTy.isDynamicDim(i)) { + // Extract the size of dimension i from the original variable + auto indexValue = arith::ConstantIndexOp::create(builder, loc, i); + auto dimSize = + memref::DimOp::create(builder, loc, originalVar, indexValue); + dynamicSizes.push_back(dimSize); + } + // Note: We only add dynamic sizes to the dynamicSizes array + // Static dimensions are handled automatically by AllocOp + } + return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) + .getResult(); + } + + // TODO: Unranked not yet supported. + return {}; + } + + bool genFree(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> varPtr, Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + // Walk through casts to find the original allocation + Value currentValue = memrefValue; + Operation *originalAlloc = nullptr; + + // Follow the chain of operations to find the original allocation + // even if a casted result is provided. + while (currentValue) { + if (auto *definingOp = currentValue.getDefiningOp()) { + // Check if this is an allocation operation + if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) { + originalAlloc = definingOp; + break; + } + + // Check if this is a cast operation we can look through + if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) { + currentValue = castOp.getSource(); + continue; + } + + // Check for other cast-like operations + if (auto reinterpretCastOp = + dyn_cast<memref::ReinterpretCastOp>(definingOp)) { + currentValue = reinterpretCastOp.getSource(); + continue; + } + + // If we can't look through this operation, stop + break; + } + // This is a block argument or similar - can't trace further. + break; + } + + if (originalAlloc) { + if (isa<memref::AllocaOp>(originalAlloc)) { + // This is an alloca - no dealloc needed, but return true (success) + return true; + } + if (isa<memref::AllocOp>(originalAlloc)) { + // This is an alloc - generate dealloc + memref::DeallocOp::create(builder, loc, memrefValue); + return true; + } + } + } + + return false; + } + + bool genCopy(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> destination, + TypedValue<PointerLikeType> source, Type varType) const { + // Generate a copy operation between two memrefs + auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination); + auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source); + + // As per memref documentation, source and destination must have same + // element type and shape in order to be compatible. We do not want to fail + // with an IR verification error - thus check that before generating the + // copy operation. + if (destMemref && srcMemref && + destMemref.getType().getElementType() == + srcMemref.getType().getElementType() && + destMemref.getType().getShape() == srcMemref.getType().getShape()) { + memref::CopyOp::create(builder, loc, srcMemref, destMemref); + return true; + } + + return false; + } }; struct LLVMPointerPointerLikeModel diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e8..14e235f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -47,6 +47,7 @@ #include <cassert> #include <cstdint> +#include <numeric> #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" // Pull in all enum type and utility function definitions. @@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp, return success(); } +/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only. +/// +/// Example: +/// %b = vector.broadcast %x : i32 to vector<3xf32> +/// %e:3 = vector.to_elements %b : vector<3xf32> +/// user_op %e#0, %e#1, %e#2 +/// becomes: +/// user_op %x, %x, %x +/// +/// The vector source case is handled by a canonicalization pattern. +static LogicalResult +foldToElementsOfBroadcast(ToElementsOp toElementsOp, + SmallVectorImpl<OpFoldResult> &results) { + auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>(); + if (!bcastOp) + return failure(); + // Vectors are handled in the ToElementsOfBroadcast RewritePattern. + if (isa<VectorType>(bcastOp.getSource().getType())) + return failure(); + + auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType()); + + Value scalar = bcastOp.getSource(); + results.assign(resultVecType.getNumElements(), scalar); + return success(); +} + LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { - return foldToElementsFromElements(*this, results); + if (succeeded(foldToElementsFromElements(*this, results))) + return success(); + return foldToElementsOfBroadcast(*this, results); } LogicalResult @@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, return success(); } +/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a +/// vector. +/// - Build `vector.to_elements %v` and remap each destination element to the +/// corresponding source element using broadcast rules (match or 1 → +/// replicate). +/// +/// Example: +/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32> +/// %e:6 = vector.to_elements %v : vector<3x2xf32> +/// becomes: +/// %src_elems:2 = vector.to_elements %src : vector<2xf32> +/// // uses: %src_elems#0, %src_elems#1, %src_elems#0, +/// // %src_elems#1, %src_elems#0, %src_elems#1 +struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> { + using Base::Base; + + LogicalResult matchAndRewrite(ToElementsOp toElementsOp, + PatternRewriter &rewriter) const override { + auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>(); + if (!bcastOp) + return failure(); + + // Only handle broadcasts from a vector source here. + auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType()); + if (!srcType) + return failure(); + + auto dstType = cast<VectorType>(toElementsOp.getSource().getType()); + + ArrayRef<int64_t> dstShape = dstType.getShape(); + ArrayRef<int64_t> srcShape = srcType.getShape(); + + int64_t dstRank = dstShape.size(); + int64_t srcRank = srcShape.size(); + + // Create elements for the broadcast source vector. + auto srcElems = vector::ToElementsOp::create( + rewriter, toElementsOp.getLoc(), bcastOp.getSource()); + + int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1, + std::multiplies<int64_t>()); + + SmallVector<Value> replacements; + replacements.reserve(dstCount); + + // For each element of the destination, determine which element of the + // source should be used. We walk all destination positions using a single + // counter, decode it into per-dimension indices, then build the matching + // source position: use the same index where sizes match, and use 0 where + // the source size is 1 (replication). This mapping is needed so we can + // replace each result of to_elements with the corresponding element from + // the broadcast source. + // Inner-dimension stretch example: + // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32> + // %e:12 = vector.to_elements %v : vector<2x3x2xf32> + // becomes: + // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32> + // // uses: %src_elems#0, %src_elems#1, %src_elems#0, + // // %src_elems#1, %src_elems#0, %src_elems#1, + // // %src_elems#2, %src_elems#3, %src_elems#2, + // // %src_elems#3, %src_elems#2, %src_elems#3 + + // Row-major strides for the destination shape. + SmallVector<int64_t> dstStrides = computeStrides(dstShape); + // Row-major strides for the source shape. + SmallVector<int64_t> srcStrides = computeStrides(srcShape); + SmallVector<int64_t> dstIdx(dstRank); + SmallVector<int64_t> srcIdx(srcRank); + for (int64_t lin = 0; lin < dstCount; ++lin) { + // Convert linear destination index to per-dimension indices. + dstIdx = delinearize(lin, dstStrides); + for (int64_t k = 0; k < srcRank; ++k) + srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]; + // Convert per-dimension source indices back to a linear index. + int64_t srcLin = linearize(srcIdx, srcStrides); + replacements.push_back(srcElems.getResult(srcLin)); + } + + rewriter.replaceOp(toElementsOp, replacements); + return success(); + } +}; + +void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<ToElementsOfBroadcast>(context); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 784e5d6..c28d2fc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { ConversionPatternRewriter &rewriter) const override { auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); auto vecType = dyn_cast<VectorType>(op.getType()); - if (!vecAttr || !vecAttr.isSplat() || !vecType) + if (!vecAttr || !vecType) return failure(); xegpu::DistributeLayoutAttr layout = @@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { int count; std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - // Current limitation: constant of vector with single value. - // TODO: support more complex cases, e.g., vector with multiple values. - Attribute singleVal = vecAttr.getSplatValue<Attribute>(); - auto newType = VectorType::get(sgShape, vecType.getElementType()); - auto sgAttr = DenseElementsAttr::get(newType, singleVal); - auto cstOp = - arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), - layout.dropSgLayoutAndData()); - SmallVector<Value> newConsts(count, cstOp); + Location loc = op.getLoc(); + auto eltType = vecType.getElementType(); - rewriter.replaceOpWithMultiple(op, {newConsts}); - return success(); + auto setLayoutIfNeeded = [&](Value val) { + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); + } + }; + + if (vecAttr.isSplat()) { + // Splat: single value for all subgroups + Attribute singleVal = vecAttr.getSplatValue<Attribute>(); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); + setLayoutIfNeeded(cstOp->getResult(0)); + rewriter.replaceOp(op, cstOp); + return success(); + } else if (sgShape == wgShape) { // if the entire vector is shared by all + // subgroups, don't distribute + auto newConstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); + setLayoutIfNeeded(newConstOp->getResult(0)); + rewriter.replaceOp(op, newConstOp); + return success(); + } else { + // Non-splat constant + // Only supports 1D & 2D + // TODO: support other cases that require SLM access + if (!eltType.isIndex()) + return rewriter.notifyMatchFailure( + op, "Unsupported element type for non-splat constant op."); + + if (wgShape.size() > 2) + return rewriter.notifyMatchFailure( + op, "Only 1D & 2D vector constant supported"); + + SmallVector<Attribute> values(vecAttr.getValues<Attribute>()); + int64_t rowStride = 0, colStride = 0; + int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0]; + int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1]; + + // Compute colStride and rowStride, and check for constant strides. + if (cols > 1) { + colStride = cast<IntegerAttr>(values[1]).getInt() - + cast<IntegerAttr>(values[0]).getInt(); + } + if (rows > 1) { + rowStride = cast<IntegerAttr>(values[cols]).getInt() - + cast<IntegerAttr>(values[0]).getInt(); + } + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + int64_t idx = r * cols + c; + // Check column stride + if (c > 0 && cols > 1) { + int64_t prevIdx = r * cols + (c - 1); + int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - + cast<IntegerAttr>(values[prevIdx]).getInt(); + if (diff != colStride) + return rewriter.notifyMatchFailure( + op, "Non-constant column stride in constant op."); + } + // Check row stride + if (r > 0 && rows > 1) { + int64_t prevIdx = (r - 1) * cols + c; + int64_t diff = cast<IntegerAttr>(values[idx]).getInt() - + cast<IntegerAttr>(values[prevIdx]).getInt(); + if (diff != rowStride) + return rewriter.notifyMatchFailure( + op, "Non-constant row stride in constant op."); + } + } + } + + // Create a constant for the base tile. + // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix. + // For 1D case, extract the first sgShape[0] elements. + SmallVector<Attribute> baseTileValues; + int baseTileCols = sgShape[sgShape.size() - 1]; + int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0]; + for (int64_t r = 0; r < baseTileRows; ++r) { + for (int64_t c = 0; c < baseTileCols; ++c) { + baseTileValues.push_back(values[r * cols + c]); + } + } + + auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), + baseTileValues); + auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); + + // Get subgroup id + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<Value, 2> strideConsts; + strideConsts.push_back( + rewriter.create<arith::ConstantIndexOp>(loc, colStride)); + if (rows > 1) + strideConsts.insert( + strideConsts.begin(), + rewriter.create<arith::ConstantIndexOp>(loc, rowStride)); + + SmallVector<Value> newConstOps; + for (auto offsets : *sgOffsets) { + // Multiply offset with stride, broadcast it and add to baseConstVec + Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0); + for (size_t i = 0; i < strideConsts.size(); ++i) { + Value mul = rewriter.create<arith::MulIOp>( + loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); + mulOffset = rewriter.create<arith::AddIOp>( + loc, rewriter.getIndexType(), mulOffset, mul); + } + // Broadcast to baseConstVec size + auto bcastOffset = rewriter.create<vector::BroadcastOp>( + loc, baseConstVec.getType(), mulOffset); + auto finalConst = + arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); + setLayoutIfNeeded(baseConstVec); + setLayoutIfNeeded(bcastOffset); + setLayoutIfNeeded(finalConst); + newConstOps.push_back(finalConst); + } + rewriter.replaceOpWithMultiple(op, {newConstOps}); + return success(); + } } }; |