aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp3
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp154
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h27
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp54
4 files changed, 238 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5d08ccc..7ca09d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -29,6 +29,8 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
+#include "LLVMDialectBytecode.h"
+
#include <numeric>
#include <optional>
@@ -4237,6 +4239,7 @@ void LLVMDialect::initialize() {
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
+ detail::addBytecodeInterface(this);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
new file mode 100644
index 0000000..41d1f80
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
@@ -0,0 +1,154 @@
+//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+// Provide some forward declarations of the functions that will be generated by
+// the include below.
+static void write(DIExpressionElemAttr attribute,
+ DialectBytecodeWriter &writer);
+static LogicalResult writeAttribute(Attribute attribute,
+ DialectBytecodeWriter &writer);
+
+//===--------------------------------------------------------------------===//
+// Optional ArrayRefs
+//
+// Note that both the writer and reader functions consider attributes to be
+// optional. This is because the attribute may be present or empty.
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalArrayRef(DialectBytecodeWriter &writer,
+ ArrayRef<EntryTy> storage) {
+ if (storage.empty()) {
+ writer.writeOwnedBool(false);
+ return;
+ }
+
+ writer.writeOwnedBool(true);
+ writer.writeList(storage, [&](EntryTy val) {
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ (void)writer.writeOptionalAttribute(val);
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ (void)writer.writeVarInt(val);
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ });
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader,
+ SmallVectorImpl<EntryTy> &storage) {
+ bool isPresent = false;
+ if (failed(reader.readBool(isPresent)))
+ return failure();
+ // Nothing to do here, the array is empty.
+ if (!isPresent)
+ return success();
+
+ auto readEntry = [&]() -> FailureOr<EntryTy> {
+ EntryTy temp;
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ if (succeeded(reader.readOptionalAttribute(temp)))
+ return temp;
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ if (succeeded(reader.readVarInt(temp)))
+ return temp;
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ return failure();
+ };
+
+ return reader.readList(storage, readEntry);
+}
+
+//===--------------------------------------------------------------------===//
+// Optional integral types
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalInt(DialectBytecodeWriter &writer,
+ std::optional<EntryTy> storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ EntryTy val = storage.value_or(0);
+ writer.writeVarIntWithFlag(val, storage.has_value());
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalInt(DialectBytecodeReader &reader,
+ std::optional<EntryTy> &storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ uint64_t result = 0;
+ bool flag = false;
+ if (failed(reader.readVarIntWithFlag(result, flag)))
+ return failure();
+ if (flag)
+ storage = static_cast<EntryTy>(result);
+ else
+ storage = std::nullopt;
+ return success();
+}
+
+//===--------------------------------------------------------------------===//
+// Tablegen generated bytecode functions
+//===--------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc"
+
+//===--------------------------------------------------------------------===//
+// LLVMDialectBytecodeInterface
+//===--------------------------------------------------------------------===//
+
+/// This class implements the bytecode interface for the LLVM dialect.
+struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface {
+ LLVMDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ // Attributes
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
+ }
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
+ }
+
+ // Types
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
+ }
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
+ }
+};
+} // namespace
+
+void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectBytecodeInterface>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
new file mode 100644
index 0000000..1a17cb4
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
@@ -0,0 +1,27 @@
+//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the LLVM dialect bytecode
+// implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+
+namespace mlir::LLVM {
+class LLVMDialect;
+
+namespace detail {
+/// Add the interfaces necessary for encoding the LLVM dialect components in
+/// bytecode.
+void addBytecodeInterface(LLVMDialect *dialect);
+} // namespace detail
+} // namespace mlir::LLVM
+
+#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..55fe0d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2047,6 +2058,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2334,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,