diff options
author | Peiming Liu <peiming@google.com> | 2024-06-17 11:35:23 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-17 11:35:23 -0700 |
commit | 996905d8152def16ca2fa1322367e493ac6eef5e (patch) | |
tree | 72e33248d3c1b40017eca8da5e9301fcbe4ed45b /mlir | |
parent | 15399890beb69f622ad0f04a544369fa7947d50b (diff) | |
download | llvm-996905d8152def16ca2fa1322367e493ac6eef5e.zip llvm-996905d8152def16ca2fa1322367e493ac6eef5e.tar.gz llvm-996905d8152def16ca2fa1322367e493ac6eef5e.tar.bz2 |
Revert "[mlir][sparse] implement lowering rules for IterateOp." (#95826)
Reverts llvm/llvm-project#95286
Diffstat (limited to 'mlir')
4 files changed, 17 insertions, 224 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index f57be49..62887c7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -34,20 +34,6 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) { return success(); } -static std::optional<LogicalResult> -convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) { - // The actually Iterator Values (that are updated every iteration). - auto idxTp = IndexType::get(itTp.getContext()); - // TODO: handle batch dimension. - assert(itTp.getEncoding().getBatchLvlRank() == 0); - if (!itTp.isUnique()) { - // Segment high for non-unique iterator. - fields.push_back(idxTp); - } - fields.push_back(idxTp); - return success(); -} - namespace { /// Sparse codegen rule for number of entries operator. @@ -71,114 +57,10 @@ public: } }; -class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> { -public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(IterateOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - if (!op.getCrdUsedLvls().empty()) - return rewriter.notifyMatchFailure( - op, "non-empty coordinates list not implemented."); - - Location loc = op.getLoc(); - - auto iterSpace = SparseIterationSpace::fromValues( - op.getIterSpace().getType(), adaptor.getIterSpace(), 0); - - std::unique_ptr<SparseIterator> it = - iterSpace.extractIterator(rewriter, loc); - - if (it->iteratableByFor()) { - auto [lo, hi] = it->genForCond(rewriter, loc); - Value step = constantIndex(rewriter, loc, 1); - SmallVector<Value> ivs; - for (ValueRange inits : adaptor.getInitArgs()) - llvm::append_range(ivs, inits); - scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs); - - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) - return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - forOp.getBody()->erase(); - Region &dstRegion = forOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - auto yieldOp = - llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(forOp.getBody()); - // replace sparse_tensor.yield with scf.yield. - rewriter.create<scf::YieldOp>(loc, yieldOp.getResults()); - yieldOp.erase(); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, forOp.getResults(), resultMapping); - } else { - SmallVector<Value> ivs; - llvm::append_range(ivs, it->getCursor()); - for (ValueRange inits : adaptor.getInitArgs()) - llvm::append_range(ivs, inits); - - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); - SmallVector<Location> l(types.size(), op.getIterator().getLoc()); - - // Generates loop conditions. - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); - rewriter.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - assert(remArgs.size() == adaptor.getInitArgs().size()); - rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); - - // Generates loop body. - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) - return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - Region &dstRegion = whileOp.getAfter(); - // TODO: handle uses of coordinate! - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - ValueRange aArgs = whileOp.getAfterArguments(); - auto yieldOp = llvm::cast<sparse_tensor::YieldOp>( - whileOp.getAfterBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); - - aArgs = it->linkNewScope(aArgs); - ValueRange nx = it->forward(rewriter, loc); - SmallVector<Value> yields; - llvm::append_range(yields, nx); - llvm::append_range(yields, yieldOp.getResults()); - - // replace sparse_tensor.yield with scf.yield. - yieldOp->erase(); - rewriter.create<scf::YieldOp>(loc, yields); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp( - op, whileOp.getResults().drop_front(it->getCursor().size()), - resultMapping); - } - return success(); - } -}; - } // namespace mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { addConversion([](Type type) { return type; }); - addConversion(convertIteratorType); addConversion(convertIterSpaceType); addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, @@ -192,6 +74,5 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { void mlir::populateLowerSparseIterationToSCFPatterns( TypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>( - converter, patterns.getContext()); + patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index ef95fcc..be8e15d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -331,13 +331,6 @@ public: TrivialIterator(const SparseTensorLevel &stl) : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {} - TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, - Value posLo, Value posHi) - : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo), - posHi(posHi) { - seek(posLo); - } - std::string getDebugInterfacePrefix() const override { return std::string("trivial<") + stl.toString() + ">"; } @@ -427,14 +420,6 @@ public: : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) { assert(!stl.isUnique()); } - - DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl, - Value posLo, Value posHi) - : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) { - assert(!stl.isUnique()); - seek({posLo, genSegmentHigh(b, l, posLo)}); - } - // For LLVM-style RTTI. static bool classof(const SparseIterator *from) { return from->kind == IterKind::kDedup; @@ -1547,11 +1532,6 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues( return space; } -std::unique_ptr<SparseIterator> -SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const { - return makeSimpleIterator(b, l, *this); -} - //===----------------------------------------------------------------------===// // SparseIterator factory functions. //===----------------------------------------------------------------------===// @@ -1611,26 +1591,6 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, } std::unique_ptr<SparseIterator> -sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l, - const SparseIterationSpace &iterSpace) { - // assert(iterSpace.getSpaceDim() == 1); - std::unique_ptr<SparseIterator> ret; - if (!iterSpace.isUnique()) { - // We always dedupliate the non-unique level, but we should optimize it away - // if possible. - ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(), - iterSpace.getBoundLo(), - iterSpace.getBoundHi()); - } else { - ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(), - iterSpace.getBoundLo(), - iterSpace.getBoundHi()); - } - ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional); - return ret; -} - -std::unique_ptr<SparseIterator> sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy) { std::unique_ptr<SparseIterator> ret; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 91f363d..17636af 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -132,10 +132,6 @@ public: Value getBoundLo() const { return bound.first; } Value getBoundHi() const { return bound.second; } - // Extract an iterator to iterate over the sparse iteration space. - std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b, - Location l) const; - private: SmallVector<std::unique_ptr<SparseTensorLevel>> lvls; std::pair<Value, Value> bound; @@ -196,13 +192,6 @@ public: crd = nullptr; } - // Reconstructs a iteration space directly from the provided ValueRange. - static std::unique_ptr<SparseIterator> - fromValues(IteratorType dstTp, ValueRange values, unsigned tid); - - // The inverse operation of `fromValues`. - SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); } - // // Iterator properties. // @@ -356,21 +345,12 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b, unsigned tid, Level lvl); -/// Helper function to create a TensorLevel object from given ValueRange. +/// Helper function to create a TensorLevel object from given `tensor`. std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz, ValueRange buffers, unsigned tid, Level l); - -/// Helper function to create a simple SparseIterator object that iterate -/// over the entire iteration space. -std::unique_ptr<SparseIterator> -makeSimpleIterator(OpBuilder &b, Location l, - const SparseIterationSpace &iterSpace); - -/// Helper function to create a simple SparseIterator object that iterate -/// over the sparse tensor level. -/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when -/// feature complete. +/// Helper function to create a simple SparseIterator object that iterates +/// over the SparseTensorLevel. std::unique_ptr<SparseIterator> makeSimpleIterator( const SparseTensorLevel &stl, SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional); diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir index 77a0e89..5fcd661 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s -// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED #COO = #sparse_tensor.encoding<{ map = (i, j) -> ( @@ -8,44 +7,17 @@ ) }> -// CHECK-LABEL: @sparse_iteration_to_scf -// // deduplication -// CHECK: scf.while {{.*}} { -// CHECK: } do { -// CHECK: } -// CHECK: scf.while {{.*}} { -// CHECK: } do { -// // actual computation -// CHECK: scf.for {{.*}} { -// CHECK: arith.addi -// CHECK: } -// // deduplication -// CHECK: scf.while {{.*}} { -// CHECK: } do { -// CHECK: } -// CHECK: scf.yield -// CHECK: } -// CHECK: return - -// COLLAPSED-LABEL: @sparse_iteration_to_scf -// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} { -// COLLAPSED: %[[VAL:.*]] = arith.addi -// COLLAPSED: scf.yield %[[VAL]] : index -// COLLAPSED: } -// COLLAPSED: return %[[RET]] : index -func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index { - %i = arith.constant 0 : index - %c1 = arith.constant 1 : index - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 - : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> - %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { - %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 - : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1> - %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { - %k = arith.addi %inner, %c1 : index - sparse_tensor.yield %k : index - } - sparse_tensor.yield %r2 : index - } - return %r1 : index +// CHECK-LABEL: func.func @sparse_1D_space( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> +// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex> +// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex> +// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex> +// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex> +// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]] +func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0> } |