diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 66 |
1 files changed, 60 insertions, 6 deletions
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 90c3d63..1a73456 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -347,10 +347,14 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) { if (c->isString()) return b.getStringAttr(c->getAsString()); if (auto *c = dyn_cast<llvm::ConstantFP>(value)) { - if (c->getType()->isDoubleTy()) - return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); - if (c->getType()->isFloatingPointTy()) - return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); + auto *type = c->getType(); + FloatType floatTy; + if (type->isBFloatTy()) + floatTy = FloatType::getBF16(context); + else + floatTy = getDLFloatType(*context, type->getScalarSizeInBits()); + assert(floatTy && "unsupported floating point type"); + return b.getFloatAttr(floatTy, c->getValueAPF()); } if (auto *f = dyn_cast<llvm::Function>(value)) return SymbolRefAttr::get(b.getContext(), f->getName()); @@ -607,7 +611,7 @@ static StringRef lookupOperationNameFromOpcode(unsigned opcode) { // FIXME: cleanuppad // FIXME: catchpad // ICmp is handled specially. - // FIXME: fcmp + // FCmp is handled specially. // PHI is handled specially. INST(Freeze, Freeze), INST(Call, Call), // FIXME: select @@ -649,7 +653,47 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { case llvm::CmpInst::Predicate::ICMP_UGE: return LLVM::ICmpPredicate::uge; } - llvm_unreachable("incorrect comparison predicate"); + llvm_unreachable("incorrect integer comparison predicate"); +} + +static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) { + switch (p) { + default: + llvm_unreachable("incorrect comparison predicate"); + case llvm::CmpInst::Predicate::FCMP_FALSE: + return LLVM::FCmpPredicate::_false; + case llvm::CmpInst::Predicate::FCMP_TRUE: + return LLVM::FCmpPredicate::_true; + case llvm::CmpInst::Predicate::FCMP_OEQ: + return LLVM::FCmpPredicate::oeq; + case llvm::CmpInst::Predicate::FCMP_ONE: + return LLVM::FCmpPredicate::one; + case llvm::CmpInst::Predicate::FCMP_OLT: + return LLVM::FCmpPredicate::olt; + case llvm::CmpInst::Predicate::FCMP_OLE: + return LLVM::FCmpPredicate::ole; + case llvm::CmpInst::Predicate::FCMP_OGT: + return LLVM::FCmpPredicate::ogt; + case llvm::CmpInst::Predicate::FCMP_OGE: + return LLVM::FCmpPredicate::oge; + case llvm::CmpInst::Predicate::FCMP_ORD: + return LLVM::FCmpPredicate::ord; + case llvm::CmpInst::Predicate::FCMP_ULT: + return LLVM::FCmpPredicate::ult; + case llvm::CmpInst::Predicate::FCMP_ULE: + return LLVM::FCmpPredicate::ule; + case llvm::CmpInst::Predicate::FCMP_UGT: + return LLVM::FCmpPredicate::ugt; + case llvm::CmpInst::Predicate::FCMP_UGE: + return LLVM::FCmpPredicate::uge; + case llvm::CmpInst::Predicate::FCMP_UNO: + return LLVM::FCmpPredicate::uno; + case llvm::CmpInst::Predicate::FCMP_UEQ: + return LLVM::FCmpPredicate::ueq; + case llvm::CmpInst::Predicate::FCMP_UNE: + return LLVM::FCmpPredicate::une; + } + llvm_unreachable("incorrect floating point comparison predicate"); } static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { @@ -774,6 +818,16 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { rhs); return success(); } + case llvm::Instruction::FCmp: { + Value lhs = processValue(inst->getOperand(0)); + Value rhs = processValue(inst->getOperand(1)); + if (!lhs || !rhs) + return failure(); + instMap[inst] = b.create<FCmpOp>( + loc, b.getI1Type(), + getFCmpPredicate(cast<llvm::FCmpInst>(inst)->getPredicate()), lhs, rhs); + return success(); + } case llvm::Instruction::Br: { auto *brInst = cast<llvm::BranchInst>(inst); OperationState state(loc, |