From 98ce2debc6ff3f6d31d7b63eb54e10e88a84ee78 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Wed, 6 Dec 2023 13:23:50 -0800 Subject: [mlir][sparse] cleanup ldx/idx/depth/at usage (#74654) This adds a consistent usage with `at` for everything that refers to the current loop nesting. This cleans up some redundant legacy code from when we were still using topSort inside sparsifier code. --- .../SparseTensor/Transforms/Sparsification.cpp | 139 ++++++++++----------- 1 file changed, 69 insertions(+), 70 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index d03e961..6637a26 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor; // Sparsifier analysis methods. //===----------------------------------------------------------------------===// -/// Determines if affine expression is invariant. -static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx, - bool &isAtLoop) { +/// Returns true iff affine expression is invariant. Sets the +/// parameter `isAtLoop` when expression just became invariant. +static bool isInvariantAffine(AffineExpr a, LoopId at, bool &isAtLoop) { switch (a.getKind()) { case AffineExprKind::DimId: { const LoopId i = cast(a).getPosition(); - if (i == ldx) { + if (i + 1 == at) { isAtLoop = true; - return true; // invariant at given loop + return true; // becomes invariant at current loop } - return i < loopDepth; // invariant when already generated + return i < at; // invariant when already generated } case AffineExprKind::Add: case AffineExprKind::Mul: { auto binOp = cast(a); - return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) && - isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop); + return isInvariantAffine(binOp.getLHS(), at, isAtLoop) && + isInvariantAffine(binOp.getRHS(), at, isAtLoop); } default: { assert(isa(a)); @@ -126,8 +126,8 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, if (coefficient <= 0) return false; - const LoopId ldx = merger.makeLoopId(cast(a).getPosition()); - if (!isUndefLT(merger.getLvlType(tensor, ldx))) + const LoopId idx = merger.makeLoopId(cast(a).getPosition()); + if (!isUndefLT(merger.getLvlType(tensor, idx))) return false; // used more than once, e.g., A[i][i] // TODO: Generalizes the following two cases. A[i] (with trivial index @@ -135,14 +135,14 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, // not necessarily need to differentiate them. if (!isSubExp) { assert(coefficient == 1); - merger.setLevelAndType(tensor, ldx, lvl, lt); + merger.setLevelAndType(tensor, idx, lvl, lt); } if (isSubExp) { // The current loops appears in more than one affine expressions on the // same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is // used twice. - if (merger.hasDependentLvl(ldx, tensor)) { + if (merger.hasDependentLvl(idx, tensor)) { // TODO: This can be supported by coiterate slices if the loop idx is // appeared on affine index for different tensor, or take slice on // multiple dimensions when it is on the same tensor. @@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, // else increase min(d0_1, d0_2). return false; } - merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient); + merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient); } return true; } @@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { if (kind == TensorExp::Kind::kReduce) env.startCustomReduc(e); // enter custom - Value v0, v1; // If either lhs/rhs is a synthetic zero, we infer the type for the zero value // based on the type of the other operand. + Value v0, v1; if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) { v1 = genExp(env, rewriter, exp.children.e1); @@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, - LoopId ldx, bool atStart) { + LoopId at, bool atStart) { if (exp == ::mlir::sparse_tensor::detail::kInvalidId) return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. - bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId; linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); const auto map = op.getMatchingIndexingMap(&t); const auto stt = getSparseTensorType(t.get()); const Level lvlRank = stt.getLvlRank(); assert(static_cast(map.getNumResults()) == lvlRank); + bool isAtLoop = at == 0; // for scalar tensors for (Level l = 0; l < lvlRank; l++) { const AffineExpr a = map.getResult(l); - if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop)) + if (!isInvariantAffine(a, at, /*out*/ isAtLoop)) return; // still in play } // All exhausted at this level (isAtLoop denotes exactly at this LoopId). @@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, env.startCustomReduc(exp); // enter custom const ExprId e0 = env.exp(exp).children.e0; const ExprId e1 = env.exp(exp).children.e1; - genInvariants(env, builder, e0, ldx, atStart); - genInvariants(env, builder, e1, ldx, atStart); + genInvariants(env, builder, e0, at, atStart); + genInvariants(env, builder, e1, at, atStart); if (env.exp(exp).kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom } @@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { /// Whether or not the current loop being generated should be parallized (if /// possible) according to the configuration. -static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter, +static bool shouldTryParallize(CodegenEnv &env, LoopId at, ArrayRef tidLvls) { linalg::GenericOp op = env.op(); auto iteratorTypes = op.getIteratorTypesArray(); - bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) { - // Queries the LT based on the tensor id and loop idx, as requested by - // `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv + bool isSparse = llvm::any_of(tidLvls, [at, &env](TensorLevel tidLvl) { + // Queries the LT based on the tensor and loop id, as requested by + // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv // should be consistent with the LT indexed by . - const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx); + const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, at); return isCompressedLT(lt) || isSingletonLT(lt); }); - return isParallelFor(env, isOuter, isSparse); + return isParallelFor(env, /*isOuter=*/at == 0, isSparse); } /// Emit a loop to coiterate over the list of tensor levels. The generated loop /// can either be a for loop or while loop depending on whether there is at most /// one sparse level in the list. static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, - LoopId idx, ArrayRef tidLvls, + ArrayRef tidLvls, bool tryParallel, bool needsUniv) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { - // Construct the while-loop with a parameter for each - // index. + // Construct while-loop with a parameter for each index. return env.emitter().enterCoIterationOverTensorsAtLvls( builder, env.op().getLoc(), tidLvls, reduc, tryParallel, /*genDedup=*/true, needsUniv); @@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId at, bool needsUniv, ArrayRef tidLvls) { - bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls); - return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv); + bool tryParallel = shouldTryParallize(env, at, tidLvls); + return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv); } /// Generates the induction structure for a while-loop. -static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, +static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv) { Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. @@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, } /// Generates a single if-statement within a while-loop. -static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, +static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId at, LatPointId p) { Location loc = env.op().getLoc(); SmallVector types; @@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, auto stt = getSparseTensorType(env.op().getInputs()[tid]); lt = stt.getLvlType(*lvl); } - assert(ldx == env.merger().loop(b)); + assert(at == env.merger().loop(b)); Value clause; if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || is2OutOf4LT(lt)) { assert(lvl.has_value()); const Value crd = env.emitter().getCoords()[tid][*lvl]; - const Value lvar = env.getLoopVar(ldx); + const Value lvar = env.getLoopVar(at); clause = builder.create(loc, arith::CmpIPredicate::eq, crd, lvar); } else { @@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, /// Starts a loop sequence at given level. Returns true if /// the universal loop index must be maintained at this level. static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, - LoopId idx, LoopId ldx, LatSetId lts) { - assert(!env.getLoopVar(idx)); + LoopId at, LatSetId lts) { + assert(!env.getLoopVar(at)); // Emit invariants at this loop sequence level. - genInvariants(env, builder, exp, ldx, /*atStart=*/true); + genInvariants(env, builder, exp, at, /*atStart=*/true); // Emit access pattern expansion for sparse tensor output. - genExpand(env, builder, idx, /*atStart=*/true); + genExpand(env, builder, at, /*atStart=*/true); // Emit further intitialization at this loop sequence level. const LatPointId l0 = env.set(lts)[0]; bool needsUniv = false; @@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, std::optional lvl, LevelType lt, bool isIdxReduc) { - assert(env.merger().loop(b) == idx); + assert(env.merger().loop(b) == at); if (isDenseLT(lt) || isUndefLT(lt)) { if (tid == env.merger().getSynTensorID()) { // Needs loop emitter to set up loop bounds for synthetic tensor too if @@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, return false; } +// Generates dense affine address for encoding. static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl) { @@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env, } } +// We can generate address for constant affine expression before any loops +// starting from the first level as they do not depend on any thing. +// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two +// levels can be determined before loops. static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter) { - // We can generate address for constant affine expression before any loops - // starting from the first level as they do not depend on any thing. - // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two - // levels can be determined before loops. for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } /// Return true if the lattices bit can be iterated by a for loop. static bool translateBitsToTidLvlPairs( - CodegenEnv &env, LatPointId li, LoopId ldx, + CodegenEnv &env, LatPointId li, LoopId at, SmallVectorImpl &tidLvls, SmallVectorImpl> &affineTidLvls) { const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); - const std::optional outLvl = env.merger().getLvl(outTid, ldx); + const std::optional outLvl = env.merger().getLvl(outTid, at); unsigned numloopCond = 0; bool hasNonUnique = false; - env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid, - std::optional lvl, - LevelType lt, bool isIdxReduc) { + env.merger().foreachTensorLoopId(li, [&, at](TensorLoopId b, TensorId tid, + std::optional lvl, + LevelType lt, bool isIdxReduc) { if (simple[b]) { if (isIdxReduc) { tidLvls.push_back(env.makeTensorLevel(tid, *lvl)); @@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs( if (isa(exp) || !stt.isDenseLvl(l)) continue; - // Constant affine expression are handled in genLoop + // Constant affine expression are handled in genLoop. if (!isa(exp)) { bool isAtLoop = false; - if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) && - isAtLoop) { + assert(at == env.getLoopDepth()); + if (isInvariantAffine(exp, at + 1, /*out*/ isAtLoop) && isAtLoop) { // If the compound affine is invariant and we are right at the // level. We need to generate the address according to the // affine expression. This is also the best place we can do it @@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs( } }); - if (isDenseLT(env.lt(outTid, ldx))) { + if (isDenseLT(env.lt(outTid, at))) { // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized env. @@ -1131,9 +1131,9 @@ static std::pair startLoop(CodegenEnv &env, LatPointId li, bool needsUniv) { // The set of tensors + lvls to generate loops on SmallVector tidLvls; + // The set of dense tensors with non-trivial affine expression that just - // becomes invariant and the address shall now be generated at the current - // level. + // becomes invariant and the address are generated at the current level. SmallVector> affineTidLvls; bool isSingleCond = translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls); @@ -1161,38 +1161,34 @@ static std::pair startLoop(CodegenEnv &env, /// Ends a single loop in current sequence. Returns new values for needsUniv. static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, - LoopId idx, LatPointId li, bool needsUniv, - bool isSingleCond) { - + LatPointId li, bool needsUniv, bool isSingleCond) { + // Either a for-loop or a while-loop that iterates over a slice. if (isSingleCond) { - // Either a for-loop or a while-loop that iterates over a slice. // Any iteration creates a valid lex insert. if (env.isReduc() && env.getValidLexInsert()) env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); } else if (auto whileOp = dyn_cast(loop)) { // End a while-loop. - finalizeWhileOp(env, rewriter, idx, needsUniv); + finalizeWhileOp(env, rewriter, needsUniv); } else { needsUniv = false; } - env.genLoopBoundary([&](MutableArrayRef reduc) { env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); return std::nullopt; }); - return needsUniv; } /// Ends a loop sequence at given level. static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, - unsigned idx, unsigned ldx) { - assert(!env.getLoopVar(idx)); + unsigned at) { + assert(!env.getLoopVar(at)); env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); // Unmark bookkeeping of invariants and loop index. - genInvariants(env, builder, exp, ldx, /*atStart=*/false); + genInvariants(env, builder, exp, at, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. - genExpand(env, builder, idx, /*atStart=*/false); + genExpand(env, builder, at, /*atStart=*/false); } /// Recursively generates code while computing iteration lattices in order @@ -1200,6 +1196,8 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, /// and intersections of sparse iterations spaces. static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId at) { + assert(at == env.getLoopDepth()); + // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == env.getLoopNum()) { Value rhs = genExp(env, rewriter, exp); @@ -1207,13 +1205,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, return; } - // Construct iteration lattices for current loop index, with L0 at top. - const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1; + // Construct iteration lattices for current loop index. const LatSetId lts = env.merger().optimizeSet(env.merger().buildLattices(exp, at)); // Start a loop sequence. - bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts); + bool needsUniv = startLoopSeq(env, rewriter, exp, at, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. // We cannot change this to `for (const LatPointId li : env.set(lts))` @@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, } // End a loop. - needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond); + needsUniv = endLoop(env, rewriter, loop, at, needsUniv, isSingleCond); } // End a loop sequence. - endLoopSeq(env, rewriter, exp, at, ldx); + endLoopSeq(env, rewriter, exp, at); + assert(at == env.getLoopDepth()); } /// Converts the result computed by the sparse kernel into the required form. @@ -1309,6 +1307,7 @@ public: op, "Loops not yet scheduled, try run --sparse-reinterpret-map " "before sparsification."); } + // Must have been demapped as well if the generic op is sorted. assert(!hasAnyNonIdentityOperandsOrResults(op)); -- cgit v1.1