diff options
author | Amir Bishara <139038766+amirBish@users.noreply.github.com> | 2023-12-08 11:50:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-08 11:50:33 +0200 |
commit | cf2d625a5d328ab4af6292be7b47c645ffef0e2b (patch) | |
tree | baa9ae3c9e91db0f2e74a7bfea33df10049817f8 | |
parent | 52296e25277146bf2643156627971c11cc7f4a37 (diff) | |
download | llvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.zip llvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.tar.gz llvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.tar.bz2 |
[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file (#73850)
Declare `getPreservedProducerResults` function which helps to get the
preserved results of the producer linalg generic operation as a result
of elementwise fusion.
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 33 |
2 files changed, 24 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 3f4dfe4..a848d12 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, struct ElementwiseOpFusionResult { Operation *fusedOp; llvm::DenseMap<Value, Value> replacements; + static llvm::SmallDenseSet<int> + getPreservedProducerResults(GenericOp producer, GenericOp consumer); }; FailureOr<ElementwiseOpFusionResult> fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index dc5ea28..3eb9119 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( return t1.compose(fusedConsumerArgIndexMap); } +/// Returns a set of indices of the producer's results which would +/// be preserved after the fusion. +llvm::SmallDenseSet<int> +ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer, + GenericOp consumer) { + llvm::SmallDenseSet<int> preservedProducerResults; + for (const auto &producerResult : llvm::enumerate(producer->getResults())) { + auto *outputOperand = producer.getDpsInitOperand(producerResult.index()); + if (producer.payloadUsesValueFromOperand(outputOperand) || + !producer.canOpOperandsBeDropped(outputOperand) || + llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { + return user != consumer.getOperation(); + })) { + preservedProducerResults.insert(producerResult.index()); + } + } + return preservedProducerResults; +} + /// Conditions for elementwise fusion of generic operations. bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { if (!fusedOperand) @@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); /// Find the results of the producer that have uses outside of the consumer. - llvm::SmallDenseSet<int> preservedProducerResults; - for (const auto &producerResult : llvm::enumerate(producer->getResults())) { - auto *outputOperand = producer.getDpsInitOperand(producerResult.index()); - if (producer.payloadUsesValueFromOperand(outputOperand) || - !producer.canOpOperandsBeDropped(outputOperand) || - llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { - return user != consumer.getOperation(); - })) { - preservedProducerResults.insert(producerResult.index()); - } - } + llvm::SmallDenseSet<int> preservedProducerResults = + ElementwiseOpFusionResult::getPreservedProducerResults(producer, + consumer); // Compute the fused operands list and indexing maps. SmallVector<Value> fusedInputOperands, fusedOutputOperands; |