//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::ml_program; namespace mlir { namespace ml_program { namespace { template struct ExternalModelBase : public BufferizableOpInterface::ExternalModel { AliasingValueList getAliasingValues(Operation *, OpOperand &, const AnalysisState &) const { return {}; } BufferRelation bufferRelation(Operation *, OpResult, const AnalysisState &) const { return BufferRelation::Unknown; } }; /// Bufferization of ml_program.global into a memref.global struct GlobalOpInterface : public ExternalModelBase { bool bufferizesToMemoryRead(Operation *, OpOperand &, const AnalysisState &) const { return false; } bool bufferizesToMemoryWrite(Operation *, OpOperand &, const AnalysisState &) const { return false; } bool hasTensorSemantics(Operation *) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &, BufferizationState &state) const { auto globalOp = cast(op); if (!globalOp.getValue().has_value()) return globalOp.emitError("global op must have a value"); bufferization::removeSymbol(globalOp, state); auto tensorType = cast(globalOp.getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); auto replacement = replaceOpWithNewBufferizedOp( rewriter, globalOp, globalOp.getSymName(), /*sym_visibility=*/globalOp.getSymVisibilityAttr(), /*type=*/cast(memrefType), /*initial_value=*/globalOp.getValue().value(), /*constant=*/!globalOp.getIsMutable(), /*alignment=*/nullptr); bufferization::insertSymbol(replacement, state); return success(); } }; /// Bufferization of ml_program.global_load into a memref.get_global struct GlobalLoadOpInterface : public ExternalModelBase { bool bufferizesToMemoryRead(Operation *, OpOperand &, const AnalysisState &) const { return false; } bool bufferizesToMemoryWrite(Operation *, OpOperand &, const AnalysisState &) const { return false; } bool isWritable(Operation *, Value, const AnalysisState &) const { return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &, BufferizationState &state) const { auto globalLoadOp = cast(op); auto tensorType = cast(globalLoadOp.getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); replaceOpWithNewBufferizedOp( rewriter, globalLoadOp, memrefType, globalLoadOp.getGlobalAttr().getLeafReference()); return success(); } }; /// Bufferization of ml_program.global_store into a memref.get_global and /// memcpy struct GlobalStoreOpInterface : public ExternalModelBase { bool bufferizesToMemoryRead(Operation *, OpOperand &, const AnalysisState &) const { return false; } bool bufferizesToMemoryWrite(Operation *, OpOperand &, const AnalysisState &) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const { auto globalStoreOp = cast(op); auto tensorType = cast(globalStoreOp.getValue().getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); auto loc = globalStoreOp.getLoc(); auto targetMemref = memref::GetGlobalOp::create( rewriter, loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options, state); if (failed(sourceMemref)) { return failure(); } auto memcpy = options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref); if (failed(memcpy)) { return failure(); } rewriter.eraseOp(globalStoreOp); return success(); } }; } // namespace void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) { GlobalOp::attachInterface(*ctx); GlobalLoadOp::attachInterface(*ctx); GlobalStoreOp::attachInterface(*ctx); }); } } // namespace ml_program } // namespace mlir