diff options
author | ian Bearman <ianb@microsoft.com> | 2024-02-08 07:59:37 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-08 16:59:37 +0100 |
commit | 067d2779fcfc62dd429177f350b8cefe49b65b51 (patch) | |
tree | cf9a64fbcef9a8fdde9c298d3957e0078f8e5171 /mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | |
parent | b14731fe93d0db9a59984783051880795ae0992d (diff) | |
download | llvm-067d2779fcfc62dd429177f350b8cefe49b65b51.zip llvm-067d2779fcfc62dd429177f350b8cefe49b65b51.tar.gz llvm-067d2779fcfc62dd429177f350b8cefe49b65b51.tar.bz2 |
[MLIR] Setting MemorySpace During Bufferization (#78484)
Collection of changes with the goal of being able to convert `encoding`
to `memorySpace` during bufferization
- new API for encoder to allow implementation to select destination
memory space
- update existing bufferization implementations to support the new
interface
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 6ca9702..8f0f6d1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, return bufferizableOp.getBufferType(value, options, invocationStack); // Op is not bufferizable. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType( // If we do not know the memory space and there is no default memory space, // report a failure. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::detail::defaultIsRepetitiveRegion( |