aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2024-04-01 10:30:36 -0700
committerGitHub <noreply@github.com>2024-04-01 10:30:36 -0700
commita54930e696a275ac3947484f44d770cd587ce147 (patch)
tree4dfde57daa7d9cf10d6ff427e29a1dd12e3c7725
parent2cfd7d433be0831c6e2a248a4b828f7aedcaeaa0 (diff)
downloadllvm-a54930e696a275ac3947484f44d770cd587ce147.zip
llvm-a54930e696a275ac3947484f44d770cd587ce147.tar.gz
llvm-a54930e696a275ac3947484f44d770cd587ce147.tar.bz2
[mlir][sparse] allow YieldOp to yield multiple values. (#87261)
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td25
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp18
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp4
4 files changed, 29 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c3..5df8a17 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
- Arguments<(ins Optional<AnyType>:$result)> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+ ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+ "ForeachOp"]>]>,
+ Arguments<(ins Variadic<AnyType>:$results)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,27 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
let builders = [
OpBuilder<(ins),
[{
- build($_builder, $_state, Value());
+ build($_builder, $_state, ValueRange());
+ }]>,
+ OpBuilder<(ins "Value":$yieldVal),
+ [{
+ build($_builder, $_state, ValueRange(yieldVal));
}]>
];
+ let extraClassDeclaration = [{
+ Value getSingleResult() {
+ assert(hasSingleResult());
+ return getResults().front();
+ }
+ bool hasSingleResult() {
+ return getResults().size() == 1;
+ }
+ }];
+
let assemblyFormat = [{
- $result attr-dict `:` type($result)
+ $results attr-dict `:` type($results)
}];
- let hasVerifier = 1;
}
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb..e4d93c5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
if (!yield)
return op->emitError() << regionName
<< " region must end with sparse_tensor.yield";
- if (!yield.getResult() || yield.getResult().getType() != outputType)
+ if (!yield.hasSingleResult() ||
+ yield.getSingleResult().getType() != outputType)
return op->emitError() << regionName << " region yield type mismatch";
return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
// Absent branch can only yield invariant values.
Block *absentBlock = &absent.front();
Block *parent = getOperation()->getBlock();
- Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+ Value absentVal =
+ cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
if (arg.getOwner() == parent)
return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
return success();
}
-LogicalResult YieldOp::verify() {
- // Check for compatible parent.
- auto *parentOp = (*this)->getParentOp();
- if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
- return success();
-
- return emitOpError("expected parent op to be sparse_tensor unary, binary, "
- "reduce, select or foreach");
-}
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14ea07f..9c0fc60 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
if (numInitArgs != 0) {
rewriter.setInsertionPointToEnd(body);
auto yield = llvm::cast<YieldOp>(body->getTerminator());
- if (auto stt = tryGetSparseTensorType(yield.getResult());
+ if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
stt && !stt->isIdentity()) {
- Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
+ Value y =
+ genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
rewriter.create<YieldOp>(loc, y);
rewriter.eraseOp(yield);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 72b722c..9c0aed3 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
- const Value absentVal = absentYield.getResult();
+ const Value absentVal = absentYield.getSingleResult();
const ExprId rhs = addInvariantExp(absentVal);
return disjSet(e, child0, buildLattices(rhs, i), unop);
}
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
// Merge cloned block and return yield value.
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
- Value val = clonedYield.getResult();
+ Value val = clonedYield.getSingleResult();
rewriter.eraseOp(clonedYield);
rewriter.eraseOp(placeholder);
return val;