aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Herhut <herhut@google.com>2020-06-09 17:20:53 +0200
committerStephan Herhut <herhut@google.com>2020-06-09 19:33:45 +0200
commit2c8afe1298e5f471a5736757b1cd2a708dd91ec9 (patch)
treea45283036285205175d8f044df8483fd542bdcfc
parentb7d369280ba6073a285811733f90cf7f2e0066be (diff)
downloadllvm-2c8afe1298e5f471a5736757b1cd2a708dd91ec9.zip
llvm-2c8afe1298e5f471a5736757b1cd2a708dd91ec9.tar.gz
llvm-2c8afe1298e5f471a5736757b1cd2a708dd91ec9.tar.bz2
[mlir][gpu] Add support for f16 when lowering to nvvm intrinsics
Summary: The NVVM target only provides implementations for tanh etc. on f32 and f64 operands. To also support f16, we now insert operations to extend to f32 and truncate back to f16 around the intrinsic call. Differential Revision: https://reviews.llvm.org/D81473
-rw-r--r--mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h42
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir8
2 files changed, 42 insertions, 8 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index c7bbb6d..58b5f1d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -20,6 +20,9 @@ namespace mlir {
/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
+/// If the input values are of f16 type, the value is first casted to f32, the
+/// function called and then the result casted back.
+///
/// Example with NVVM:
/// %exp_f32 = std.exp %arg_f32 : f32
///
@@ -44,21 +47,48 @@ public:
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
- .template cast<LLVM::LLVMType>();
- LLVMType funcType = getFunctionType(resultType, operands);
- StringRef funcName = getFunctionName(resultType);
+ static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+ SourceOp>::value,
+ "expected op with same operand and result types");
+
+ SmallVector<Value, 1> castedOperands;
+ for (Value operand : operands)
+ castedOperands.push_back(maybeCast(operand, rewriter));
+
+ LLVMType resultType =
+ castedOperands.front().getType().cast<LLVM::LLVMType>();
+ LLVMType funcType = getFunctionType(resultType, castedOperands);
+ StringRef funcName = getFunctionName(funcType.getFunctionResultType());
if (funcName.empty())
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
- op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
- rewriter.replaceOp(op, {callOp.getResult(0)});
+ op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
+ castedOperands);
+
+ if (resultType == operands.front().getType()) {
+ rewriter.replaceOp(op, {callOp.getResult(0)});
+ return success();
+ }
+
+ Value truncated = rewriter.create<LLVM::FPTruncOp>(
+ op->getLoc(), operands.front().getType(), callOp.getResult(0));
+ rewriter.replaceOp(op, {truncated});
return success();
}
private:
+ Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+ LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+ if (!type.isHalfTy())
+ return operand;
+
+ return rewriter.create<LLVM::FPExtOp>(
+ operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+ operand);
+ }
+
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
ArrayRef<Value> operands) const {
using LLVM::LLVMType;
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index f05c9af..925615c 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -219,12 +219,16 @@ gpu.module @test_module {
// CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
// CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
// CHECK-LABEL: func @gpu_tanh
- func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = std.tanh %arg_f16 : f16
+ // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
+ // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
+ // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
%result32 = std.tanh %arg_f32 : f32
// CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
%result64 = std.tanh %arg_f64 : f64
// CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
- std.return %result32, %result64 : f32, f64
+ std.return %result16, %result32, %result64 : f16, f32, f64
}
}