aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAart Bik <39774503+aartbik@users.noreply.github.com>2023-11-28 17:46:02 -0800
committerGitHub <noreply@github.com>2023-11-28 17:46:02 -0800
commit98f8b1afb408e488274e77533dbb8041c9513ca7 (patch)
tree4589dadb143efb3a4767d34ff99eacee436d3b22
parent83305faeb54cf254d133e39ddac7f2fee36972cc (diff)
downloadllvm-98f8b1afb408e488274e77533dbb8041c9513ca7.zip
llvm-98f8b1afb408e488274e77533dbb8041c9513ca7.tar.gz
llvm-98f8b1afb408e488274e77533dbb8041c9513ca7.tar.bz2
[mlir][sparse] remove COO test from trait and encoding (#73733)
This is a minor step towards moving ALL COO related tests into the SparseTensorType class rather than having it all over the place (with risk of becoming inconsistent). Next revision will move ALL COO related methods into this class.
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td15
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td15
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp12
3 files changed, 13 insertions, 29 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 1e62d99..12c1068 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -382,9 +382,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// the null encoding (since dense-tensors are always all-dense).
bool isAllDense() const;
- /// Returns true if it is a sparse tensor encoding in COO format.
- bool isCOO() const;
-
/// Returns true if every level is ordered. Also returns true for
/// the null encoding (since dense-tensors are always all-ordered).
bool isAllOrdered() const;
@@ -467,10 +464,6 @@ def SparseTensorStorageSpecifierKindAttr
def IsSparseTensorPred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
-def IsCOOPred
- : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
- " ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isCOO()">;
-
def IsSparseTensorSlicePred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
@@ -478,22 +471,14 @@ def IsSparseTensorSlicePred
class SparseTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
-class COOSparseTensorOf<list<Type> allowedTypes>
- : TensorOf<allowedTypes, [IsCOOPred], "COO sparse tensor">;
-
class SparseTensorSliceOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
-class RankedSparseTensorOf<list<Type> allowedTypes>
- : RankedTensorOf<allowedTypes, [IsSparseTensorPred], "ranked sparse tensor">;
-
class ScalarLikeOf<list<Type> allowedTypes>
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
def AnySparseTensor : SparseTensorOf<[AnyType]>;
-def AnyCOOSparseTensor : COOSparseTensorOf<[AnyType]>;
def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>;
-def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
def AnyIndexingScalarLike : ScalarLikeOf<[AnySignlessIntegerOrIndex]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 78031f2..3127cf1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -921,10 +921,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
- Sparse_tensor.sort sort the `xs` values along with some `ys` values
- that are put in a single linear buffer `xy`.
- The affine map attribute `perm_map` specifies the permutation to be applied on
- the `xs` before comparison, the rank of the permutation map
+ Sorts the `xs` values along with some `ys` values that are put in a single linear
+ buffer `xy`. The affine map attribute `perm_map` specifies the permutation to be
+ applied on the `xs` before comparison, the rank of the permutation map
also specifies the number of `xs` values in `xy`.
The optional index attribute `ny` provides the number of `ys` values in `xy`.
When `ny` is not explicitly specified, its value is 0.
@@ -950,14 +949,14 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
}
def SparseTensor_ReorderCOOOp : SparseTensor_Op<"reorder_coo", [Pure]>,
- Arguments<(ins AnyCOOSparseTensor: $input_coo,
+ Arguments<(ins AnySparseTensor: $input_coo,
SparseTensorSortKindAttr:$algorithm)>,
- Results<(outs AnyCOOSparseTensor: $result_coo)> {
+ Results<(outs AnySparseTensor: $result_coo)> {
let summary = "Reorder the input COO such that it has the the same order as "
"the output COO";
let description = [{
- sparse_tensor.reorder_coo reorder input COO to the same order as specified by
- the output format. E.g., reorder an unordered COO into an ordered one.
+ Reorders the input COO to the same order as specified by the output format.
+ E.g., reorder an unordered COO into an ordered one.
The input and result COO tensor must have the same element type, position type and
coordinate type. At the moment, the operation also only supports ordering
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 74d2fd5..d4f8afd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -316,10 +316,6 @@ bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
}
-bool SparseTensorEncodingAttr::isCOO() const {
- return getImpl() && isCOOType(*this, 0, true);
-}
-
bool SparseTensorEncodingAttr::isAllOrdered() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
}
@@ -1664,14 +1660,18 @@ LogicalResult ReorderCOOOp::verify() {
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
+ if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
+ !isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
+ emitError("Unexpected non-COO sparse tensors");
+
if (!srcStt.hasSameDimToLvl(dstStt))
emitError("Unmatched dim2lvl map between input and result COO");
if (srcStt.getPosType() != dstStt.getPosType() ||
srcStt.getCrdType() != dstStt.getCrdType() ||
- srcStt.getElementType() != dstStt.getElementType()) {
+ srcStt.getElementType() != dstStt.getElementType())
emitError("Unmatched storage format between input and result COO");
- }
+
return success();
}