aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp66
1 files changed, 60 insertions, 6 deletions
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 90c3d63..1a73456 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -347,10 +347,14 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (c->isString())
return b.getStringAttr(c->getAsString());
if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
- if (c->getType()->isDoubleTy())
- return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
- if (c->getType()->isFloatingPointTy())
- return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
+ auto *type = c->getType();
+ FloatType floatTy;
+ if (type->isBFloatTy())
+ floatTy = FloatType::getBF16(context);
+ else
+ floatTy = getDLFloatType(*context, type->getScalarSizeInBits());
+ assert(floatTy && "unsupported floating point type");
+ return b.getFloatAttr(floatTy, c->getValueAPF());
}
if (auto *f = dyn_cast<llvm::Function>(value))
return SymbolRefAttr::get(b.getContext(), f->getName());
@@ -607,7 +611,7 @@ static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
// FIXME: cleanuppad
// FIXME: catchpad
// ICmp is handled specially.
- // FIXME: fcmp
+ // FCmp is handled specially.
// PHI is handled specially.
INST(Freeze, Freeze), INST(Call, Call),
// FIXME: select
@@ -649,7 +653,47 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
case llvm::CmpInst::Predicate::ICMP_UGE:
return LLVM::ICmpPredicate::uge;
}
- llvm_unreachable("incorrect comparison predicate");
+ llvm_unreachable("incorrect integer comparison predicate");
+}
+
+static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) {
+ switch (p) {
+ default:
+ llvm_unreachable("incorrect comparison predicate");
+ case llvm::CmpInst::Predicate::FCMP_FALSE:
+ return LLVM::FCmpPredicate::_false;
+ case llvm::CmpInst::Predicate::FCMP_TRUE:
+ return LLVM::FCmpPredicate::_true;
+ case llvm::CmpInst::Predicate::FCMP_OEQ:
+ return LLVM::FCmpPredicate::oeq;
+ case llvm::CmpInst::Predicate::FCMP_ONE:
+ return LLVM::FCmpPredicate::one;
+ case llvm::CmpInst::Predicate::FCMP_OLT:
+ return LLVM::FCmpPredicate::olt;
+ case llvm::CmpInst::Predicate::FCMP_OLE:
+ return LLVM::FCmpPredicate::ole;
+ case llvm::CmpInst::Predicate::FCMP_OGT:
+ return LLVM::FCmpPredicate::ogt;
+ case llvm::CmpInst::Predicate::FCMP_OGE:
+ return LLVM::FCmpPredicate::oge;
+ case llvm::CmpInst::Predicate::FCMP_ORD:
+ return LLVM::FCmpPredicate::ord;
+ case llvm::CmpInst::Predicate::FCMP_ULT:
+ return LLVM::FCmpPredicate::ult;
+ case llvm::CmpInst::Predicate::FCMP_ULE:
+ return LLVM::FCmpPredicate::ule;
+ case llvm::CmpInst::Predicate::FCMP_UGT:
+ return LLVM::FCmpPredicate::ugt;
+ case llvm::CmpInst::Predicate::FCMP_UGE:
+ return LLVM::FCmpPredicate::uge;
+ case llvm::CmpInst::Predicate::FCMP_UNO:
+ return LLVM::FCmpPredicate::uno;
+ case llvm::CmpInst::Predicate::FCMP_UEQ:
+ return LLVM::FCmpPredicate::ueq;
+ case llvm::CmpInst::Predicate::FCMP_UNE:
+ return LLVM::FCmpPredicate::une;
+ }
+ llvm_unreachable("incorrect floating point comparison predicate");
}
static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
@@ -774,6 +818,16 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
rhs);
return success();
}
+ case llvm::Instruction::FCmp: {
+ Value lhs = processValue(inst->getOperand(0));
+ Value rhs = processValue(inst->getOperand(1));
+ if (!lhs || !rhs)
+ return failure();
+ instMap[inst] = b.create<FCmpOp>(
+ loc, b.getI1Type(),
+ getFCmpPredicate(cast<llvm::FCmpInst>(inst)->getPredicate()), lhs, rhs);
+ return success();
+ }
case llvm::Instruction::Br: {
auto *brInst = cast<llvm::BranchInst>(inst);
OperationState state(loc,