diff options
Diffstat (limited to 'mlir/lib/Conversion/NVGPUToNVVM')
| -rw-r--r-- | mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 55 |
1 files changed, 12 insertions, 43 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index a9efada..64a7f56 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(mbarrierType)) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( - op, barrier, count, adaptor.getPredicate()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, - adaptor.getPredicate()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, + adaptor.getPredicate()); return success(); } }; @@ -870,13 +865,7 @@ struct NVGPUMBarrierArriveLowering adaptor.getMbarId(), rewriter); Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, - barrier); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, - barrier); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier); return success(); } }; @@ -897,13 +886,8 @@ struct NVGPUMBarrierArriveNoCompleteLowering Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( - op, tokenType, barrier, count); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( - op, tokenType, barrier, count); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( + op, tokenType, barrier, count); return success(); } }; @@ -920,13 +904,8 @@ struct NVGPUMBarrierTestWaitLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Type retType = rewriter.getI1Type(); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( - op, retType, barrier, adaptor.getToken()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( - op, retType, barrier, adaptor.getToken()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier, + adaptor.getToken()); return success(); } }; @@ -943,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( - op, barrier, txcount, adaptor.getPredicate()); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( - op, barrier, txcount, adaptor.getPredicate()); + op, Type{}, // return-value is optional and is void by default + barrier, txcount, // barrier and txcount + NVVM::MemScopeKind::CTA, // default scope is CTA + false, // relaxed-semantics is false + adaptor.getPredicate()); return success(); } }; @@ -970,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( - op, barrier, phase, ticks); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, phase, ticks); return success(); |
