aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp61
1 files changed, 14 insertions, 47 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 3a02d56..cc05f1d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -28,21 +28,13 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}
-/// Makes target array's elements sorted according to the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
- ArrayRef<LoopId> order) {
+/// Sorts the dependent loops such that it is ordered in the same sequence in
+/// which loops will be generated.
+static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
std::sort(target.begin(), target.end(),
- [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
+ [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
assert(std::addressof(l) == std::addressof(r) || l != r);
- int idxL = -1, idxR = -1;
- for (int i = 0, e = order.size(); i < e; i++) {
- if (order[i] == l.first)
- idxL = i;
- if (order[i] == r.first)
- idxR = i;
- }
- assert(idxL >= 0 && idxR >= 0);
- return idxL < idxR;
+ return l.first < r.first;
});
}
//===----------------------------------------------------------------------===//
@@ -50,18 +42,12 @@ static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
//===----------------------------------------------------------------------===//
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
- unsigned numTensors, unsigned numLoops,
- unsigned numFilterLoops, unsigned maxRank)
+ unsigned numTensors, unsigned numLoops, unsigned maxRank)
: linalgOp(linop), sparseOptions(opts),
- latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
- loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
- insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
- redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
- redValidLexInsert() {
- // TODO: remove topSort, loops should be already sorted by previous pass.
- for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
- topSort.push_back(l);
-}
+ latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
+ sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
+ expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
+ redCustom(detail::kInvalidId), redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -97,7 +83,7 @@ void CodegenEnv::startEmit() {
(void)enc;
assert(!enc || lvlRank == enc.getLvlRank());
for (Level lvl = 0; lvl < lvlRank; lvl++)
- sortArrayBasedOnOrder(latticeMerger.getDependentLoops(tid, lvl), topSort);
+ sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
}
loopEmitter.initialize(
@@ -105,7 +91,7 @@ void CodegenEnv::startEmit() {
StringAttr::get(linalgOp.getContext(),
linalg::GenericOp::getOperationName()),
/*hasOutput=*/true,
- /*isSparseOut=*/sparseOut != nullptr, topSort,
+ /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
// TODO: compute the map and pass it to loop emitter directly instead of
// passing in a callback.
/*dependentLvlGetter=*/
@@ -190,8 +176,7 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// needed.
outerParNest = 0;
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
- assert(topSortSize() == latticeMerger.getNumLoops());
- for (const LoopId i : topSort) {
+ for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
if (linalg::isReductionIterator(iteratorTypes[i]))
break; // terminate at first reduction
outerParNest++;
@@ -208,26 +193,8 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// Code generation environment topological sort methods
//===----------------------------------------------------------------------===//
-ArrayRef<LoopId> CodegenEnv::getTopSortSlice(LoopOrd n, LoopOrd m) const {
- return ArrayRef<LoopId>(topSort).slice(n, m);
-}
-
-ArrayRef<LoopId> CodegenEnv::getLoopStackUpTo(LoopOrd n) const {
- return ArrayRef<LoopId>(topSort).take_front(n);
-}
-
-ArrayRef<LoopId> CodegenEnv::getCurrentLoopStack() const {
- return getLoopStackUpTo(loopEmitter.getCurrentDepth());
-}
-
Value CodegenEnv::getLoopVar(LoopId i) const {
- // TODO: this class should store the inverse of `topSort` so that
- // it can do this conversion directly, instead of searching through
- // `topSort` every time. (Or else, `LoopEmitter` should handle this.)
- for (LoopOrd n = 0, numLoops = topSortSize(); n < numLoops; n++)
- if (topSort[n] == i)
- return loopEmitter.getLoopIV(n);
- llvm_unreachable("invalid loop identifier");
+ return loopEmitter.getLoopIV(i);
}
//===----------------------------------------------------------------------===//