aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target')
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp52
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp18
2 files changed, 43 insertions, 27 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2acbd03..64e3c5f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -649,40 +649,38 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
if (child->isZeroValue() && !elementType->isFPOrFPVectorTy()) {
return llvm::ConstantAggregateZero::get(arrayType);
- } else {
- if (llvm::ConstantDataSequential::isElementTypeCompatible(
- elementType)) {
- // TODO: Handle all compatible types. This code only handles integer.
- if (isa<llvm::IntegerType>(elementType)) {
- if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
- if (ci->getBitWidth() == 8) {
- SmallVector<int8_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 16) {
- SmallVector<int16_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 32) {
- SmallVector<int32_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 64) {
- SmallVector<int64_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
+ }
+ if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) {
+ // TODO: Handle all compatible types. This code only handles integer.
+ if (isa<llvm::IntegerType>(elementType)) {
+ if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
+ if (ci->getBitWidth() == 8) {
+ SmallVector<int8_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 16) {
+ SmallVector<int16_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 32) {
+ SmallVector<int32_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 64) {
+ SmallVector<int64_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
}
}
}
+ }
// std::vector is used here to accomodate large number of elements that
// exceed SmallVector capacity.
std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
- }
}
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b88fbaa..29ed5a4 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -89,6 +89,22 @@ static bool isZeroValue(Attribute attr) {
return false;
}
+/// Move all functions declaration before functions definitions. In SPIR-V
+/// "declarations" are functions without a body and "definitions" functions
+/// with a body. This is stronger than necessary. It should be sufficient to
+/// ensure any declarations precede their uses and not all definitions, however
+/// this allows to avoid analysing every function in the module this way.
+static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
+ Block::OpListType &ops = moduleOp.getBody()->getOperations();
+ if (ops.empty())
+ return;
+ Operation &firstOp = ops.front();
+ for (Operation &op : llvm::drop_begin(ops))
+ if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
+ if (funcOp.getBody().empty())
+ funcOp->moveBefore(&firstOp);
+}
+
namespace mlir {
namespace spirv {
@@ -119,6 +135,8 @@ LogicalResult Serializer::serialize() {
processMemoryModel();
processDebugInfo();
+ moveFuncDeclarationsToTop(module);
+
// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : *module.getBody()) {