diff options
-rw-r--r-- | flang/lib/Lower/ConvertCall.cpp | 6 | ||||
-rw-r--r-- | flang/test/Lower/CUDA/cuda-kernel-calls.cuf | 6 |
2 files changed, 8 insertions, 4 deletions
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 9909121..9556933 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -416,7 +416,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult( mlir::Type i32Ty = builder.getI32Type(); mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); - mlir::Value grid_x, grid_y; + mlir::Value grid_x, grid_y, grid_z; if (caller.getCallDescription().chevrons()[0].GetType()->category() == Fortran::common::TypeCategory::Integer) { // If grid is an integer, it is converted to dim3(grid,1,1). Since z is @@ -426,11 +426,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult( fir::getBase(converter.genExprValue( caller.getCallDescription().chevrons()[0], stmtCtx))); grid_y = one; + grid_z = one; } else { auto dim3Addr = converter.genExprAddr( caller.getCallDescription().chevrons()[0], stmtCtx); grid_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x"); grid_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y"); + grid_z = readDim3Value(builder, loc, fir::getBase(dim3Addr), "z"); } mlir::Value block_x, block_y, block_z; @@ -466,7 +468,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult( caller.getCallDescription().chevrons()[3], stmtCtx))); builder.create<fir::CUDAKernelLaunch>( - loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, one, + loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z, block_x, block_y, block_z, bytes, stream, operands); callNumResults = 0; } else if (caller.requireDispatchCall()) { diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf index d5dabaa..55b5246e 100644 --- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf +++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf @@ -20,13 +20,15 @@ contains call dev_kernel0<<<10, 20>>>() ! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>() - call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>> + call dev_kernel0<<< __builtin_dim3(1,1,4), __builtin_dim3(32,1,1) >>> ! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>> ! CHECK: %[[DIM3_GRID:.*]]:2 = hlfir.declare %[[ADDR_DIM3_GRID]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.0"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) ! CHECK: %[[GRID_X:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32> ! CHECK: %[[GRID_X_LOAD:.*]] = fir.load %[[GRID_X]] : !fir.ref<i32> ! CHECK: %[[GRID_Y:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"y"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32> ! CHECK: %[[GRID_Y_LOAD:.*]] = fir.load %[[GRID_Y]] : !fir.ref<i32> +! CHECK: %[[GRID_Z:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"z"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32> +! CHECK: %[[GRID_Z_LOAD:.*]] = fir.load %[[GRID_Z]] : !fir.ref<i32> ! CHECK: %[[ADDR_DIM3_BLOCK:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>> ! CHECK: %[[DIM3_BLOCK:.*]]:2 = hlfir.declare %[[ADDR_DIM3_BLOCK]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.1"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) ! CHECK: %[[BLOCK_X:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32> @@ -35,7 +37,7 @@ contains ! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32> ! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32> ! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32> -! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %c1{{.*}}, %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>() +! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[GRID_Z_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>() call dev_kernel0<<<10, 20, 2>>>() ! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() |