diff options
Diffstat (limited to 'mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp')
-rw-r--r-- | mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 92 |
1 files changed, 80 insertions, 12 deletions
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 0b7ffa4..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -160,8 +230,8 @@ struct ConvertGetGlobal final if (opTy.getRank() == 0) { emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); - emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>( - op.getLoc(), lvalueType, operands.getNameAttr()); + emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( + rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); rewriter.replaceOpWithNewOp<emitc::ApplyOp>( op, pointerType, rewriter.getStringAttr("&"), globalLValue); @@ -191,8 +261,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create<emitc::SubscriptOp>( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript); return success(); @@ -211,8 +281,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create<emitc::SubscriptOp>( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, operands.getValue()); return success(); @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -242,7 +310,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { if (inputs.size() != 1) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }; @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } |