aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..55fe0d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2047,6 +2058,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2334,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,