aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target')
-rw-r--r--mlir/lib/Target/Cpp/TranslateRegistration.cpp2
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp8
-rw-r--r--mlir/lib/Target/LLVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp36
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp59
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp105
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp61
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp32
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h1
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp61
19 files changed, 394 insertions, 163 deletions
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd..7dae03e 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -9,8 +9,6 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a393d88..dcd2e11 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -17,15 +17,12 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <stack>
-#include <utility>
#define DEBUG_TYPE "translate-to-cpp"
@@ -903,8 +900,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
- auto expressionOp =
- dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
@@ -1545,7 +1541,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
return success();
}
- auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index 7c6fc37..f6e44c6 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRTargetLLVM
intrinsics_gen
LINK_COMPONENTS
+ BitWriter
Core
IPO
IRReader
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index 75170bf..8e6f5c7 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Tools/mlir-translate/Translation.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ff34a08..0f675a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
-static LogicalResult
-convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
- ArrayAttr resAttrsArray, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (argAttrsArray) {
- for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
- if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
- !argAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, argAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addParamAttrs(argIdx, *attrBuilder);
- }
- }
- }
-
- if (resAttrsArray && resAttrsArray.size() > 0) {
- if (resAttrsArray.size() != 1)
- return mlir::emitError(loc, "llvm.func cannot have multiple results");
- if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
- !resAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, resAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addRetAttrs(*attrBuilder);
- }
- }
- return success();
-}
-
-static LogicalResult
-convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- return convertParameterAndResultAttrs(
- callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
- moduleTranslation);
-}
-
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));
- if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
- op.getResAttrsAttr(), inst,
- moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
return failure();
if (op.getNumResults() == 1)
@@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getInlineHintAttr())
call->addFnAttr(llvm::Attribute::InlineHint);
- if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
return failure();
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
@@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
- if (failed(
- convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index ad01a64..55e73e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -13,7 +13,6 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
-
#include "llvm/IR/ConstantRange.h"
using namespace mlir;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index b3577c6..90462d1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
}
}
+/// Return the intrinsic ID associated with stmatrix for the given paramters.
+static llvm::Intrinsic::ID
+getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShapeAttr shape,
+ NVVM::LdStMatrixEltType eltType) {
+ if (shape.getM() == 8 && shape.getN() == 8) {
+ switch (num) {
+ case 1:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
+ }
+ } else if (shape.getM() == 16 && shape.getN() == 8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
+ case 4:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
+ }
+ }
+ llvm_unreachable("unknown stmatrix kind");
+}
+
/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index d162afd..97c6b4e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -151,8 +151,7 @@ processDataOperands(llvm::IRBuilderBase &builder,
// Copyin operands are handled as `to` call.
llvm::SmallVector<mlir::Value> create, copyin;
for (mlir::Value dataOp : op.getDataClauseOperands()) {
- if (auto createOp =
- mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
+ if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) {
create.push_back(createOp.getVarPtr());
} else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
dataOp.getDefiningOp())) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da39b19..49e1e55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -16,15 +16,12 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -39,7 +36,6 @@
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
-#include <any>
#include <cstdint>
#include <iterator>
#include <numeric>
@@ -3541,8 +3537,7 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
}
static bool isDeclareTargetLink(mlir::Value value) {
- if (auto addressOfOp =
- llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
+ if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
if (auto declareTargetGlobal =
@@ -3882,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
- llvm::sort(indices.begin(), indices.end(),
- [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+ llvm::sort(indices, [&](const size_t a, const size_t b) {
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+ for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
- if (aIndex == bIndex)
- continue;
+ if (aIndex == bIndex)
+ continue;
- if (aIndex < bIndex)
- return first;
+ if (aIndex < bIndex)
+ return first;
- if (aIndex > bIndex)
- return !first;
- }
+ if (aIndex > bIndex)
+ return !first;
+ }
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
+ // Iterated the up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ return memberIndicesA.size() < memberIndicesB.size();
+ });
return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
@@ -4502,8 +4496,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = dataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4520,8 +4513,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = enterDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn =
@@ -4540,8 +4532,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = exitDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4560,8 +4551,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = updateDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -5202,8 +5192,7 @@ static std::optional<int64_t> extractConstInteger(Value value) {
if (!value)
return std::nullopt;
- if (auto constOp =
- dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
+ if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
return constAttr.getInt();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
index 580afdd..cb1f234 100644
--- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
@@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
+ llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false,
+ /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands,
+ mlirAttrs)))
return failure();
Type resultType = moduleImport.convertType(inst->getType());
@@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
ValueRange{mlirOperands}, FastmathFlagsAttr{});
moduleImport.setFastmathFlagsAttr(inst, op);
-
- ArrayAttr argsAttr, resAttr;
- moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
- op.setArgAttrsAttr(argsAttr);
- op.setResAttrsAttr(resAttr);
+ moduleImport.convertArgAndResultAttrs(inst, op);
// Update importer tracking of results.
unsigned numRes = op.getNumResults();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 94db7f8..6325480 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
#include "llvm/IR/Constants.h"
@@ -142,6 +143,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
// TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
+
return failure();
}
@@ -1062,6 +1064,18 @@ void ModuleImport::convertTargetTriple() {
builder.getStringAttr(llvmModule->getTargetTriple().str()));
}
+void ModuleImport::convertModuleLevelAsm() {
+ llvm::StringRef asmStr = llvmModule->getModuleInlineAsm();
+ llvm::SmallVector<mlir::Attribute> asmArrayAttr;
+
+ for (llvm::StringRef line : llvm::split(asmStr, '\n'))
+ if (!line.empty())
+ asmArrayAttr.push_back(builder.getStringAttr(line));
+
+ mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(),
+ builder.getArrayAttr(asmArrayAttr));
+}
+
LogicalResult ModuleImport::convertFunctions() {
for (llvm::Function &func : llvmModule->functions())
if (failed(processFunction(&func)))
@@ -1626,12 +1640,11 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
// Convert dso_local_equivalent.
if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) {
Type type = convertType(dsoLocalEquivalent->getType());
- return builder
- .create<DSOLocalEquivalentOp>(
- loc, type,
- FlatSymbolRefAttr::get(
- builder.getContext(),
- dsoLocalEquivalent->getGlobalValue()->getName()))
+ return DSOLocalEquivalentOp::create(
+ builder, loc, type,
+ FlatSymbolRefAttr::get(
+ builder.getContext(),
+ dsoLocalEquivalent->getGlobalValue()->getName()))
.getResult();
}
@@ -1736,9 +1749,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
auto blockTag =
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
- return builder
- .create<BlockAddressOp>(loc, convertType(blockAddr->getType()),
- BlockAddressAttr::get(context, fnSym, blockTag))
+ return BlockAddressOp::create(
+ builder, loc, convertType(blockAddr->getType()),
+ BlockAddressAttr::get(context, fnSym, blockTag))
.getRes();
}
@@ -2228,17 +2241,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (!resultTy)
return failure();
ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst);
- return builder
- .create<InlineAsmOp>(
- loc, resultTy, *operands,
- builder.getStringAttr(asmI->getAsmString()),
- builder.getStringAttr(asmI->getConstraintString()),
- asmI->hasSideEffects(), asmI->isAlignStack(),
- convertTailCallKindFromLLVM(callInst->getTailCallKind()),
- AsmDialectAttr::get(
- mlirModule.getContext(),
- convertAsmDialectFromLLVM(asmI->getDialect())),
- operandAttrs)
+ return InlineAsmOp::create(
+ builder, loc, resultTy, *operands,
+ builder.getStringAttr(asmI->getAsmString()),
+ builder.getStringAttr(asmI->getConstraintString()),
+ asmI->hasSideEffects(), asmI->isAlignStack(),
+ convertTailCallKindFromLLVM(callInst->getTailCallKind()),
+ AsmDialectAttr::get(
+ mlirModule.getContext(),
+ convertAsmDialectFromLLVM(asmI->getDialect())),
+ operandAttrs)
.getOperation();
}
bool isIncompatibleCall;
@@ -2268,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
- convertParameterAttributes(callInst, callOp, builder);
+ convertArgAndResultAttrs(callInst, callOp);
return callOp.getOperation();
}();
@@ -2365,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ convertArgAndResultAttrs(invokeInst, invokeOp);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
@@ -2731,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
}
DictionaryAttr
-ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder) {
+ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) {
SmallVector<NamedAttribute> paramAttrs;
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
+ auto llvmAttr = llvmAttrSet.getAttribute(llvmKind);
// Skip attributes that are not attached.
if (!llvmAttr.isValid())
continue;
@@ -2770,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
return builder.getDictionaryAttr(paramAttrs);
}
-void ModuleImport::convertParameterAttributes(llvm::Function *func,
- LLVMFuncOp funcOp,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(llvm::Function *func,
+ LLVMFuncOp funcOp) {
auto llvmAttrs = func->getAttributes();
for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
- funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
+ funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs));
}
// Convert the result attributes and attach them wrapped in an ArrayAttribute
// to the funcOp.
@@ -2784,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
if (!llvmResAttr.hasAttributes())
return;
funcOp.setResAttrsAttr(
- builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
+ builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)}));
}
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- ArrayAttr &argsAttr,
- ArrayAttr &resAttr,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(
+ llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions) {
+ // Compute the set of immediate argument positions.
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ // Convert the argument attributes and filter out immediate arguments.
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ // Skip immediate arguments.
+ if (immArgPositionsSet.contains(i))
+ continue;
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
@@ -2808,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
- argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argsAttr = getArrayAttr(argAttrs);
+ argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs));
+ attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
+ // Convert the result attributes.
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
- DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
- resAttr = getArrayAttr({resAttrs});
-}
-
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- CallOpInterface callOp,
- OpBuilder &builder) {
- ArrayAttr argsAttr, resAttr;
- convertParameterAttributes(call, argsAttr, resAttr, builder);
- callOp.setArgAttrsAttr(argsAttr);
- callOp.setResAttrsAttr(resAttr);
+ DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr);
+ attrsOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}
template <typename Op>
@@ -2893,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
builder, loc, func->getName(), functionType,
convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
- convertParameterAttributes(func, funcOp, builder);
+ convertArgAndResultAttrs(func, funcOp);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
funcOp.setPersonalityAttr(personality);
@@ -3200,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule(
if (failed(moduleImport.convertIFuncs()))
return {};
moduleImport.convertTargetTriple();
+ moduleImport.convertModuleLevelAsm();
return module;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b997e55..b3a06e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+LogicalResult ModuleTranslation::convertArgAndResultAttrs(
+ ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions) {
+ // Convert the argument attributes.
+ if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) {
+ unsigned argAttrIdx = 0;
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) {
+ if (argAttrIdx >= argAttrsArray.size())
+ break;
+ // Skip immediate arguments (they have no entries in argAttrsArray).
+ if (immArgPositionsSet.contains(argIdx))
+ continue;
+ // Skip empty argument attributes.
+ auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]);
+ if (argAttrs.empty())
+ continue;
+ // Convert and add attributes to the call instruction.
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ // Convert the result attributes.
+ if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) {
+ if (!resAttrsArray.empty()) {
+ auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+ }
+
+ return success();
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
@@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
llvmModule->setTargetTriple(
llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
+ if (auto asmAttr = m->getDiscardableAttr(
+ LLVM::LLVMDialect::getModuleLevelAsmAttrName())) {
+ auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr);
+ if (!asmArrayAttr) {
+ m->emitError("expected an array attribute for a module level asm");
+ return nullptr;
+ }
+
+ for (Attribute elt : asmArrayAttr) {
+ auto asmStrAttr = dyn_cast<StringAttr>(elt);
+ if (!asmStrAttr) {
+ m->emitError(
+ "expected a string attribute for each entry of a module level asm");
+ return nullptr;
+ }
+ llvmModule->appendModuleInlineAsm(asmStrAttr.getValue());
+ }
+ }
+
return llvmModule;
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e5934bb..88931b5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ structDecorationsInfo.emplace_back(decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd..db1cc3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 6aab6bb..737f296 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -19,7 +19,6 @@
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
@@ -319,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -447,6 +447,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();
@@ -618,11 +631,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -692,6 +710,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
@@ -926,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
+ } else if (isa<spirv::TensorArmType>(constType)) {
+ numberOfConstituents = shapedType.getNumElements();
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ uint32_t elementID = 0;
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
+ elementID =
+ elementType.isInteger(1)
+ ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
+ : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
+ }
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
+ elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
+ }
+ if (!elementID) {
+ return 0;
+ }
+ operands.push_back(elementID);
+ }
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {