aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAart Bik <39774503+aartbik@users.noreply.github.com>2023-12-12 12:44:46 -0800
committerGitHub <noreply@github.com>2023-12-12 12:44:46 -0800
commit047399c213a007f91b5d472cfe6742d5b7be70f3 (patch)
treec092cd4bb4863c0196376836c5b1f7e33f066b3d
parentc77cdbac9b121611121adf5806a99aff4812a40c (diff)
downloadllvm-047399c213a007f91b5d472cfe6742d5b7be70f3.zip
llvm-047399c213a007f91b5d472cfe6742d5b7be70f3.tar.gz
llvm-047399c213a007f91b5d472cfe6742d5b7be70f3.tar.bz2
[mlir][sparse] cleanup of CodegenEnv reduction API (#75243)
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp27
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h9
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp26
3 files changed, 36 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 312aefc..4bd3af2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -115,10 +115,10 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
SmallVector<Value> params;
if (isReduc()) {
params.push_back(redVal);
- if (redValidLexInsert)
+ if (isValidLexInsert())
params.push_back(redValidLexInsert);
} else {
- assert(!redValidLexInsert);
+ assert(!isValidLexInsert());
}
if (isExpand())
params.push_back(expCount);
@@ -128,8 +128,8 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
unsigned i = 0;
if (isReduc()) {
updateReduc(params[i++]);
- if (redValidLexInsert)
- setValidLexInsert(params[i++]);
+ if (isValidLexInsert())
+ updateValidLexInsert(params[i++]);
}
if (isExpand())
updateExpandCount(params[i++]);
@@ -235,14 +235,14 @@ void CodegenEnv::endExpand() {
//===----------------------------------------------------------------------===//
void CodegenEnv::startReduc(ExprId exp, Value val) {
- assert(!isReduc() && exp != detail::kInvalidId);
+ assert(!isReduc() && exp != detail::kInvalidId && val);
redExp = exp;
redVal = val;
latticeMerger.setExprValue(exp, val);
}
void CodegenEnv::updateReduc(Value val) {
- assert(isReduc());
+ assert(isReduc() && val);
redVal = val;
latticeMerger.clearExprValue(redExp);
latticeMerger.setExprValue(redExp, val);
@@ -257,13 +257,18 @@ Value CodegenEnv::endReduc() {
return val;
}
-void CodegenEnv::setValidLexInsert(Value val) {
- assert(isReduc() && val);
+void CodegenEnv::startValidLexInsert(Value val) {
+ assert(!isValidLexInsert() && isReduc() && val);
+ redValidLexInsert = val;
+}
+
+void CodegenEnv::updateValidLexInsert(Value val) {
+ assert(redValidLexInsert && isReduc() && val);
redValidLexInsert = val;
}
-void CodegenEnv::clearValidLexInsert() {
- assert(!isReduc());
+void CodegenEnv::endValidLexInsert() {
+ assert(isValidLexInsert() && !isReduc());
redValidLexInsert = Value();
}
@@ -272,7 +277,7 @@ void CodegenEnv::startCustomReduc(ExprId exp) {
redCustom = exp;
}
-Value CodegenEnv::getCustomRedId() {
+Value CodegenEnv::getCustomRedId() const {
assert(isCustomReduc());
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index a1947f4..cd626041 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -150,13 +150,16 @@ public:
void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
- void setValidLexInsert(Value val);
- void clearValidLexInsert();
+
+ void startValidLexInsert(Value val);
+ bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
+ void updateValidLexInsert(Value val);
Value getValidLexInsert() const { return redValidLexInsert; }
+ void endValidLexInsert();
void startCustomReduc(ExprId exp);
bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
- Value getCustomRedId();
+ Value getCustomRedId() const;
void endCustomReduc();
private:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 992be43..2367d3b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -415,9 +415,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
Value chain = env.getInsertionChain();
- if (!env.getValidLexInsert()) {
- env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
- } else {
+ if (env.isValidLexInsert()) {
// Generates runtime check for a valid lex during reduction,
// to avoid inserting the identity value for empty reductions.
// if (validLexInsert) then
@@ -438,6 +436,9 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// Value assignment.
builder.setInsertionPointAfter(ifValidLexInsert);
env.updateInsertionChain(ifValidLexInsert.getResult(0));
+ } else {
+ // Generates regular insertion chain.
+ env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
}
return;
}
@@ -688,12 +689,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
env.startReduc(exp, genTensorLoad(env, builder, exp));
}
if (env.hasSparseOutput())
- env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
+ env.startValidLexInsert(
+ constantI1(builder, env.op().getLoc(), false));
} else {
if (!env.isCustomReduc() || env.isReduc())
genTensorStore(env, builder, exp, env.endReduc());
if (env.hasSparseOutput())
- env.clearValidLexInsert();
+ env.endValidLexInsert();
}
} else {
// Start or end loop invariant hoisting of a tensor load.
@@ -846,9 +848,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
if (env.isReduc()) {
yields.push_back(env.getReduc());
env.updateReduc(ifOp.getResult(y++));
- if (env.getValidLexInsert()) {
+ if (env.isValidLexInsert()) {
yields.push_back(env.getValidLexInsert());
- env.setValidLexInsert(ifOp.getResult(y++));
+ env.updateValidLexInsert(ifOp.getResult(y++));
}
}
if (env.isExpand()) {
@@ -904,7 +906,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
});
if (env.isReduc()) {
types.push_back(env.getReduc().getType());
- if (env.getValidLexInsert())
+ if (env.isValidLexInsert())
types.push_back(env.getValidLexInsert().getType());
}
if (env.isExpand())
@@ -924,10 +926,10 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
if (env.isReduc()) {
operands.push_back(env.getReduc());
env.updateReduc(redInput);
- if (env.getValidLexInsert()) {
+ if (env.isValidLexInsert()) {
// Any overlapping indices during a reduction creates a valid lex insert.
operands.push_back(constantI1(builder, env.op().getLoc(), true));
- env.setValidLexInsert(validIns);
+ env.updateValidLexInsert(validIns);
}
}
if (env.isExpand()) {
@@ -1174,8 +1176,8 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
// Either a for-loop or a while-loop that iterates over a slice.
if (isSingleCond) {
// Any iteration creates a valid lex insert.
- if (env.isReduc() && env.getValidLexInsert())
- env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
+ if (env.isReduc() && env.isValidLexInsert())
+ env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
// End a while-loop.
finalizeWhileOp(env, rewriter, needsUniv);