diff options
author | Jaden Angella <ajaden@google.com> | 2025-07-18 10:15:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-18 10:15:05 -0700 |
commit | 7fd91bb6e89be39a130e04058a01d41ae5d600cb (patch) | |
tree | 05d102b198fe908286ada3210fcda144c87bc860 | |
parent | ff225b5d88647448be8bbba54aaac3977a5485b5 (diff) | |
download | llvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.zip llvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.tar.gz llvm-7fd91bb6e89be39a130e04058a01d41ae5d600cb.tar.bz2 |
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars (#148055)
This aims to expand the the MemRefToEmitC pass so that it can accept
global scalars.
From:
```
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
memref.get_global @__constant_xi32 : memref<i32>
}
```
To:
```
emitc.global static const @__constant_xi32 : i32 = -1
emitc.func @globals() {
%0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
return
}
```
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 36 | ||||
-rw-r--r-- | mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir | 5 |
2 files changed, 38 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1..0b7ffa4 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { } }; +Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { + Type resultTy; + if (opTy.getRank() == 0) { + resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy)); + } else { + resultTy = typeConverter->convertType(opTy); + } + return resultTy; +} + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - + MemRefType opTy = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); @@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - auto resultTy = getTypeConverter()->convertType(op.getType()); + + Type resultTy = convertMemRefType(opTy, getTypeConverter()); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); @@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); + if (opTy.getRank() == 0) { + auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue()); + initialValue = elementsAttr.getSplatValue<Attribute>(); + } if (isa_and_present<UnitAttr>(initialValue)) initialValue = {}; @@ -132,11 +150,23 @@ struct ConvertGetGlobal final matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); + MemRefType opTy = op.getType(); + Type resultTy = convertMemRefType(opTy, getTypeConverter()); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); } + + if (opTy.getRank() == 0) { + emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); + emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>( + op.getLoc(), lvalueType, operands.getNameAttr()); + emitc::PointerType pointerType = emitc::PointerType::get(resultTy); + rewriter.replaceOpWithNewOp<emitc::ApplyOp>( + op, pointerType, rewriter.getStringAttr("&"), globalLValue); + return success(); + } rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, operands.getNameAttr()); return success(); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index d37fd1d..2b4eda3 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 { module @globals { memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> + memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1> + // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1 memref.global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32> memref.global @uninitialized_global : memref<3x7xf32> = uninitialized @@ -50,6 +52,9 @@ module @globals { func.func @use_global() { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> + // CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32> + // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + %1 = memref.get_global @__constant_xi32 : memref<i32> return } } |