aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAmir Bishara <139038766+amirBish@users.noreply.github.com>2023-12-08 11:50:33 +0200
committerGitHub <noreply@github.com>2023-12-08 11:50:33 +0200
commitcf2d625a5d328ab4af6292be7b47c645ffef0e2b (patch)
treebaa9ae3c9e91db0f2e74a7bfea33df10049817f8 /mlir
parent52296e25277146bf2643156627971c11cc7f4a37 (diff)
downloadllvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.zip
llvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.tar.gz
llvm-cf2d625a5d328ab4af6292be7b47c645ffef0e2b.tar.bz2
[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file (#73850)
Declare `getPreservedProducerResults` function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp33
2 files changed, 24 insertions, 11 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3f4dfe4..a848d12 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
+ static llvm::SmallDenseSet<int>
+ getPreservedProducerResults(GenericOp producer, GenericOp consumer);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index dc5ea28..3eb9119 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+llvm::SmallDenseSet<int>
+ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
+ GenericOp consumer) {
+ llvm::SmallDenseSet<int> preservedProducerResults;
+ for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
+ auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+ if (producer.payloadUsesValueFromOperand(outputOperand) ||
+ !producer.canOpOperandsBeDropped(outputOperand) ||
+ llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
+ return user != consumer.getOperation();
+ })) {
+ preservedProducerResults.insert(producerResult.index());
+ }
+ }
+ return preservedProducerResults;
+}
+
/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer.
- llvm::SmallDenseSet<int> preservedProducerResults;
- for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
- auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
- if (producer.payloadUsesValueFromOperand(outputOperand) ||
- !producer.canOpOperandsBeDropped(outputOperand) ||
- llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
- return user != consumer.getOperation();
- })) {
- preservedProducerResults.insert(producerResult.index());
- }
- }
+ llvm::SmallDenseSet<int> preservedProducerResults =
+ ElementwiseOpFusionResult::getPreservedProducerResults(producer,
+ consumer);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;