diff options
8 files changed, 50 insertions, 32 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 226a2fb..d8cfeee 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -257,6 +257,9 @@ struct BufferizationOptions { /// Parameters: Value, memory space, bufferization options using UnknownTypeConverterFn = std::function<BaseMemRefType( Value, Attribute memorySpace, const BufferizationOptions &)>; + // Produce a MemorySpace attribute from a tensor type + using DefaultMemorySpaceFn = + std::function<std::optional<Attribute>(TensorType t)>; BufferizationOptions(); @@ -296,11 +299,6 @@ struct BufferizationOptions { /// bufferized or not. bool bufferizeFunctionBoundaries = false; - /// The default memory space that should be used when it cannot be inferred - /// from the context. If case of std::nullopt, bufferization fails when the - /// memory space cannot be inferred at any point. - std::optional<Attribute> defaultMemorySpace = Attribute(); - /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for). /// If this flag is set to `false`, those invariants are no longer enforced /// with buffer copies. @@ -351,6 +349,13 @@ struct BufferizationOptions { /// used. UnknownTypeConverterFn unknownTypeConverterFn = nullptr; + // Use during type conversion to determine the memory space for memref based + // on the original tensor type if the memory space cannot be inferred. + // Returning std::nullopt will cause bufferization to fail (useful to indicate + // failure to determine memory space for a tensor type). + DefaultMemorySpaceFn defaultMemorySpaceFn = + [](TensorType t) -> std::optional<Attribute> { return Attribute(); }; + /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated. /// Should be used only with `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index f69b255..d7492c9 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -26,17 +26,18 @@ struct ConstantOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto constantOp = cast<arith::ConstantOp>(op); + auto type = constantOp.getType().dyn_cast<RankedTensorType>(); + + // Only ranked tensors are supported. + if (!type) + return failure(); Attribute memorySpace; - if (options.defaultMemorySpace.has_value()) - memorySpace = *options.defaultMemorySpace; + if (auto memSpace = options.defaultMemorySpaceFn(type)) + memorySpace = *memSpace; else return constantOp->emitError("could not infer memory space"); - // Only ranked tensors are supported. - if (!isa<RankedTensorType>(constantOp.getType())) - return failure(); - // Only constants inside a module are supported. auto moduleOp = constantOp->getParentOfType<ModuleOp>(); if (!moduleOp) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 6ca9702..8f0f6d1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, return bufferizableOp.getBufferType(value, options, invocationStack); // Op is not bufferizable. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType( // If we do not know the memory space and there is no default memory space, // report a failure. - if (!options.defaultMemorySpace.has_value()) + auto memSpace = + options.defaultMemorySpaceFn(value.getType().cast<TensorType>()); + if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, - *options.defaultMemorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memSpace); } bool bufferization::detail::defaultIsRepetitiveRegion( diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index eb4a96f..34a0c59 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); - } else if (options.defaultMemorySpace.has_value()) { - memorySpace = *options.defaultMemorySpace; + } else if (auto ms = options.defaultMemorySpaceFn(getType())) { + memorySpace = *ms; } else { return getOperation()->emitError("could not infer memory space"); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index dc94b72..208cbda 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -210,8 +210,12 @@ struct OneShotBufferizePass opt.dumpAliasSets = dumpAliasSets; opt.setFunctionBoundaryTypeConversion( parseLayoutMapOption(functionBoundaryTypeConversion)); - if (mustInferMemorySpace) - opt.defaultMemorySpace = std::nullopt; + if (mustInferMemorySpace) { + opt.defaultMemorySpaceFn = + [](TensorType t) -> std::optional<Attribute> { + return std::nullopt; + }; + } opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 07cd1f9..4cdbbf3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, assert(tensorType && "expected TensorType"); BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpace, funcOp, options); + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( index, BufferizationDialect::kBufferLayoutAttrName); @@ -443,7 +443,8 @@ struct FuncOpInterface // Note: If `inferFunctionResultLayout = true`, cast are later folded // away. BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpace, funcOp, options); + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + options); Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( loc, resultType, returnVal); returnValues.push_back(toMemrefOp); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 678b7c0..957f631 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -473,14 +473,14 @@ struct FromElementsOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto fromElementsOp = cast<tensor::FromElementsOp>(op); + auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); // TODO: Implement memory space for this op. - if (options.defaultMemorySpace != Attribute()) + if (options.defaultMemorySpaceFn(tensorType) != Attribute()) return op->emitError("memory space not implemented yet"); // Allocate a buffer for the result. Location loc = op->getLoc(); - auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( @@ -588,8 +588,10 @@ struct GenerateOpInterface const BufferizationOptions &options) const { auto generateOp = cast<tensor::GenerateOp>(op); + auto type = generateOp.getResult().getType(); + // TODO: Implement memory space for this op. - if (options.defaultMemorySpace != Attribute()) + if (options.defaultMemorySpaceFn(type) != Attribute()) return op->emitError("memory space not implemented yet"); // Allocate memory. @@ -1007,10 +1009,6 @@ struct SplatOpInterface OpBuilder::InsertionGuard g(rewriter); auto splatOp = cast<tensor::SplatOp>(op); - // TODO: Implement memory space for this op. - if (options.defaultMemorySpace != Attribute()) - return op->emitError("memory space not implemented yet"); - // Allocate memory. Location loc = op->getLoc(); FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( @@ -1021,6 +1019,11 @@ struct SplatOpInterface // Create linalg::MapOp. auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); + + // TODO: Implement memory space for this op. + if (options.defaultMemorySpaceFn(tensorType) != Attribute()) + return op->emitError("memory space not implemented yet"); + auto linalgOp = rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), /*init=*/*tensorAlloc); diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp index fedfbe3..2991a3c 100644 --- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp +++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp @@ -44,8 +44,10 @@ struct TestTensorCopyInsertionPass bufferization::OneShotBufferizationOptions options; options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops; options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; - if (mustInferMemorySpace) - options.defaultMemorySpace = std::nullopt; + if (mustInferMemorySpace) { + options.defaultMemorySpaceFn = + [](TensorType t) -> std::optional<Attribute> { return std::nullopt; }; + } if (failed(bufferization::insertTensorCopies(getOperation(), options))) signalPassFailure(); } |