diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp')
| -rw-r--r-- | mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 64 | 
1 files changed, 58 insertions, 6 deletions
| diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 7d0a236..76a822b 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -14,6 +14,7 @@  #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"  #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  #include "mlir/Dialect/Affine/IR/AffineOps.h"  #include "mlir/Dialect/Arith/IR/Arith.h" @@ -27,6 +28,7 @@  #include "mlir/Interfaces/SideEffectInterfaces.h"  #include "mlir/Transforms/DialectConversion.h"  #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h"  #include "llvm/Support/DebugLog.h"  #include <optional> @@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,    bool seenSideeffects = false;    // Whether we have left a nesting scope (and hence are no longer innermost).    bool leftNestingScope = false; +  LocalAliasAnalysis aliasAnalysis; +  llvm::DenseSet<Value> writtenBuffer;    while (!worklist.empty()) {      Operation *op = worklist.pop_back_val();      // Now walk over the body and clone it.      // TODO: This is only correct if there either is no further scf.parallel -    //       nested or this code is side-effect free. Otherwise we might need -    //       predication. We are overly conservative for now and only allow -    //       side-effects in the innermost scope. +    //       nested or this code has side-effect but the memory buffer is not +    //       alias to inner loop access buffer. Otherwise we might need +    //       predication.      if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {        // Before entering a nested scope, make sure there have been no -      // sideeffects until now. -      if (seenSideeffects) -        return failure(); +      // sideeffects until now or the nested operations do not access the +      // buffer written by outer scope. +      if (seenSideeffects) { +        WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) { +          if (isMemoryEffectFree(nestedOp)) +            return WalkResult::advance(); + +          auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp); +          if (!memEffectInterface) +            return WalkResult::advance(); + +          SmallVector<MemoryEffects::EffectInstance> effects; +          memEffectInterface.getEffects(effects); +          for (const MemoryEffects::EffectInstance &effect : effects) { +            if (isa<MemoryEffects::Read>(effect.getEffect()) || +                isa<MemoryEffects::Write>(effect.getEffect())) { +              Value baseBuffer = effect.getValue(); +              if (!baseBuffer) +                return WalkResult::interrupt(); +              for (Value val : writtenBuffer) { +                if (aliasAnalysis.alias(baseBuffer, val) != +                    AliasResult::NoAlias) { +                  return WalkResult::interrupt(); +                } +              } +            } +          } +          return WalkResult::advance(); +        }); +        if (walkRes.wasInterrupted()) +          return failure(); +      }        // A nested scf.parallel needs insertion of code to compute indices.        // Insert that now. This will also update the worklist with the loops        // body. @@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,        rewriter.setInsertionPointAfter(parent);        leftNestingScope = true;        seenSideeffects = false; +      writtenBuffer.clear();      } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {        // Convert scf.reduction op        auto parentLoop = op->getParentOfType<ParallelOp>(); @@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,        Operation *clone = rewriter.clone(*op, cloningMap);        cloningMap.map(op->getResults(), clone->getResults());        // Check for side effects. +      if (!isMemoryEffectFree(clone)) { +        // Record the buffer accessed by the operations with write effects. +        if (auto memEffectInterface = +                dyn_cast<MemoryEffectOpInterface>(clone)) { +          SmallVector<MemoryEffects::EffectInstance> effects; +          memEffectInterface.getEffects(effects); +          for (const MemoryEffects::EffectInstance &effect : effects) { +            if (isa<MemoryEffects::Write>(effect.getEffect())) { +              Value writtenBase = effect.getValue(); +              // Conservatively return failure if we cannot find the written +              // address. +              if (!writtenBase) +                return failure(); +              writtenBuffer.insert(writtenBase); +            } +          } +        } +      }        // TODO: Handle region side effects properly.        seenSideeffects |=            !isMemoryEffectFree(clone) || clone->getNumRegions() != 0; | 
