aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td68
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h109
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp107
-rw-r--r--mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir42
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel1
8 files changed, 325 insertions, 36 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bccdeaa..7caae2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -167,6 +167,74 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
}
//===----------------------------------------------------------------------===//
+// EliminateLinalgOpAnchoredEmptyTensorsOp
+//===----------------------------------------------------------------------===//
+
+def EliminateLinalgOpAnchoredEmptyTensorsOp
+ : Op<Transform_Dialect, "structured.eliminate_empty_tensors",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Try to eliminate all `tensor.empty` op uses that are anchored on a LinalgOp
+ within the targeted op.
+
+ This op is similar to `bufferization.eliminate_empty_tensors`, but specific
+ to LinalgOps.
+
+ `tensor.empty` ops cannot be bufferized. They can either be converted to
+ `bufferization.alloc_tensor` or replaced with another tensor (via this
+ transform). `tensor.empty` does not specify the contents of the returned
+ tensor so their results can be replaced with arbitrary tensor values as long
+ as the dimensions match.
+
+ This transform looks for `tensor.empty` ops where the SSA use-def chain of
+ the result ends in a supported LinalgOp (always following the aliasing
+ OpOperand/OpResult chain). The following LinalgOps are supported:
+ - Only parallel iterator types.
+ - The use-def chain ends in an input operand of the LinalgOp.
+ - The LinalgOp has an unused output operand with the same shape and
+ indexing map.
+
+ Example:
+
+ ```
+ %0 = tensor.empty()
+ %1 = linalg.matmul ins(...) outs(%0)
+ %2 = linalg.generic ins(%1) outs(%dest) {
+ ^bb0(%in: f32, %out: f32):
+ // out not used
+ }
+ ```
+
+ Is rewritten with:
+ ```
+ %0 = tensor.empty()
+ %1 = linalg.matmul ins(...) outs(%dest)
+ %2 = linalg.generic ins(%0) outs(%1) {
+ ^bb0(%in: f32, %out: f32):
+ // Use %out instead of %in
+ }
+ ```
+
+ After this transformation, the "ins" operand has no uses inside the body of
+ the LinalgOp and can be folded away with existing cleanup patterns.
+ Afterwards, the tensor::EmptyOp can also fold away, so that the example can
+ bufferize without an allocation (in the absence of other conflicts).
+
+ #### Return modes
+
+ This transform reads the target handle and modifies the payload. It does
+ not produce any handle.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+
+ let results = (outs);
+
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+}
+
+//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 99886f5..c441b79 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -27,6 +27,10 @@
#include "llvm/ADT/SmallSet.h"
namespace mlir {
+namespace bufferization {
+class OneShotAnalysisState;
+} // namespace bufferization
+
namespace linalg {
class LinalgOp;
@@ -39,6 +43,75 @@ class LinalgOp;
std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
//===----------------------------------------------------------------------===//
+// Bufferization-related transforms.
+//===----------------------------------------------------------------------===//
+
+/// Materialize a buffer allocation for the given tensor.pad op and lower the
+/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
+///
+/// %0 = tensor.pad low[%l] high[%h] %t ...
+///
+/// is lowered to:
+///
+/// %alloc = memref.alloc
+/// linalg.fill ... outs(%alloc)
+/// %subview = memref.subview %alloc [%l] [...] [1]
+/// memref.tensor_store %t, %subview
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In addition to rewriting the IR as shown above, this function returns the
+/// newly allocated buffer. Furthermore, the result of the
+/// bufferization.to_tensor op is optionally returned via `replacement`.
+Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
+ Attribute memorySpace = {},
+ Value *replacement = nullptr);
+
+/// Materialize a buffer allocation for the given tensor value. E.g.:
+///
+/// %alloc = memref.alloc
+/// memref.tensor_store %value, %alloc
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In case `value` is a tensor.pad result, the corresponding overload is used
+/// internally to produce a better bufferization.
+///
+/// In addition to rewriting the IR as shown above, this function returns the
+/// newly allocated buffer. Furthermore, the result of the
+/// bufferization.to_tensor op is optionally returned via `replacement`.
+Value bufferizeToAllocation(RewriterBase &rewriter, Value value,
+ Attribute memorySpace = {},
+ Value *replacement = nullptr);
+
+/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
+/// LinalgOp. This transforms looks for LinalgOps that have an unused output
+/// operand and an input operand that is rooted in a tensor::EmptyOp. The
+/// tensor::EmptyOp uses are replaced with the output operand and the two
+/// operands of the LinalgOp are swapped.
+///
+/// Example:
+/// %0 = tensor.empty()
+/// %1 = linalg.matmul ins(...) outs(%0)
+/// %2 = linalg.generic ins(%1) outs(%dest) {
+/// ^bb0(%in: f32, %out: f32):
+/// // out not used
+/// }
+///
+/// The IR is transformed as follows:
+/// %0 = tensor.empty()
+/// %1 = linalg.matmul ins(...) outs(%dest)
+/// %2 = linalg.generic ins(%0) outs(%1) {
+/// ^bb0(%in: f32, %out: f32):
+/// // Use %out instead of %in
+/// }
+///
+/// The "ins" operand has no uses inside the body of the LinalgOp and can be
+/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp
+/// can also fold away.
+LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(
+ RewriterBase &rewriter, Operation *op,
+ bufferization::OneShotAnalysisState &state);
+
+//===----------------------------------------------------------------------===//
// Structs that configure the behavior of various transformations.
//===----------------------------------------------------------------------===//
@@ -308,42 +381,6 @@ LogicalResult vectorizeOpPrecondition(Operation *op,
using LinalgLoops = SmallVector<Operation *, 4>;
-/// Materialize a buffer allocation for the given tensor.pad op and lower the
-/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
-///
-/// %0 = tensor.pad low[%l] high[%h] %t ...
-///
-/// is lowered to:
-///
-/// %alloc = memref.alloc
-/// linalg.fill ... outs(%alloc)
-/// %subview = memref.subview %alloc [%l] [...] [1]
-/// memref.tensor_store %t, %subview
-/// %0 = bufferization.to_tensor %alloc restrict writable
-///
-/// In addition to rewriting the IR as shown above, this function returns the
-/// newly allocated buffer. Furthermore, the result of the
-/// bufferization.to_tensor op is optionally returned via `replacement`.
-Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
- Attribute memorySpace = {},
- Value *replacement = nullptr);
-
-/// Materialize a buffer allocation for the given tensor value. E.g.:
-///
-/// %alloc = memref.alloc
-/// memref.tensor_store %value, %alloc
-/// %0 = bufferization.to_tensor %alloc restrict writable
-///
-/// In case `value` is a tensor.pad result, the corresponding overload is used
-/// internally to produce a better bufferization.
-///
-/// In addition to rewriting the IR as shown above, this function returns the
-/// newly allocated buffer. Furthermore, the result of the
-/// bufferization.to_tensor op is optionally returned via `replacement`.
-Value bufferizeToAllocation(RewriterBase &rewriter, Value value,
- Attribute memorySpace = {},
- Value *replacement = nullptr);
-
/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index ec1631a..1298636 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
+ MLIRBufferizationTransforms
MLIRFuncDialect
MLIRIR
MLIRLinalgDialect
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d702e6d..2f38347 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
@@ -234,6 +235,37 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
#undef DOWNSCALE
return emitDefaultSilenceableFailure(target);
}
+
+//===----------------------------------------------------------------------===//
+// EliminateLinalgOpAnchoredEmptyTensorsOp
+//===----------------------------------------------------------------------===//
+
+void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTarget(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
+ transform::TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ bufferization::OneShotBufferizationOptions options;
+ options.allowReturnAllocs = true;
+
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ bufferization::OneShotAnalysisState state(target, options);
+ if (failed(analyzeOp(target, state)))
+ return mlir::emitSilenceableFailure(target->getLoc())
+ << "failed to analyze op";
+ if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep(
+ rewriter, target, state)))
+ return mlir::emitSilenceableFailure(target->getLoc())
+ << "failed to eliminate LinalgOp anchored tensor.empty ops";
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 82787a3..5ae9b7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
DropUnitDims.cpp
ElementwiseOpFusion.cpp
ElementwiseToLinalg.cpp
+ EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
new file mode 100644
index 0000000..4b75406
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -0,0 +1,107 @@
+//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::linalg;
+
+/// Get an output operand that matches the given input operand and can be used
+/// to eliminate a tensor.empty op.
+static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
+ for (OpOperand *operand : op.getDpsInitOperands()) {
+ // Operand must be unused.
+ if (op.payloadUsesValueFromOperand(operand))
+ continue;
+ // Types must match.
+ if (operand->get().getType() != in->get().getType())
+ continue;
+ // Indexing maps must match.
+ if (op.getMatchingIndexingMap(operand) != op.getMatchingIndexingMap(in))
+ continue;
+ return operand;
+ }
+ return nullptr;
+}
+
+LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ OpBuilder::InsertionGuard g(rewriter);
+ DominanceInfo domInfo;
+
+ op->walk([&](LinalgOp op) {
+ // Only ops with all "parallel" iterator types are supported.
+ if (op.getNumParallelLoops() != op.getNumLoops())
+ return WalkResult::skip();
+
+ for (OpOperand *in : op.getDpsInputOperands()) {
+ // Skip non-tensor operands.
+ if (!in->get().getType().isa<RankedTensorType>())
+ continue;
+
+ // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
+ // equivalent tensors. I.e., stop when there are ops such as extract_slice
+ // on the path.
+ TraversalConfig config;
+ config.followEquivalentOnly = true;
+ config.alwaysIncludeLeaves = false;
+ SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
+ in->get(), /*condition=*/
+ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+ config);
+ if (emptyTensors.empty())
+ continue;
+
+ // Find matching out operand.
+ OpOperand *out = getUnusedOutOperand(op, in);
+ if (!out)
+ continue;
+
+ // Check if this transform would violate dominance.
+ if (!llvm::all_of(emptyTensors, [&](Value v) {
+ return domInfo.properlyDominates(out->get(), v.getDefiningOp());
+ }))
+ continue;
+
+ // Replace all uses of the tensor.empty, but do not delete it yet. It will
+ // fold away later (to not invalidate DominanceInfo).
+ for (Value v : emptyTensors) {
+ assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
+ rewriter.replaceAllUsesWith(v, out->get());
+ }
+
+ // Turn the "in" into an "out".
+ rewriter.updateRootInPlace(op, [&]() {
+ out->set(in->get());
+ // The original "in" could be removed entirely here (because it will no
+ // longer have any uses in the payload), but we delegate this to
+ // existing cleanup patterns that remove unused operands.
+ in->set(emptyTensors.front());
+ BlockArgument outArg = op.getMatchingBlockArgument(out);
+ assert(outArg.getUses().empty() && "expected that out has no uses");
+ BlockArgument inArg = op.getMatchingBlockArgument(in);
+ rewriter.replaceAllUsesWith(inArg, outArg);
+ assert(!op.payloadUsesValueFromOperand(in) &&
+ "expected that the in operand is now unused");
+ });
+
+ state.resetCache();
+ }
+
+ return WalkResult::advance();
+ });
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
new file mode 100644
index 0000000..939eea3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @eliminate_tensor_empty(
+// CHECK-SAME: %[[arg0:.*]]: tensor<50x91xf32>,
+// CHECK-NOT: tensor.empty
+// CHECK: %[[filled:.*]] = linalg.fill {{.*}} outs(%[[arg0]]
+// CHECK: %[[matmul:.*]] = linalg.matmul {{.*}} outs(%[[filled]]
+// CHECK: %[[generic:.*]] = linalg.generic {{.*}} outs(%[[matmul]]
+// CHECK: return %[[generic]]
+func.func @eliminate_tensor_empty(
+ %arg0: tensor<50x91xf32>, %arg1: tensor<91xf32>, %arg2: tensor<50x1280xf32>,
+ %arg3: tensor<1280x91xf32>) -> tensor<50x91xf32>
+{
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.empty() : tensor<50x91xf32>
+ %1 = linalg.fill ins(%cst : f32)
+ outs(%0 : tensor<50x91xf32>) -> tensor<50x91xf32>
+ %2 = linalg.matmul
+ ins(%arg2, %arg3 : tensor<50x1280xf32>, tensor<1280x91xf32>)
+ outs(%1 : tensor<50x91xf32>) -> tensor<50x91xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1, %2 : tensor<91xf32>, tensor<50x91xf32>)
+ outs(%arg0 : tensor<50x91xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %16 = arith.addf %in, %in_0 : f32
+ linalg.yield %16 : f32
+ } -> tensor<50x91xf32>
+ return %3 : tensor<50x91xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.linalg.erase_unnecessary_inputs
+ } : !transform.any_op
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e303df1..463b326 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9306,6 +9306,7 @@ cc_library(
":Analysis",
":ArithDialect",
":AsmParser",
+ ":BufferizationTransforms",
":DialectUtils",
":FuncDialect",
":GPUDialect",