aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/NVGPUToNVVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/NVGPUToNVVM')
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp55
1 files changed, 12 insertions, 43 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a9efada..64a7f56 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value count = truncToI32(b, adaptor.getCount());
- if (isMbarrierShared(mbarrierType)) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
- op, barrier, count, adaptor.getPredicate());
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
- adaptor.getPredicate());
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+ adaptor.getPredicate());
return success();
}
};
@@ -870,13 +865,7 @@ struct NVGPUMBarrierArriveLowering
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
- barrier);
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
- barrier);
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
return success();
}
};
@@ -897,13 +886,8 @@ struct NVGPUMBarrierArriveNoCompleteLowering
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
Value count = truncToI32(b, adaptor.getCount());
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
- op, tokenType, barrier, count);
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
- op, tokenType, barrier, count);
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
+ op, tokenType, barrier, count);
return success();
}
};
@@ -920,13 +904,8 @@ struct NVGPUMBarrierTestWaitLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type retType = rewriter.getI1Type();
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
- op, retType, barrier, adaptor.getToken());
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
- op, retType, barrier, adaptor.getToken());
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
+ adaptor.getToken());
return success();
}
};
@@ -943,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, barrier, txcount, adaptor.getPredicate());
+ op, Type{}, // return-value is optional and is void by default
+ barrier, txcount, // barrier and txcount
+ NVVM::MemScopeKind::CTA, // default scope is CTA
+ false, // relaxed-semantics is false
+ adaptor.getPredicate());
return success();
}
};
@@ -970,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();