diff options
Diffstat (limited to 'mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp')
-rw-r--r-- | mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 76 |
1 files changed, 67 insertions, 9 deletions
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3..a2dfc12 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,8 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -19,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/DebugLog.h" #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" @@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +struct ClampFOpConversion final + : public ConvertOpToLLVMPattern<math::ClampFOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only f16 and f32 types are supported by fmed3 + Type opTy = op.getType(); + Type resultType = getTypeConverter()->convertType(opTy); + + if (auto vectorType = dyn_cast<VectorType>(opTy)) + opTy = vectorType.getElementType(); + + if (!isa<Float16Type, Float32Type>(opTy)) + return rewriter.notifyMatchFailure( + op, "fmed3 only supports f16 and f32 types"); + + // Handle multi-dimensional vectors (converted to LLVM arrays) + if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + typename math::ClampFOp::Adaptor adaptor(operands); + return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getValue(), adaptor.getMin(), + adaptor.getMax()); + }, + rewriter); + + // Handle 1D vectors and scalars directly + rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } +}; + void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); + + if (chipset.has_value() && chipset->majorVersion >= 9) { + patterns.add<ClampFOpConversion>(converter); + } else { + LDBG() << "Chipset dependent patterns were not added"; + } } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { - ConvertMathToROCDLPass() = default; +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); @@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + + FailureOr<amdgpu::Chipset> maybeChipset; + if (!chipset.empty()) { + maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) + return signalPassFailure(); + } + populateMathToROCDLConversionPatterns( + converter, patterns, + succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt); + ConversionTarget target(getContext()); - target.addLegalDialect<BuiltinDialect, func::FuncDialect, - vector::VectorDialect, LLVM::LLVMDialect>(); + target + .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect, + LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, |