diff options
Diffstat (limited to 'mlir/lib/Conversion')
| -rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 70 | ||||
| -rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h | 21 | ||||
| -rw-r--r-- | mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 5 |
4 files changed, 65 insertions, 37 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 9b15435..478b6aa 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { /// on the architecture you are compiling for. static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { - auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); - auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType()); - auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); - auto elemSourceType = sourceVectorType.getElementType(); - auto elemBSourceType = sourceBVectorType.getElementType(); - auto elemDestType = destVectorType.getElementType(); - - if (elemSourceType.isF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); - if (elemSourceType.isF16() && elemDestType.isF16()) - return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isBF16()) - return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); - if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); + auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); + auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + + if (k == 16) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); + if (chipset.majorVersion == 11) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + } } - if (chipset.majorVersion >= 12) { + if (chipset.majorVersion < 12) + return std::nullopt; + + // gfx12+ + if (k == 16) { if (isa<Float8E4M3FNType>(elemSourceType) && isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); @@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, if (isa<Float8E5M2Type>(elemSourceType) && isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { - bool isWave64 = destVectorType.getNumElements() == 4; - // This is the ambiguous case. 8 inputs to the wave64 version means that - // we want the 16x16x32 version, but for wave32 they mean the short form. - bool has8Inputs = sourceVectorType.getNumElements() == 8; - if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) - return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); - } + + return std::nullopt; } - return std::nullopt; + if (k == 32) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + return std::nullopt; + } + + llvm_unreachable("unhandled WMMA case"); } namespace { diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 2285d26..eb662a1 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, /*isVarArg=*/true); LLVM::LLVMFuncOp printfDecl = - getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); + getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType); + printfDecl.setCConv(callingConvention); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant( @@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + call.setCConv(callingConvention); rewriter.eraseOp(gpuPrintfOp); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 66d3bb4..ec74787 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" namespace mlir { @@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { /// This pass will add a declaration of printf() to the GPUModule if needed /// and separate out the format strings into global constants. For some /// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler -/// will lower printf calls to appropriate device-side code +/// will lower printf calls to appropriate device-side code. +/// However not all backends use the same calling convention and function +/// naming. +/// For example, the LLVM SPIRV backend requires calling convention +/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be +/// mangled as "_Z6printfPU3AS2Kcz". +/// Default callingConvention is LLVM::cconv::CConv::C and +/// funcName is "printf" but they can be customized as needed. struct GPUPrintfOpToLLVMCallLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { - GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter, - int addressSpace = 0) + GPUPrintfOpToLLVMCallLowering( + const LLVMTypeConverter &converter, int addressSpace = 0, + LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C, + StringRef funcName = "printf") : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter), - addressSpace(addressSpace) {} + addressSpace(addressSpace), callingConvention(callingConvention), + funcName(funcName) {} LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, @@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering private: int addressSpace; + LLVM::cconv::CConv callingConvention; + StringRef funcName; }; /// Lowering of gpu.printf to a vprintf standard library. diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index c2363a1..25f1e1b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, - gpu::ThreadIdOp>(); + gpu::ThreadIdOp, gpu::PrintfOp>(); populateGpuToLLVMSPVConversionPatterns(converter, patterns); populateGpuMemorySpaceAttributeConversions(converter); + patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2, + LLVM::cconv::CConv::SPIR_FUNC, + "_Z6printfPU3AS2Kcz"); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) |
