diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 6ca9702..8f0f6d1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, return bufferizableOp.getBufferType(value, options, invocationStack); // Op is not bufferizable. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType( // If we do not know the memory space and there is no default memory space, // report a failure. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::detail::defaultIsRepetitiveRegion( |