diff options
author | Matthias Springer <me@m-sp.org> | 2024-11-19 09:27:51 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-19 09:27:51 +0900 |
commit | 204234a69c068032a1adac31f00b51f3b9efa778 (patch) | |
tree | 959e979fbb5a5534a8ffef39e8a970950c920847 | |
parent | 5ae4d505c38872b3faaeea5779f6c25a9138bbc5 (diff) | |
download | llvm-204234a69c068032a1adac31f00b51f3b9efa778.zip llvm-204234a69c068032a1adac31f00b51f3b9efa778.tar.gz llvm-204234a69c068032a1adac31f00b51f3b9efa778.tar.bz2 |
[mlir][SparseTensor][NFC] Pass tensor type to descriptor helper (#116468)
`getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple`
extract the tensor type from an `unrealized_conversion_cast` op that
serves as a workaround for missing 1:N dialect conversion support.
This commit changes these functions so that they explicitly receive the
tensor type as a function argument. This is in preparation of merging
the 1:1 and 1:N conversion drivers. The conversion patterns in this file
will soon start receiving multiple SSA values (`ValueRange`) from their
adaptors (instead of a single value that is the result of
`unrealized_conversion_cast`). It will no longer be possible to take the
tensor type from the `unrealized_conversion_cast` op. The
`unrealized_conversion_cast` workaround will disappear entirely.
4 files changed, 44 insertions, 34 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index bf7b3f9..25fca49 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -646,10 +646,11 @@ public: matchAndRewrite(LvlOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::optional<int64_t> lvl = op.getConstantLvlIndex(); - if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType())) + RankedTensorType srcType = op.getSource().getType(); + if (!lvl || !getSparseTensorEncoding(srcType)) return failure(); - auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); rewriter.replaceOp(op, sz); @@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { assert(dstStt.hasSameDimToLvl(srcStt)); // We don't need a mutable descriptor here as we perform sorting in-place. - auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo()); - auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo()); + auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), + op.getInputCoo().getType()); + auto nnz = desc.getValMemSize(rewriter, op.getLoc()); auto crd = desc.getAOSMemRef(); auto val = desc.getValMemRef(); @@ -704,7 +706,8 @@ public: matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply lowers to specifer.get <field> operation. - auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); + auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), + op.getSlice().getType()); auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, op.getDim().getZExtValue()); @@ -762,7 +765,8 @@ public: Location loc = op.getLoc(); // Deal with copy. if (op.getCopy()) { - auto desc = getDescriptorFromTensorTuple(adaptor.getCopy()); + auto desc = getDescriptorFromTensorTuple( + adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType())); SmallVector<Value> fields; fields.reserve(desc.getNumFields()); // Memcpy on memref fields. @@ -868,7 +872,9 @@ public: if (createDeallocs) { // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple( + adaptor.getTensor(), + cast<RankedTensorType>(op.getTensor().getType())); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create<memref::DeallocOp>(loc, input); @@ -889,7 +895,8 @@ public: matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); @@ -909,7 +916,8 @@ public: if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); Location loc = op->getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); const auto srcType = getSparseTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); @@ -959,7 +967,8 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); SmallVector<Value> fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, + op.getTensor().getType()); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); @@ -1032,7 +1041,8 @@ public: assert(stt.isIdentity() && "Run reinterpret-map before conversion."); Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getDest()); + auto desc = + getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); TypeRange flatSpTensorTps = desc.getFields().getTypes(); SmallVector<Value> params = llvm::to_vector(desc.getFields()); params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); @@ -1059,7 +1069,8 @@ public: // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getPosMemRef(lvl); auto size = desc.getPosMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1081,7 +1092,8 @@ public: // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { auto size = desc.getCrdMemSize(rewriter, loc, lvl); @@ -1106,7 +1118,8 @@ public: // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getAOSMemRef(); auto size = desc.getCrdMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1126,7 +1139,8 @@ public: // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getValMemRef(); auto size = desc.getValMemSize(rewriter, loc); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1172,7 +1186,8 @@ public: // else: // dst = memref.copy(src) Location loc = op.getLoc(); - auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), + op.getSource().getType()); SmallVector<Value> fields; foreachFieldAndTypeInSparseTensor( SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), @@ -1236,7 +1251,8 @@ public: assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); SmallVector<Value> fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); + auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, + op.getSource().getType()); auto newSpec = rewriter.create<StorageSpecifierInitOp>( loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); @@ -1285,8 +1301,9 @@ public: // Query memSizes for the actually stored values. // FIXME: the nse value computed in this way might be wrong when there is // any "loose_compressed" level. - rewriter.replaceOp( - op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor())); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); + rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } }; @@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter LogicalResult matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); Location loc = op.getLoc(); SmallVector<Value> retMem; SmallVector<Value> retLen; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index de553a5..f923824 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { .getResult(); } -Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, - Value tensor) { - return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); -} - Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim) { auto enc = getSparseTensorEncoding(tensor.getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index d0ef8a6..dc017e6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc, Value tensor); -/// Generates code to retrieve the values size for the sparse tensor. -Value genValMemSize(OpBuilder &builder, Location loc, Value tensor); - /// Generates code to retrieve the slice offset for the sparse tensor slice, /// return a constant if the offset is statically known. Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index c2f6316..8985854 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc, return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); } -inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { +inline SparseTensorDescriptor +getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) { auto tuple = getTuple(tensor); - SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0])); - return SparseTensorDescriptor(stt, tuple.getInputs()); + return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs()); } inline MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) { +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields, + RankedTensorType type) { auto tuple = getTuple(tensor); fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0])); - return MutSparseTensorDescriptor(stt, fields); + return MutSparseTensorDescriptor(SparseTensorType(type), fields); } } // namespace sparse_tensor |