diff options
author | Peiming Liu <peiming@google.com> | 2024-06-17 13:29:53 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-17 13:29:53 -0700 |
commit | d6cc35f7f67575f2d3534ea385c2f36f48f49aea (patch) | |
tree | 856eeac32bc7fd83ef4e45ee39e950e771c138b0 /mlir | |
parent | 5b04b6fe3fabba8f76d730da3c0d528e1dd0c184 (diff) | |
download | llvm-d6cc35f7f67575f2d3534ea385c2f36f48f49aea.zip llvm-d6cc35f7f67575f2d3534ea385c2f36f48f49aea.tar.gz llvm-d6cc35f7f67575f2d3534ea385c2f36f48f49aea.tar.bz2 |
Reapply "[mlir][sparse] implement lowering rules for IterateOp." (#95836)
Diffstat (limited to 'mlir')
4 files changed, 224 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index 62887c7..4224925 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { return success(); } +static std::optional<LogicalResult> +convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { + // The actually Iterator Values (that are updated every iteration). + auto idxTp = IndexType::get(itTp.getContext()); + // TODO: handle batch dimension. + assert(itTp.getEncoding().getBatchLvlRank() == 0); + if (!itTp.isUnique()) { + // Segment high for non-unique iterator. + fields.push_back(idxTp); + } + fields.push_back(idxTp); + return success(); +} + namespace { /// Sparse codegen rule for number of entries operator. @@ -57,10 +71,114 @@ public: } }; +class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + LogicalResult + matchAndRewrite(IterateOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + if (!op.getCrdUsedLvls().empty()) + return rewriter.notifyMatchFailure( + op, "non-empty coordinates list not implemented."); + + Location loc = op.getLoc(); + + auto iterSpace = SparseIterationSpace::fromValues( + op.getIterSpace().getType(), adaptor.getIterSpace(), 0); + + std::unique_ptr<SparseIterator> it = + iterSpace.extractIterator(rewriter, loc); + + if (it->iteratableByFor()) { + auto [lo, hi] = it->genForCond(rewriter, loc); + Value step = constantIndex(rewriter, loc, 1); + SmallVector<Value> ivs; + for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs); + + Block *loopBody = op.getBody(); + OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs( + loopBody->getArgumentTypes(), bodyTypeMapping))) + return failure(); + rewriter.applySignatureConversion(loopBody, bodyTypeMapping); + + rewriter.eraseBlock(forOp.getBody()); + Region &dstRegion = forOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + + auto yieldOp = + llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator()); + + rewriter.setInsertionPointToEnd(forOp.getBody()); + // replace sparse_tensor.yield with scf.yield. + rewriter.create<scf::YieldOp>(loc, yieldOp.getResults()); + rewriter.eraseOp(yieldOp); + + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + rewriter.replaceOp(op, forOp.getResults(), resultMapping); + } else { + SmallVector<Value> ivs; + llvm::append_range(ivs, it->getCursor()); + for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + + assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); + + TypeRange types = ValueRange(ivs).getTypes(); + auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); + SmallVector<Location> l(types.size(), op.getIterator().getLoc()); + + // Generates loop conditions. + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); + rewriter.setInsertionPointToStart(before); + ValueRange bArgs = before->getArguments(); + auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); + assert(remArgs.size() == adaptor.getInitArgs().size()); + rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); + + // Generates loop body. + Block *loopBody = op.getBody(); + OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs( + loopBody->getArgumentTypes(), bodyTypeMapping))) + return failure(); + rewriter.applySignatureConversion(loopBody, bodyTypeMapping); + + Region &dstRegion = whileOp.getAfter(); + // TODO: handle uses of coordinate! + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + ValueRange aArgs = whileOp.getAfterArguments(); + auto yieldOp = llvm::cast<sparse_tensor::YieldOp>( + whileOp.getAfterBody()->getTerminator()); + + rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); + + aArgs = it->linkNewScope(aArgs); + ValueRange nx = it->forward(rewriter, loc); + SmallVector<Value> yields; + llvm::append_range(yields, nx); + llvm::append_range(yields, yieldOp.getResults()); + + // replace sparse_tensor.yield with scf.yield. + rewriter.eraseOp(yieldOp); + rewriter.create<scf::YieldOp>(loc, yields); + + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + rewriter.replaceOp( + op, whileOp.getResults().drop_front(it->getCursor().size()), + resultMapping); + } + return success(); + } +}; + } // namespace mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { addConversion([](Type type) { return type; }); + addConversion(convertIteratorType); addConversion(convertIterSpaceType); addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, @@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { void mlir::populateLowerSparseIterationToSCFPatterns( TypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext()); + patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index be8e15d..ef95fcc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -331,6 +331,13 @@ public: TrivialIterator(const SparseTensorLevel &stl) : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {} + TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, + Value posLo, Value posHi) + : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo), + posHi(posHi) { + seek(posLo); + } + std::string getDebugInterfacePrefix() const override { return std::string("trivial<") + stl.toString() + ">"; } @@ -420,6 +427,14 @@ public: : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) { assert(!stl.isUnique()); } + + DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, + Value posLo, Value posHi) + : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) { + assert(!stl.isUnique()); + seek({posLo, genSegmentHigh(b, l, posLo)}); + } + // For LLVM-style RTTI. static bool classof(const SparseIterator *from) { return from->kind == IterKind::kDedup; @@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues( return space; } +std::unique_ptr<SparseIterator> +SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const { + return makeSimpleIterator(b, l, *this); +} + //===----------------------------------------------------------------------===// // SparseIterator factory functions. //===----------------------------------------------------------------------===// @@ -1591,6 +1611,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, } std::unique_ptr<SparseIterator> +sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l, + const SparseIterationSpace &iterSpace) { + // assert(iterSpace.getSpaceDim() == 1); + std::unique_ptr<SparseIterator> ret; + if (!iterSpace.isUnique()) { + // We always dedupliate the non-unique level, but we should optimize it away + // if possible. + ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(), + iterSpace.getBoundLo(), + iterSpace.getBoundHi()); + } else { + ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(), + iterSpace.getBoundLo(), + iterSpace.getBoundHi()); + } + ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional); + return ret; +} + +std::unique_ptr<SparseIterator> sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy) { std::unique_ptr<SparseIterator> ret; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 17636af..91f363d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -132,6 +132,10 @@ public: Value getBoundLo() const { return bound.first; } Value getBoundHi() const { return bound.second; } + // Extract an iterator to iterate over the sparse iteration space. + std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b, + Location l) const; + private: SmallVector<std::unique_ptr<SparseTensorLevel>> lvls; std::pair<Value, Value> bound; @@ -192,6 +196,13 @@ public: crd = nullptr; } + // Reconstructs a iteration space directly from the provided ValueRange. + static std::unique_ptr<SparseIterator> + fromValues(IteratorType dstTp, ValueRange values, unsigned tid); + + // The inverse operation of `fromValues`. + SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); } + // // Iterator properties. // @@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b, unsigned tid, Level lvl); -/// Helper function to create a TensorLevel object from given `tensor`. +/// Helper function to create a TensorLevel object from given ValueRange. std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz, ValueRange buffers, unsigned tid, Level l); -/// Helper function to create a simple SparseIterator object that iterates -/// over the SparseTensorLevel. + +/// Helper function to create a simple SparseIterator object that iterate +/// over the entire iteration space. +std::unique_ptr<SparseIterator> +makeSimpleIterator(OpBuilder &b, Location l, + const SparseIterationSpace &iterSpace); + +/// Helper function to create a simple SparseIterator object that iterate +/// over the sparse tensor level. +/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when +/// feature complete. std::unique_ptr<SparseIterator> makeSimpleIterator( const SparseTensorLevel &stl, SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir index 5fcd661..77a0e89 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s +// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED #COO = #sparse_tensor.encoding<{ map = (i, j) -> ( @@ -7,17 +8,44 @@ ) }> -// CHECK-LABEL: func.func @sparse_1D_space( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> -// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex> -// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex> -// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex> -// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex> -// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]] -func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> - return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0> +// CHECK-LABEL: @sparse_iteration_to_scf +// // deduplication +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// CHECK: } +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// // actual computation +// CHECK: scf.for {{.*}} { +// CHECK: arith.addi +// CHECK: } +// // deduplication +// CHECK: scf.while {{.*}} { +// CHECK: } do { +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: return + +// COLLAPSED-LABEL: @sparse_iteration_to_scf +// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} { +// COLLAPSED: %[[VAL:.*]] = arith.addi +// COLLAPSED: scf.yield %[[VAL]] : index +// COLLAPSED: } +// COLLAPSED: return %[[RET]] : index +func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index { + %i = arith.constant 0 : index + %c1 = arith.constant 1 : index + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1> + %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { + %k = arith.addi %inner, %c1 : index + sparse_tensor.yield %k : index + } + sparse_tensor.yield %r2 : index + } + return %r1 : index } |