diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 116 |
1 files changed, 90 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3f0b0ba..dd9b4c2 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -42,6 +42,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" @@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns( // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - Attribute memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/memorySpace); -} - -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - int64_t memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/b.getI64IntegerAttr(memorySpace)); -} - namespace { class NewOpsListener : public RewriterBase::ForwardingListener { public: @@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() { } //===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +/// Return true if the operand may be read from by its owner. This is currently +/// very conservative and only looks inside linalg operations to prevent +/// unintentional data loss. +static bool mayBeRead(OpOperand &operand) { + auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner()); + + // Be conservative about ops we cannot analyze deeper. + if (!linalgOp) + return true; + + // Look inside linalg ops. + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); + return !blockArgument.use_empty(); +} + +/// Return true if the value may be read through any of its uses. +static bool mayBeRead(Value value) { + // If the value has a reference semantics, it + // may be read through any alias... + if (!isa<TensorType, FloatType, IntegerType>(value.getType())) + return true; + return llvm::any_of(value.getUses(), + static_cast<bool (&)(OpOperand &)>(mayBeRead)); +} + +DiagnosedSilenceableFailure +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector<Value> promoted; + for (Value tensor : state.getPayloadValues(getTensor())) { + auto type = dyn_cast<RankedTensorType>(tensor.getType()); + if (!type) { + return emitSilenceableError() << "non-tensor type: " << tensor; + } + + Operation *definingOp = tensor.getDefiningOp(); + if (definingOp) + rewriter.setInsertionPointAfter(definingOp); + else + rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner()); + + // Check this before we emit operations using this value. + bool needsMaterialization = mayBeRead(tensor); + + SmallVector<Value> dynamicDims; + llvm::SmallPtrSet<Operation *, 4> preservedOps; + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { + if (!ShapedType::isDynamic(dim)) + continue; + Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); + auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + preservedOps.insert(dimOp); + dynamicDims.push_back(dimOp); + } + auto allocation = rewriter.create<bufferization::AllocTensorOp>( + tensor.getLoc(), type, dynamicDims); + // Set memory space if provided. + if (getMemorySpaceAttr()) + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); + Value allocated = allocation; + + // Only insert a materialization (typically bufferizes to a copy) when the + // value may be read from. + if (needsMaterialization) { + auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( + tensor.getLoc(), tensor, allocated); + preservedOps.insert(copy); + promoted.push_back(copy.getResult()); + } else { + promoted.push_back(allocated); + } + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); + } + results.setValues(cast<OpResult>(getPromoted()), promoted); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PromoteTensorOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + transform::onlyReadsHandle(getTensorMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// |