diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF')
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 58 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 145 |
4 files changed, 160 insertions, 45 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 744a595..2946b53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -111,10 +111,8 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, return nullptr; } -/// Helper function to compute the difference between two values. This is used -/// by the loop implementations to compute the trip count. -static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, - bool isSigned) { +std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub, + bool isSigned) { llvm::APSInt diff; auto addOp = ub.getDefiningOp<arith::AddIOp>(); if (!addOp) @@ -399,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } //===----------------------------------------------------------------------===// @@ -407,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { + assert( + (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -428,7 +427,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getResults()); + regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -751,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null<ForOp>(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -761,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2055,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); else - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); } //===----------------------------------------------------------------------===// @@ -2335,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - // The `then` and the `else` region branch back to the parent operation. + // The `then` and the `else` region branch back to the parent operation or one + // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -2346,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2363,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } } @@ -3387,7 +3389,8 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + getOperation(), ResultRange{getResults().end(), getResults().end()})); } //===----------------------------------------------------------------------===// @@ -3433,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3516,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3530,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + assert(llvm::is_contained( + {&getAfter(), &getBefore()}, + point.getTerminatorPredecessorOrNull()->getParentRegion()) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point == getAfter()) { + if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4447,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getResults()); + successors.emplace_back(getOperation(), getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ae52af5..ddcbda8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,7 +23,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1..00bef70 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,7 +21,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 10eae89..888dd44 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, return arith::DivUIOp::create(builder, loc, sum, divisor); } -/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with -/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap -/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each -/// unrolled iteration using annotateFn. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, +void mlir::generateUnrolledLoop( + Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, - ValueRange iterArgs, ValueRange yieldedValues) { + ValueRange iterArgs, ValueRange yieldedValues, + IRMapping *clonedToSrcOpsMap) { + + // Check if the op was cloned from another source op, and return it if found + // (or the same op if not found) + auto findOriginalSrcOp = + [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * { + Operation *srcOp = op; + // If the source op derives from another op: traverse the chain to find the + // original source op + while (srcOp && clonedToSrcOpsMap.contains(srcOp)) + srcOp = clonedToSrcOpsMap.lookup(srcOp); + return srcOp; + }; + // Builder to insert unrolled bodies just before the terminator of the body of - // 'forOp'. + // the loop. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); - constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; + static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; if (!annotateFn) - annotateFn = defaultAnnotateFn; + annotateFn = noopAnnotateFn; // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); - // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + // Unroll the contents of the loop body (append unrollFactor - 1 additional + // copies). SmallVector<Value, 4> lastYielded(yieldedValues); for (unsigned i = 1; i < unrollFactor; i++) { - IRMapping operandMap; - // Prepare operand map. + IRMapping operandMap; operandMap.map(iterArgs, lastYielded); // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forOpIV.use_empty()) { - Value ivUnroll = ivRemapFn(i, forOpIV, builder); - operandMap.map(forOpIV, ivUnroll); + if (!iv.use_empty()) { + Value ivUnroll = ivRemapFn(i, iv, builder); + operandMap.map(iv, ivUnroll); } // Clone the original body of 'forOp'. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) { - Operation *clonedOp = builder.clone(*it, operandMap); + Operation *srcOp = &(*it); + Operation *clonedOp = builder.clone(*srcOp, operandMap); annotateFn(i, clonedOp, builder); + if (clonedToSrcOpsMap) + clonedToSrcOpsMap->map(clonedOp, + findOriginalSrcOp(srcOp, *clonedToSrcOpsMap)); } // Update yielded values. @@ -1544,3 +1558,100 @@ bool mlir::isPerfectlyNestedForLoops( } return true; } + +llvm::SmallVector<int64_t> +mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) { + std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds(); + std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds(); + std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps(); + if (!loBnds || !upBnds || !steps) + return {}; + llvm::SmallVector<int64_t> tripCounts; + for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) { + std::optional<llvm::APInt> numIter = constantTripCount( + lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb); + if (!numIter) + return {}; + tripCounts.push_back(numIter->getSExtValue()); + } + return tripCounts; +} + +FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors( + scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors, + RewriterBase &rewriter, + function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, + IRMapping *clonedToSrcOpsMap) { + const unsigned numLoops = op.getNumLoops(); + assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) && + "Expected positive unroll factors"); + assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) && + "Expected non-empty unroll factors of size <= to the number of loops"); + + // Bail out if no valid unroll factors were provided + if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; })) + return rewriter.notifyMatchFailure( + op, "Unrolling not applied if all factors are 1"); + + // Return if the loop body is empty. + if (llvm::hasSingleElement(op.getBody()->getOperations())) + return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body"); + + // If the provided unroll factors do not cover all the loop dims, they are + // applied to the inner loop dimensions. + const unsigned firstLoopDimIdx = numLoops - unrollFactors.size(); + + // Make sure that the unroll factors divide the iteration space evenly + // TODO: Support unrolling loops with dynamic iteration spaces. + const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op); + if (tripCounts.empty()) + return rewriter.notifyMatchFailure( + op, "Failed to compute constant trip counts for the loop. Note that " + "dynamic loop sizes are not supported."); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (tripCounts[dimIdx] % unrollFactor) + return rewriter.notifyMatchFailure( + op, "Unroll factors don't divide the iteration space evenly"); + } + + std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps(); + if (!maybeFoldSteps) + return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps"); + llvm::SmallVector<size_t> steps{}; + for (auto step : *maybeFoldSteps) + steps.push_back(static_cast<size_t>(*getConstantIntValue(step))); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (unrollFactor == 1) + continue; + const size_t origStep = steps[dimIdx]; + const int64_t newStep = origStep * unrollFactor; + IRMapping clonedToSrcOpsMap; + + ValueRange iterArgs = ValueRange(op.getRegionIterArgs()); + auto yieldedValues = op.getBody()->getTerminator()->getOperands(); + + generateUnrolledLoop( + op.getBody(), op.getInductionVars()[dimIdx], unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i); + const auto map = + b.getDimIdentityMap().dropResult(0).insertResult(expr, 0); + return affine::AffineApplyOp::create(b, iv.getLoc(), map, + ValueRange{iv}); + }, + /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap); + + // Update loop step + auto prevInsertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + op.getStepMutable()[dimIdx].assign( + arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep)); + rewriter.restoreInsertionPoint(prevInsertPoint); + } + return op; +} |
