diff options
author | Diego Caballero <diego.caballero@intel.com> | 2020-01-31 15:18:58 -0800 |
---|---|---|
committer | Diego Caballero <diego.caballero@intel.com> | 2020-01-31 15:19:38 -0800 |
commit | e5aaf30cf1ab03417c38a3df2482f76e673511a0 (patch) | |
tree | 0d8ba51d7e1957b724d6b793b592ebbd6563a0eb | |
parent | 27684ae66d5545f211c0ac4393d0ba2bf3b5b47c (diff) | |
download | llvm-e5aaf30cf1ab03417c38a3df2482f76e673511a0.zip llvm-e5aaf30cf1ab03417c38a3df2482f76e673511a0.tar.gz llvm-e5aaf30cf1ab03417c38a3df2482f76e673511a0.tar.bz2 |
[mlir] Introduce bare ptr calling convention for MemRefs in LLVM dialect
Summary:
This patch introduces an alternative calling convention for
MemRef function arguments in LLVM dialect. It converts MemRef
function arguments to LLVM bare pointers to the MemRef element
type instead of creating a MemRef descriptor. Bare pointers are
then promoted to a MemRef descriptors at the beginning of the
function. This calling convention is only enabled with a flag.
Reviewers: ftynse, bondhugula, nicolasvasilache, rriddle, mehdi_amini
Reviewed By: ftynse, rriddle, mehdi_amini
Subscribers: Joonsoo, flaub, merge_guards_bot, jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72802
-rw-r--r-- | mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h | 35 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h | 17 | ||||
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 239 | ||||
-rw-r--r-- | mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir (renamed from mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir) | 173 | ||||
-rw-r--r-- | mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir | 322 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir | 183 |
6 files changed, 754 insertions, 215 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 931466c..0b8ac9c 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -26,6 +26,7 @@ class Type; namespace mlir { +class LLVMTypeConverter; class UnrankedMemRefType; namespace LLVM { @@ -33,13 +34,43 @@ class LLVMDialect; class LLVMType; } // namespace LLVM +/// Set of callbacks that allows the customization of LLVMTypeConverter. +struct LLVMTypeConverterCustomization { + using CustomCallback = + std::function<LLVM::LLVMType(LLVMTypeConverter &, Type)>; + + /// Customize the type conversion of function arguments. + CustomCallback funcArgConverter; + + /// Initialize customization to default callbacks. + LLVMTypeConverterCustomization(); +}; + +/// Callback to convert function argument types. It converts a MemRef function +/// argument to a struct that contains the descriptor information. Converted +/// types are promoted to a pointer to the converted type. +LLVM::LLVMType structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type); + +/// Callback to convert function argument types. It converts MemRef function +/// arguments to bare pointers to the MemRef element type. Converted types are +/// not promoted to pointers. +LLVM::LLVMType barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type); + /// Conversion from types in the Standard dialect to the LLVM IR dialect. class LLVMTypeConverter : public TypeConverter { public: using TypeConverter::convertType; + /// Create an LLVMTypeConverter using the default + /// LLVMTypeConverterCustomization. LLVMTypeConverter(MLIRContext *ctx); + /// Create an LLVMTypeConverter using 'custom' customizations. + LLVMTypeConverter(MLIRContext *ctx, + const LLVMTypeConverterCustomization &custom); + /// Convert types to LLVM IR. This calls `convertAdditionalType` to convert /// non-standard or non-builtin types. Type convertType(Type t) override; @@ -121,8 +152,8 @@ private: // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); - // Extract an LLVM IR dialect type. - LLVM::LLVMType unwrap(Type type); + /// Callbacks for customizing the type conversion. + LLVMTypeConverterCustomization customizations; }; /// Helper class to produce LLVM dialect operations extracting or inserting diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index be19a89..f9e68c7 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -44,8 +44,8 @@ using LLVMTypeConverterMaker = std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>; /// Collect a set of patterns to convert memory-related operations from the -/// Standard dialect to the LLVM dialect, excluding the memory-related -/// operations. +/// Standard dialect to the LLVM dialect, excluding non-memory-related +/// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); @@ -54,10 +54,21 @@ void populateStdToLLVMMemoryConversionPatters( void populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Collect a set of patterns to convert from the Standard dialect to LLVM. +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. +void populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of default patterns to convert from the Standard dialect to +/// LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Collect a set of patterns to convert from the Standard dialect to +/// LLVM using the bare pointer calling convention for MemRef function +/// arguments. +void populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. /// Specifying `useAlloca-true` emits stack allocations instead. In the future diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 4fcd64f..e13b2a4 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -43,19 +43,14 @@ static llvm::cl::opt<bool> llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); -LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) - : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) { - assert(llvmDialect && "LLVM IR dialect is not registered"); - module = &llvmDialect->getLLVMModule(); -} - -// Get the LLVM context. -llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { - return module->getContext(); -} +static llvm::cl::opt<bool> clUseBarePtrCallConv( + PASS_NAME "-use-bare-ptr-memref-call-conv", + llvm::cl::desc("Replace FuncOp's MemRef arguments with " + "bare pointers to the MemRef element types"), + llvm::cl::init(false)); // Extract an LLVM IR type from the LLVM IR dialect type. -LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { +static LLVM::LLVMType unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); @@ -66,6 +61,70 @@ LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { return wrappedLLVMType; } +/// Initialize customization to default callbacks. +LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() { + funcArgConverter = structFuncArgTypeConverter; +} + +// Callback to convert function argument types. It converts a MemRef function +// arguments to a struct that contains the descriptor information. Converted +// types are promoted to a pointer to the converted type. +LLVM::LLVMType mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type) { + auto converted = + converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>(); + if (!converted) + return {}; + if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) + converted = converted.getPointerTo(); + return converted; +} + +/// Convert a MemRef type to a bare pointer to the MemRef element type. +static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter, + MemRefType type) { + int64_t offset; + SmallVector<int64_t, 4> strides; + if (failed(getStridesAndOffset(type, strides, offset))) + return {}; + + LLVM::LLVMType elementType = + unwrap(converter.convertType(type.getElementType())); + if (!elementType) + return {}; + return elementType.getPointerTo(type.getMemorySpace()); +} + +/// Callback to convert function argument types. It converts MemRef function +/// arguments to bare pointers to the MemRef element type. Converted types are +/// not promoted to pointers. +LLVM::LLVMType mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type) { + // TODO: Add support for unranked memref. + if (auto memrefTy = type.dyn_cast<MemRefType>()) + return convertMemRefTypeToBarePtr(converter, memrefTy) + .dyn_cast_or_null<LLVM::LLVMType>(); + return converter.convertType(type).dyn_cast_or_null<LLVM::LLVMType>(); +} + +/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization. +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) + : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {} + +/// Create an LLVMTypeConverter using 'custom' customizations. +LLVMTypeConverter::LLVMTypeConverter( + MLIRContext *ctx, const LLVMTypeConverterCustomization &customs) + : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()), + customizations(customs) { + assert(llvmDialect && "LLVM IR dialect is not registered"); + module = &llvmDialect->getLLVMModule(); +} + +/// Get the LLVM context. +llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { + return module->getContext(); +} + LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy( llvmDialect, module->getDataLayout().getPointerSizeInBits()); @@ -116,11 +175,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); - auto converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>(); + auto converted = customizations.funcArgConverter(*this, type) + .dyn_cast_or_null<LLVM::LLVMType>(); if (!converted) return {}; - if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) - converted = converted.getPointerTo(); result.addInputs(en.index(), converted); } @@ -493,27 +551,29 @@ protected: LLVM::LLVMDialect &dialect; }; -struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { - using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern; - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast<FuncOp>(op); - FunctionType type = funcOp.getType(); - - // Store the positions of memref-typed arguments so that we can emit loads - // from them to follow the calling convention. - SmallVector<unsigned, 4> promotedArgIndices; - promotedArgIndices.reserve(type.getNumInputs()); +struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> { +protected: + using LLVMLegalizationPattern::LLVMLegalizationPattern; + using UnsignedTypePair = std::pair<unsigned, Type>; + + // Gather the positions and types of memref-typed arguments in a given + // FunctionType. + void getMemRefArgIndicesAndTypes( + FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const { + argsInfo.reserve(type.getNumInputs()); for (auto en : llvm::enumerate(type.getInputs())) { if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>()) - promotedArgIndices.push_back(en.index()); + argsInfo.push_back({en.index(), en.value()}); } + } - // Convert the original function arguments. Struct arguments are promoted to - // pointer to struct arguments to allow calling external functions with - // various ABIs (e.g. compiled from C/C++ on platform X). + // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided + // to this legalization pattern. + LLVM::LLVMFuncOp + convertFuncOpToLLVMFuncOp(FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Convert the original function arguments. They are converted using the + // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = lowering.convertFunctionSignature( @@ -532,20 +592,41 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, + funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + return newFuncOp; + } +}; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast<FuncOp>(op); + + // Store the positions of memref-typed arguments so that we can emit loads + // from them to follow the calling convention. + SmallVector<UnsignedTypePair, 4> promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + // Insert loads from memref descriptor pointers in function bodies. if (!newFuncOp.getBody().empty()) { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - for (unsigned idx : promotedArgIndices) { - BlockArgument arg = firstBlock->getArgument(idx); + for (const auto &argInfo : promotedArgsInfo) { + BlockArgument arg = firstBlock->getArgument(argInfo.first); Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } @@ -556,6 +637,56 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { } }; +/// FuncOp legalization pattern that converts MemRef arguments to bare pointers +/// to the MemRef element type. This will impact the calling convention and ABI. +struct BarePtrFuncOpConversion : public FuncOpConversionBase { + using FuncOpConversionBase::FuncOpConversionBase; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = cast<FuncOp>(op); + + // Store the positions and type of memref-typed arguments so that we can + // promote them to MemRef descriptor structs at the beginning of the + // function. + SmallVector<UnsignedTypePair, 4> promotedArgsInfo; + getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (newFuncOp.getBody().empty()) { + rewriter.eraseOp(op); + return matchSuccess(); + } + + // Promote bare pointers from MemRef arguments to a MemRef descriptor struct + // at the beginning of the function so that all the MemRefs in the function + // have a uniform representation. + Block *firstBlock = &newFuncOp.getBody().front(); + rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); + auto funcLoc = funcOp.getLoc(); + for (const auto &argInfo : promotedArgsInfo) { + // TODO: Add support for unranked MemRefs. + if (auto memrefType = argInfo.second.dyn_cast<MemRefType>()) { + // Replace argument with a placeholder (undef), promote argument to a + // MemRef descriptor and replace placeholder with the last instruction + // of the MemRef descriptor. The placeholder is needed to avoid + // replacing argument uses in the MemRef descriptor instructions. + BlockArgument arg = firstBlock->getArgument(argInfo.first); + Value placeHolder = + rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType()); + rewriter.replaceUsesOfBlockArgument(arg, placeHolder); + auto desc = MemRefDescriptor::fromStaticShape( + rewriter, funcLoc, lowering, memrefType, arg); + rewriter.replaceOp(placeHolder.getDefiningOp(), {desc}); + } + } + + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + //////////////// Support for Lowering operations on n-D vectors //////////////// namespace { // Helper struct to "unroll" operations on n-D vectors in terms of operations on @@ -2128,7 +2259,6 @@ void mlir::populateStdToLLVMMemoryConversionPatters( // clang-format off patterns.insert< DimOpLowering, - FuncOpConversion, LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, @@ -2141,8 +2271,26 @@ void mlir::populateStdToLLVMMemoryConversionPatters( // clang-format on } +void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert<FuncOpConversion>(*converter.getDialect(), converter); +} + void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + +static void populateStdToLLVMBarePtrFuncOpConversionPattern( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert<BarePtrFuncOpConversion>(*converter.getDialect(), converter); +} + +void mlir::populateStdToLLVMBarePtrConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatters(converter, patterns); } @@ -2209,7 +2357,17 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, /// Create an instance of LLVMTypeConverter in the given context. static std::unique_ptr<LLVMTypeConverter> makeStandardToLLVMTypeConverter(MLIRContext *context) { - return std::make_unique<LLVMTypeConverter>(context); + LLVMTypeConverterCustomization customs; + customs.funcArgConverter = structFuncArgTypeConverter; + return std::make_unique<LLVMTypeConverter>(context, customs); +} + +/// Create an instance of BarePtrTypeConverter in the given context. +static std::unique_ptr<LLVMTypeConverter> +makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) { + LLVMTypeConverterCustomization customs; + customs.funcArgConverter = barePtrFuncArgTypeConverter; + return std::make_unique<LLVMTypeConverter>(context, customs); } namespace { @@ -2275,6 +2433,9 @@ static PassRegistration<LLVMLoweringPass> "Standard to the LLVM dialect", [] { return std::make_unique<LLVMLoweringPass>( - clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, - makeStandardToLLVMTypeConverter); + clUseAlloca.getValue(), + clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns + : populateStdToLLVMConversionPatterns, + clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter + : makeStandardToLLVMTypeConverter); }); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir index 358cf40..43cbc78 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,10 +1,4 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA - -// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) { - return -} // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -14,74 +8,11 @@ func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j) return } -// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { -// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - return %static : memref<32x18xf32> -} - -// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -func @zero_d_alloc() -> memref<f32> { -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> - -// ALLOCA-NOT: malloc -// ALLOCA: alloca -// ALLOCA-NOT: malloc - %0 = alloc() : memref<f32> - return %0 : memref<f32> -} - -// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { -func @zero_d_dealloc(%arg0: memref<f32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %arg0 : memref<f32> +// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) +func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) { return } -// CHECK-LABEL: func @aligned_1d_alloc( -func @aligned_1d_alloc() -> memref<42xf32> { -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 -// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 -// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - %0 = alloc() {alignment = 8} : memref<42xf32> - return %0 : memref<42xf32> -} - // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> { @@ -162,61 +93,6 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) { return } -// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @static_alloc() -> memref<32x18xf32> { -// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 -// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 -// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> - %0 = alloc() : memref<32x18xf32> - return %0 : memref<32x18xf32> -} - -// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { -func @static_dealloc(%static: memref<10x8xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %static : memref<10x8xf32> - return -} - -// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { -func @zero_d_load(%arg0: memref<f32>) -> f32 { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %arg0[] : memref<f32> - return %0 : f32 -} - -// CHECK-LABEL: func @static_load( -// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 -func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %static[%i, %j] : memref<10x42xf32> - return -} - // CHECK-LABEL: func @mixed_load( // CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { @@ -283,34 +159,6 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) { return } -// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) { -func @zero_d_store(%arg0: memref<f32>, %arg1: f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*"> - store %arg1, %arg0[] : memref<f32> - return -} - -// CHECK-LABEL: func @static_store -func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> - store %val, %static[%i, %j] : memref<10x42xf32> - return -} - // CHECK-LABEL: func @dynamic_store func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -440,20 +288,3 @@ func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { %4 = dim %mixed, 4 : memref<42x?x?x13x?xf32> return } - -// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { -func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 - %0 = dim %static, 0 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 - %1 = dim %static, 1 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 - %2 = dim %static, 2 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 - %3 = dim %static, 3 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 - %4 = dim %static, 4 : memref<42x32x15x13x27xf32> - return -} - diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir new file mode 100644 index 0000000..e44d2fc --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -0,0 +1,322 @@ +// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR + +// BAREPTR-LABEL: func @check_noalias +// BAREPTR-SAME: %{{.*}}: !llvm<"float*"> {llvm.noalias = true} +func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { + return +} + +// ----- + +// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// BAREPTR-LABEL: func @check_static_return +// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { +// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + +// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + return %static : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +func @zero_d_alloc() -> memref<f32> { +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + +// ALLOCA-NOT: malloc +// ALLOCA: alloca +// ALLOCA-NOT: malloc + +// BAREPTR-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + %0 = alloc() : memref<f32> + return %0 : memref<f32> +} + +// ----- + +// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { +// BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"float*">) { +func @zero_d_dealloc(%arg0: memref<f32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %arg0 : memref<f32> + return +} + +// ----- + +// CHECK-LABEL: func @aligned_1d_alloc( +// BAREPTR-LABEL: func @aligned_1d_alloc( +func @aligned_1d_alloc() -> memref<42xf32> { +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + +// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> +// BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %0 = alloc() {alignment = 8} : memref<42xf32> + return %0 : memref<42xf32> +} + +// ----- + +// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// BAREPTR-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @static_alloc() -> memref<32x18xf32> { +// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + +// BAREPTR-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[num_elems:.*]] = llvm.mul %[[sz1]], %[[sz2]] : !llvm.i64 +// BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + %0 = alloc() : memref<32x18xf32> + return %0 : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { +// BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm<"float*">) { +func @static_dealloc(%static: memref<10x8xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %static : memref<10x8xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { +// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm<"float*">) -> !llvm.float +func @zero_d_load(%arg0: memref<f32>) -> f32 { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm<"float*"> + %0 = load %arg0[] : memref<f32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @static_load( +// CHECK-SAME: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 +// BAREPTR-LABEL: func @static_load +// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { +func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + %0 = load %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float*, float*, i64 }*">, %arg1: !llvm.float) { +// BAREPTR-LABEL: func @zero_d_store +// BAREPTR-SAME: (%[[A:.*]]: !llvm<"float*">, %[[val:.*]]: !llvm.float) +func @zero_d_store(%arg0: memref<f32>, %arg1: f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %arg1, %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm<"float*"> + store %arg1, %arg0[] : memref<f32> + return +} + +// ----- + +// CHECK-LABEL: func @static_store +// BAREPTR-LABEL: func @static_store +// BAREPTR-SAME: %[[A:.*]]: !llvm<"float*"> +func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %arg3, %[[addr]] : !llvm<"float*"> + +// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> + store %val, %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { +// BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) { +func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 + %0 = dim %static, 0 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 + %1 = dim %static, 1 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 + %2 = dim %static, 2 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 + %3 = dim %static, 3 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 +// BAREPTR-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 + %4 = dim %static, 4 : memref<42x32x15x13x27xf32> + return +} diff --git a/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir b/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir new file mode 100644 index 0000000..59fd969 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir @@ -0,0 +1,183 @@ +// RUN: mlir-opt %s -convert-loop-to-std -convert-std-to-llvm -convert-std-to-llvm-use-bare-ptr-memref-call-conv | mlir-cpu-runner -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext -entry-point-result=void | FileCheck %s + +// Verify bare pointer memref calling convention. `simple_add1_add2_test` +// gets two 2xf32 memrefs, adds 1.0f to the first one and 2.0f to the second +// one. 'main' calls 'simple_add1_add2_test' with {1, 1} and {2, 2} so {2, 2} +// and {4, 4} are the expected outputs. + +func @simple_add1_add2_test(%arg0: memref<2xf32>, %arg1: memref<2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %cst = constant 1.000000e+00 : f32 + %cst_0 = constant 2.000000e+00 : f32 + loop.for %arg2 = %c0 to %c2 step %c1 { + %0 = load %arg0[%arg2] : memref<2xf32> + %1 = addf %0, %cst : f32 + store %1, %arg0[%arg2] : memref<2xf32> + // CHECK: 2, 2 + + %2 = load %arg1[%arg2] : memref<2xf32> + %3 = addf %1, %cst_0 : f32 + store %3, %arg1[%arg2] : memref<2xf32> + // CHECK-NEXT: 4, 4 + } + return +} + +// External declarations. +llvm.func @malloc(!llvm.i64) -> !llvm<"i8*"> +llvm.func @free(!llvm<"i8*">) +func @print_f32(%arg0: f32) +func @print_comma() +func @print_newline() + +// TODO: 'main' function currently has to be provided in LLVM dialect since +// 'call' op is not yet supported by the bare pointer calling convention. The +// LLVM dialect version was generated using the following loop/std dialect +// version and minor changes around the 'simple_add1_add2_test' call. + +//func @main() +//{ +// %c2 = constant 2 : index +// %c0 = constant 0 : index +// %c1 = constant 1 : index +// %cst = constant 1.000000e+00 : f32 +// %cst_0 = constant 2.000000e+00 : f32 +// %a = alloc() : memref<2xf32> +// %b = alloc() : memref<2xf32> +// loop.for %i = %c0 to %c2 step %c1 { +// store %cst, %a[%i] : memref<2xf32> +// store %cst, %b[%i] : memref<2xf32> +// } +// +// call @simple_add1_add2_test(%a, %b) : (memref<2xf32>, memref<2xf32>) -> () +// +// %l0 = load %a[%c0] : memref<2xf32> +// call @print_f32(%l0) : (f32) -> () +// call @print_comma() : () -> () +// %l1 = load %a[%c1] : memref<2xf32> +// call @print_f32(%l1) : (f32) -> () +// call @print_newline() : () -> () +// +// %l2 = load %b[%c0] : memref<2xf32> +// call @print_f32(%l2) : (f32) -> () +// call @print_comma() : () -> () +// %l3 = load %b[%c1] : memref<2xf32> +// call @print_f32(%l3) : (f32) -> () +// call @print_newline() : () -> () +// +// dealloc %a : memref<2xf32> +// dealloc %b : memref<2xf32> +// return +//} + +llvm.func @main() { + %0 = llvm.mlir.constant(2 : index) : !llvm.i64 + %1 = llvm.mlir.constant(0 : index) : !llvm.i64 + %2 = llvm.mlir.constant(1 : index) : !llvm.i64 + %3 = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + %4 = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float + %5 = llvm.mlir.constant(2 : index) : !llvm.i64 + %6 = llvm.mlir.null : !llvm<"float*"> + %7 = llvm.mlir.constant(1 : index) : !llvm.i64 + %8 = llvm.getelementptr %6[%7] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %9 = llvm.ptrtoint %8 : !llvm<"float*"> to !llvm.i64 + %10 = llvm.mul %5, %9 : !llvm.i64 + %11 = llvm.call @malloc(%10) : (!llvm.i64) -> !llvm<"i8*"> + %12 = llvm.bitcast %11 : !llvm<"i8*"> to !llvm<"float*"> + %13 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %14 = llvm.insertvalue %12, %13[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %15 = llvm.insertvalue %12, %14[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %16 = llvm.mlir.constant(0 : index) : !llvm.i64 + %17 = llvm.insertvalue %16, %15[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %18 = llvm.mlir.constant(1 : index) : !llvm.i64 + %19 = llvm.insertvalue %5, %17[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %20 = llvm.insertvalue %18, %19[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %21 = llvm.mlir.constant(2 : index) : !llvm.i64 + %22 = llvm.mlir.null : !llvm<"float*"> + %23 = llvm.mlir.constant(1 : index) : !llvm.i64 + %24 = llvm.getelementptr %22[%23] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %25 = llvm.ptrtoint %24 : !llvm<"float*"> to !llvm.i64 + %26 = llvm.mul %21, %25 : !llvm.i64 + %27 = llvm.call @malloc(%26) : (!llvm.i64) -> !llvm<"i8*"> + %28 = llvm.bitcast %27 : !llvm<"i8*"> to !llvm<"float*"> + %29 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %30 = llvm.insertvalue %28, %29[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %31 = llvm.insertvalue %28, %30[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %32 = llvm.mlir.constant(0 : index) : !llvm.i64 + %33 = llvm.insertvalue %32, %31[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %34 = llvm.mlir.constant(1 : index) : !llvm.i64 + %35 = llvm.insertvalue %21, %33[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %36 = llvm.insertvalue %34, %35[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + llvm.br ^bb1(%1 : !llvm.i64) +^bb1(%37: !llvm.i64): // 2 preds: ^bb0, ^bb2 + %38 = llvm.icmp "slt" %37, %0 : !llvm.i64 + llvm.cond_br %38, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + %39 = llvm.extractvalue %20[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %40 = llvm.mlir.constant(0 : index) : !llvm.i64 + %41 = llvm.mlir.constant(1 : index) : !llvm.i64 + %42 = llvm.mul %37, %41 : !llvm.i64 + %43 = llvm.add %40, %42 : !llvm.i64 + %44 = llvm.getelementptr %39[%43] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + llvm.store %3, %44 : !llvm<"float*"> + %45 = llvm.extractvalue %36[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %46 = llvm.mlir.constant(0 : index) : !llvm.i64 + %47 = llvm.mlir.constant(1 : index) : !llvm.i64 + %48 = llvm.mul %37, %47 : !llvm.i64 + %49 = llvm.add %46, %48 : !llvm.i64 + %50 = llvm.getelementptr %45[%49] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + llvm.store %3, %50 : !llvm<"float*"> + %51 = llvm.add %37, %2 : !llvm.i64 + llvm.br ^bb1(%51 : !llvm.i64) +^bb3: // pred: ^bb1 + %52 = llvm.mlir.constant(1 : index) : !llvm.i64 + %53 = llvm.mlir.constant(1 : index) : !llvm.i64 + %54 = llvm.extractvalue %20[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %55 = llvm.extractvalue %36[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + llvm.call @simple_add1_add2_test(%54, %55) : (!llvm<"float*">, !llvm<"float*">) -> () + %56 = llvm.extractvalue %20[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %57 = llvm.mlir.constant(0 : index) : !llvm.i64 + %58 = llvm.mlir.constant(1 : index) : !llvm.i64 + %59 = llvm.mul %1, %58 : !llvm.i64 + %60 = llvm.add %57, %59 : !llvm.i64 + %61 = llvm.getelementptr %56[%60] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %62 = llvm.load %61 : !llvm<"float*"> + llvm.call @print_f32(%62) : (!llvm.float) -> () + llvm.call @print_comma() : () -> () + %63 = llvm.extractvalue %20[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %64 = llvm.mlir.constant(0 : index) : !llvm.i64 + %65 = llvm.mlir.constant(1 : index) : !llvm.i64 + %66 = llvm.mul %2, %65 : !llvm.i64 + %67 = llvm.add %64, %66 : !llvm.i64 + %68 = llvm.getelementptr %63[%67] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %69 = llvm.load %68 : !llvm<"float*"> + llvm.call @print_f32(%69) : (!llvm.float) -> () + llvm.call @print_newline() : () -> () + %70 = llvm.extractvalue %36[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %71 = llvm.mlir.constant(0 : index) : !llvm.i64 + %72 = llvm.mlir.constant(1 : index) : !llvm.i64 + %73 = llvm.mul %1, %72 : !llvm.i64 + %74 = llvm.add %71, %73 : !llvm.i64 + %75 = llvm.getelementptr %70[%74] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %76 = llvm.load %75 : !llvm<"float*"> + llvm.call @print_f32(%76) : (!llvm.float) -> () + llvm.call @print_comma() : () -> () + %77 = llvm.extractvalue %36[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %78 = llvm.mlir.constant(0 : index) : !llvm.i64 + %79 = llvm.mlir.constant(1 : index) : !llvm.i64 + %80 = llvm.mul %2, %79 : !llvm.i64 + %81 = llvm.add %78, %80 : !llvm.i64 + %82 = llvm.getelementptr %77[%81] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + %83 = llvm.load %82 : !llvm<"float*"> + llvm.call @print_f32(%83) : (!llvm.float) -> () + llvm.call @print_newline() : () -> () + %84 = llvm.extractvalue %20[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %85 = llvm.bitcast %84 : !llvm<"float*"> to !llvm<"i8*"> + llvm.call @free(%85) : (!llvm<"i8*">) -> () + %86 = llvm.extractvalue %36[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %87 = llvm.bitcast %86 : !llvm<"float*"> to !llvm<"i8*"> + llvm.call @free(%87) : (!llvm<"i8*">) -> () + llvm.return +} |