diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 72 |
1 files changed, 55 insertions, 17 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 6f56a17..77094d4 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1993,8 +1993,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst, /// Checks if `callType` and `calleeType` are compatible and can be represented /// in MLIR. static LogicalResult -verifyFunctionTypeCompatibility(LLVMFunctionType callType, - LLVMFunctionType calleeType) { +checkFunctionTypeCompatibility(LLVMFunctionType callType, + LLVMFunctionType calleeType) { if (callType.getReturnType() != calleeType.getReturnType()) return failure(); @@ -2020,7 +2020,9 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType, } FailureOr<LLVMFunctionType> -ModuleImport::convertFunctionType(llvm::CallBase *callInst) { +ModuleImport::convertFunctionType(llvm::CallBase *callInst, + bool &isIncompatibleCall) { + isIncompatibleCall = false; auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> { auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType); if (!funcTy) @@ -2043,11 +2045,14 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) { if (failed(calleeType)) return failure(); - // Compare the types to avoid constructing illegal call/invoke operations. - if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) { + // Compare the types and notify users via `isIncompatibleCall` if they are not + // compatible. + if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) { + isIncompatibleCall = true; Location loc = translateLoc(callInst->getDebugLoc()); - return emitError(loc) << "incompatible call and callee types: " << *callType - << " and " << *calleeType; + emitWarning(loc) << "incompatible call and callee types: " << *callType + << " and " << *calleeType; + return callType; } return calleeType; @@ -2164,16 +2169,34 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { /*operand_attrs=*/nullptr) .getOperation(); } - FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst); + bool isIncompatibleCall; + FailureOr<LLVMFunctionType> funcTy = + convertFunctionType(callInst, isIncompatibleCall); if (failed(funcTy)) return failure(); - FlatSymbolRefAttr callee = convertCalleeName(callInst); - auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands); + FlatSymbolRefAttr callee = nullptr; + if (isIncompatibleCall) { + // Use an indirect call (in order to represent valid and verifiable LLVM + // IR). Build the indirect call by passing an empty `callee` operand and + // insert into `operands` to include the indirect call target. + FlatSymbolRefAttr calleeSym = convertCalleeName(callInst); + Value indirectCallVal = builder.create<LLVM::AddressOfOp>( + loc, LLVM::LLVMPointerType::get(context), calleeSym); + operands->insert(operands->begin(), indirectCallVal); + } else { + // Regular direct call using callee name. + callee = convertCalleeName(callInst); + } + CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands); + if (failed(convertCallAttributes(callInst, callOp))) return failure(); - // Handle parameter and result attributes. - convertParameterAttributes(callInst, callOp, builder); + + // Handle parameter and result attributes unless it's an incompatible + // call. + if (!isIncompatibleCall) + convertParameterAttributes(callInst, callOp, builder); return callOp.getOperation(); }(); @@ -2238,12 +2261,25 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { unwindArgs))) return failure(); - FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst); + bool isIncompatibleInvoke; + FailureOr<LLVMFunctionType> funcTy = + convertFunctionType(invokeInst, isIncompatibleInvoke); if (failed(funcTy)) return failure(); - FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst); - + FlatSymbolRefAttr calleeName = nullptr; + if (isIncompatibleInvoke) { + // Use an indirect invoke (in order to represent valid and verifiable LLVM + // IR). Build the indirect invoke by passing an empty `callee` operand and + // insert into `operands` to include the indirect invoke target. + FlatSymbolRefAttr calleeSym = convertCalleeName(invokeInst); + Value indirectInvokeVal = builder.create<LLVM::AddressOfOp>( + loc, LLVM::LLVMPointerType::get(context), calleeSym); + operands->insert(operands->begin(), indirectInvokeVal); + } else { + // Regular direct invoke using callee name. + calleeName = convertCalleeName(invokeInst); + } // Create the invoke operation. Normal destination block arguments will be // added later on to handle the case in which the operation result is // included in this list. @@ -2254,8 +2290,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (failed(convertInvokeAttributes(invokeInst, invokeOp))) return failure(); - // Handle parameter and result attributes. - convertParameterAttributes(invokeInst, invokeOp, builder); + // Handle parameter and result attributes unless it's an incompatible + // invoke. + if (!isIncompatibleInvoke) + convertParameterAttributes(invokeInst, invokeOp, builder); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); |