diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToGPU')
| -rw-r--r-- | mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 7d0a236..76a822b 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -27,6 +28,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/Support/DebugLog.h" #include <optional> @@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, bool seenSideeffects = false; // Whether we have left a nesting scope (and hence are no longer innermost). bool leftNestingScope = false; + LocalAliasAnalysis aliasAnalysis; + llvm::DenseSet<Value> writtenBuffer; while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); // Now walk over the body and clone it. // TODO: This is only correct if there either is no further scf.parallel - // nested or this code is side-effect free. Otherwise we might need - // predication. We are overly conservative for now and only allow - // side-effects in the innermost scope. + // nested or this code has side-effect but the memory buffer is not + // alias to inner loop access buffer. Otherwise we might need + // predication. if (auto nestedParallel = dyn_cast<ParallelOp>(op)) { // Before entering a nested scope, make sure there have been no - // sideeffects until now. - if (seenSideeffects) - return failure(); + // sideeffects until now or the nested operations do not access the + // buffer written by outer scope. + if (seenSideeffects) { + WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) { + if (isMemoryEffectFree(nestedOp)) + return WalkResult::advance(); + + auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp); + if (!memEffectInterface) + return WalkResult::advance(); + + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Read>(effect.getEffect()) || + isa<MemoryEffects::Write>(effect.getEffect())) { + Value baseBuffer = effect.getValue(); + if (!baseBuffer) + return WalkResult::interrupt(); + for (Value val : writtenBuffer) { + if (aliasAnalysis.alias(baseBuffer, val) != + AliasResult::NoAlias) { + return WalkResult::interrupt(); + } + } + } + } + return WalkResult::advance(); + }); + if (walkRes.wasInterrupted()) + return failure(); + } // A nested scf.parallel needs insertion of code to compute indices. // Insert that now. This will also update the worklist with the loops // body. @@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, rewriter.setInsertionPointAfter(parent); leftNestingScope = true; seenSideeffects = false; + writtenBuffer.clear(); } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) { // Convert scf.reduction op auto parentLoop = op->getParentOfType<ParallelOp>(); @@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, Operation *clone = rewriter.clone(*op, cloningMap); cloningMap.map(op->getResults(), clone->getResults()); // Check for side effects. + if (!isMemoryEffectFree(clone)) { + // Record the buffer accessed by the operations with write effects. + if (auto memEffectInterface = + dyn_cast<MemoryEffectOpInterface>(clone)) { + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Write>(effect.getEffect())) { + Value writtenBase = effect.getValue(); + // Conservatively return failure if we cannot find the written + // address. + if (!writtenBase) + return failure(); + writtenBuffer.insert(writtenBase); + } + } + } + } // TODO: Handle region side effects properly. seenSideeffects |= !isMemoryEffectFree(clone) || clone->getNumRegions() != 0; |
