aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp64
1 files changed, 58 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 7d0a236..76a822b 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPU.h"
+#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -27,6 +28,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/DebugLog.h"
#include <optional>
@@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
bool seenSideeffects = false;
// Whether we have left a nesting scope (and hence are no longer innermost).
bool leftNestingScope = false;
+ LocalAliasAnalysis aliasAnalysis;
+ llvm::DenseSet<Value> writtenBuffer;
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
// Now walk over the body and clone it.
// TODO: This is only correct if there either is no further scf.parallel
- // nested or this code is side-effect free. Otherwise we might need
- // predication. We are overly conservative for now and only allow
- // side-effects in the innermost scope.
+ // nested or this code has side-effect but the memory buffer is not
+ // alias to inner loop access buffer. Otherwise we might need
+ // predication.
if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
// Before entering a nested scope, make sure there have been no
- // sideeffects until now.
- if (seenSideeffects)
- return failure();
+ // sideeffects until now or the nested operations do not access the
+ // buffer written by outer scope.
+ if (seenSideeffects) {
+ WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) {
+ if (isMemoryEffectFree(nestedOp))
+ return WalkResult::advance();
+
+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
+ if (!memEffectInterface)
+ return WalkResult::advance();
+
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Read>(effect.getEffect()) ||
+ isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value baseBuffer = effect.getValue();
+ if (!baseBuffer)
+ return WalkResult::interrupt();
+ for (Value val : writtenBuffer) {
+ if (aliasAnalysis.alias(baseBuffer, val) !=
+ AliasResult::NoAlias) {
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (walkRes.wasInterrupted())
+ return failure();
+ }
// A nested scf.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
@@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.setInsertionPointAfter(parent);
leftNestingScope = true;
seenSideeffects = false;
+ writtenBuffer.clear();
} else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
// Convert scf.reduction op
auto parentLoop = op->getParentOfType<ParallelOp>();
@@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
Operation *clone = rewriter.clone(*op, cloningMap);
cloningMap.map(op->getResults(), clone->getResults());
// Check for side effects.
+ if (!isMemoryEffectFree(clone)) {
+ // Record the buffer accessed by the operations with write effects.
+ if (auto memEffectInterface =
+ dyn_cast<MemoryEffectOpInterface>(clone)) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value writtenBase = effect.getValue();
+ // Conservatively return failure if we cannot find the written
+ // address.
+ if (!writtenBase)
+ return failure();
+ writtenBuffer.insert(writtenBase);
+ }
+ }
+ }
+ }
// TODO: Handle region side effects properly.
seenSideeffects |=
!isMemoryEffectFree(clone) || clone->getNumRegions() != 0;