aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaden Angella <ajaden@google.com>2025-07-18 10:15:05 -0700
committerGitHub <noreply@github.com>2025-07-18 10:15:05 -0700
commit7fd91bb6e89be39a130e04058a01d41ae5d600cb (patch)
tree05d102b198fe908286ada3210fcda144c87bc860
parentff225b5d88647448be8bbba54aaac3977a5485b5 (diff)
downloadllvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.zip
llvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.tar.gz
llvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.tar.bz2
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars (#148055)
This aims to expand the the MemRefToEmitC pass so that it can accept global scalars. From: ``` memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1> func.func @globals() { memref.get_global @__constant_xi32 : memref<i32> } ``` To: ``` emitc.global static const @__constant_xi32 : i32 = -1 emitc.func @globals() { %0 = get_global @__constant_xi32 : !emitc.lvalue<i32> %1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> return } ```
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp36
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir5
2 files changed, 38 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1..0b7ffa4 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
+ Type resultTy;
+ if (opTy.getRank() == 0) {
+ resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
+ } else {
+ resultTy = typeConverter->convertType(opTy);
+ }
+ return resultTy;
+}
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
-
+ MemRefType opTy = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- auto resultTy = getTypeConverter()->convertType(op.getType());
+
+ Type resultTy = convertMemRefType(opTy, getTypeConverter());
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
@@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;
Attribute initialValue = operands.getInitialValueAttr();
+ if (opTy.getRank() == 0) {
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+ initialValue = elementsAttr.getSplatValue<Attribute>();
+ }
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
@@ -132,11 +150,23 @@ struct ConvertGetGlobal final
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto resultTy = getTypeConverter()->convertType(op.getType());
+ MemRefType opTy = op.getType();
+ Type resultTy = convertMemRefType(opTy, getTypeConverter());
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
+
+ if (opTy.getRank() == 0) {
+ emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
+ emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>(
+ op.getLoc(), lvalueType, operands.getNameAttr());
+ emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
+ rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
+ op, pointerType, rewriter.getStringAttr("&"), globalLValue);
+ return success();
+ }
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
operands.getNameAttr());
return success();
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1d..2b4eda3 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+ // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
memref.global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,9 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
+ // CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+ // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ %1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
}