diff options
author | Matthias Gehre <matthias.gehre@amd.com> | 2024-03-21 14:27:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 14:27:37 +0100 |
commit | 0aa6d57e575dd920db81bef7ff509c4d3a9c6891 (patch) | |
tree | 067f0fc831557d550b0f823d830e83e0ff8c0e63 /mlir/lib | |
parent | 538257bf00960f6134a51a17c8477b298ff87c30 (diff) | |
download | llvm-0aa6d57e575dd920db81bef7ff509c4d3a9c6891.zip llvm-0aa6d57e575dd920db81bef7ff509c4d3a9c6891.tar.gz llvm-0aa6d57e575dd920db81bef7ff509c4d3a9c6891.tar.bz2 |
[MLIR] Add initial convert-memref-to-emitc pass (#85389)
This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt | 18 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 114 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp | 55 |
4 files changed, 188 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 8219cf9..41ab704 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) +add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) add_subdirectory(NVGPUToNVVM) diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt new file mode 100644 index 0000000..8a72e74 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRMemRefToEmitC + MemRefToEmitC.cpp + MemRefToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIREmitCDialect + MLIRMemRefDialect + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp new file mode 100644 index 0000000..0e3b646 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -0,0 +1,114 @@ +//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert memref ops into emitc ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op.getLoc(), "cannot transform alloca with dynamic shape"); + } + + if (op.getAlignment().value_or(1) > 1) { + // TODO: Allow alignment if it is not more than the natural alignment + // of the C array. + return rewriter.notifyMatchFailure( + op.getLoc(), "cannot transform alloca with alignment requirement"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); + } + auto noInit = emitc::OpaqueAttr::get(getContext(), ""); + rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit); + return success(); + } +}; + +struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); + } + + auto subscript = rewriter.create<emitc::SubscriptOp>( + op.getLoc(), operands.getMemref(), operands.getIndices()); + + auto noInit = emitc::OpaqueAttr::get(getContext(), ""); + auto var = + rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit); + + rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript); + rewriter.replaceOp(op, var); + return success(); + } +}; + +struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + auto subscript = rewriter.create<emitc::SubscriptOp>( + op.getLoc(), operands.getMemref(), operands.getIndices()); + rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, + operands.getValue()); + return success(); + } +}; +} // namespace + +void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { + typeConverter.addConversion( + [&](MemRefType memRefType) -> std::optional<Type> { + if (!memRefType.hasStaticShape() || + !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) { + return {}; + } + Type convertedElementType = + typeConverter.convertType(memRefType.getElementType()); + if (!convertedElementType) + return {}; + return emitc::ArrayType::get(memRefType.getShape(), + convertedElementType); + }); +} + +void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &converter) { + patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter, + patterns.getContext()); +} diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp new file mode 100644 index 0000000..4e5d191 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -0,0 +1,55 @@ +//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert memref ops into emitc ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" + +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertMemRefToEmitCPass + : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + void runOnOperation() override { + TypeConverter converter; + + // Fallback for other types. + converter.addConversion([](Type type) -> std::optional<Type> { + if (isa<MemRefType>(type)) + return {}; + return type; + }); + + populateMemRefToEmitCTypeConversion(converter); + + RewritePatternSet patterns(&getContext()); + populateMemRefToEmitCConversionPatterns(patterns, converter); + + ConversionTarget target(getContext()); + target.addIllegalDialect<memref::MemRefDialect>(); + target.addLegalDialect<emitc::EmitCDialect>(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace |