aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp116
1 files changed, 90 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3f0b0ba..dd9b4c2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -42,6 +42,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
@@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- Attribute memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- int64_t memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/b.getI64IntegerAttr(memorySpace));
-}
-
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
@@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+ // Be conservative about ops we cannot analyze deeper.
+ if (!linalgOp)
+ return true;
+
+ // Look inside linalg ops.
+ Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+ return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+ // If the value has a reference semantics, it
+ // may be read through any alias...
+ if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+ return true;
+ return llvm::any_of(value.getUses(),
+ static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> promoted;
+ for (Value tensor : state.getPayloadValues(getTensor())) {
+ auto type = dyn_cast<RankedTensorType>(tensor.getType());
+ if (!type) {
+ return emitSilenceableError() << "non-tensor type: " << tensor;
+ }
+
+ Operation *definingOp = tensor.getDefiningOp();
+ if (definingOp)
+ rewriter.setInsertionPointAfter(definingOp);
+ else
+ rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+ // Check this before we emit operations using this value.
+ bool needsMaterialization = mayBeRead(tensor);
+
+ SmallVector<Value> dynamicDims;
+ llvm::SmallPtrSet<Operation *, 4> preservedOps;
+ for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+ auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ preservedOps.insert(dimOp);
+ dynamicDims.push_back(dimOp);
+ }
+ auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+ tensor.getLoc(), type, dynamicDims);
+ // Set memory space if provided.
+ if (getMemorySpaceAttr())
+ allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+ Value allocated = allocation;
+
+ // Only insert a materialization (typically bufferizes to a copy) when the
+ // value may be read from.
+ if (needsMaterialization) {
+ auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+ tensor.getLoc(), tensor, allocated);
+ preservedOps.insert(copy);
+ promoted.push_back(copy.getResult());
+ } else {
+ promoted.push_back(allocated);
+ }
+ rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+ }
+ results.setValues(cast<OpResult>(getPromoted()), promoted);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTensorMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//