diff options
Diffstat (limited to 'mlir/lib/Target')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 52 | ||||
| -rw-r--r-- | mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 18 |
2 files changed, 43 insertions, 27 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 2acbd03..64e3c5f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -649,40 +649,38 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( auto *arrayType = llvm::ArrayType::get(elementType, numElements); if (child->isZeroValue() && !elementType->isFPOrFPVectorTy()) { return llvm::ConstantAggregateZero::get(arrayType); - } else { - if (llvm::ConstantDataSequential::isElementTypeCompatible( - elementType)) { - // TODO: Handle all compatible types. This code only handles integer. - if (isa<llvm::IntegerType>(elementType)) { - if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) { - if (ci->getBitWidth() == 8) { - SmallVector<int8_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 16) { - SmallVector<int16_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 32) { - SmallVector<int32_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 64) { - SmallVector<int64_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } + } + if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) { + // TODO: Handle all compatible types. This code only handles integer. + if (isa<llvm::IntegerType>(elementType)) { + if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) { + if (ci->getBitWidth() == 8) { + SmallVector<int8_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 16) { + SmallVector<int16_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 32) { + SmallVector<int32_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 64) { + SmallVector<int64_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); } } } + } // std::vector is used here to accomodate large number of elements that // exceed SmallVector capacity. std::vector<llvm::Constant *> constants(numElements, child); return llvm::ConstantArray::get(arrayType, constants); - } } } diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index b88fbaa..29ed5a4 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -89,6 +89,22 @@ static bool isZeroValue(Attribute attr) { return false; } +/// Move all functions declaration before functions definitions. In SPIR-V +/// "declarations" are functions without a body and "definitions" functions +/// with a body. This is stronger than necessary. It should be sufficient to +/// ensure any declarations precede their uses and not all definitions, however +/// this allows to avoid analysing every function in the module this way. +static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) { + Block::OpListType &ops = moduleOp.getBody()->getOperations(); + if (ops.empty()) + return; + Operation &firstOp = ops.front(); + for (Operation &op : llvm::drop_begin(ops)) + if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) + if (funcOp.getBody().empty()) + funcOp->moveBefore(&firstOp); +} + namespace mlir { namespace spirv { @@ -119,6 +135,8 @@ LogicalResult Serializer::serialize() { processMemoryModel(); processDebugInfo(); + moveFuncDeclarationsToTop(module); + // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : *module.getBody()) { |
