diff options
author | Steffi Stumpos <stumposs12@gmail.com> | 2024-07-09 15:04:33 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-09 14:04:33 -0700 |
commit | 4a010799317dfe19758477f693968fc594c1895d (patch) | |
tree | 6e37ecb47c54fdb7187b33c1074dbfd560bccfc0 | |
parent | 8d5ba7598ace4c195e673aaab2b2ccadff3ad48f (diff) | |
download | llvm-4a010799317dfe19758477f693968fc594c1895d.zip llvm-4a010799317dfe19758477f693968fc594c1895d.tar.gz llvm-4a010799317dfe19758477f693968fc594c1895d.tar.bz2 |
Expose Tail Kind Call to MLIR (#98080)
I would like to mark a call op in LLVM dialect as Musttail. The calling
convention attribute only exposes Tail, not Musttail. I noticed that the
CallInst of LLVM has an additional field to specify the flavor of tail
call kind. I bubbled this up to the LLVM dialect by adding another
attribute that maps to LLVM::CallInst::TailCallKind.
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 9 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 29 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 24 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/roundtrip.mlir | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir | 39 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/Import/tail-kind.ll | 35 |
10 files changed, 174 insertions, 7 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index b05366d..25a6ee2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1077,4 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">; /// Folded into from LLVM::ZeroOp. def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">; +//===----------------------------------------------------------------------===// +// TailCallKindAttr +//===----------------------------------------------------------------------===// + +def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> { + let parameters = (ins "TailCallKind":$TailCallKind); + let assemblyFormat = "`<` $TailCallKind `>`"; +} + #endif // LLVMIR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h index 3a93be2..3ede857 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -89,6 +89,7 @@ public: // TODO: this shouldn't be needed after we unify the attribute generation, i.e. // --gen-attr-* and --gen-attrdef-*. using cconv::CConv; +using tailcallkind::TailCallKind; using linkage::Linkage; } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index f8e8500..f41a97f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -280,6 +280,35 @@ def CConv : DialectAttr< } //===----------------------------------------------------------------------===// +// TailCallKind +//===----------------------------------------------------------------------===// + +def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>; +def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>; +def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>; +def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>; + +def TailCallKindEnum : LLVM_EnumAttr< + "TailCallKind", + "::llvm::CallInst::TailCallKind", + "Tail Call Kind", + [TailCallKindNone, TailCallKindNoTailCall, + TailCallKindMustTail, TailCallKindTail]> { + let cppNamespace = "::mlir::LLVM::tailcallkind"; +} + +def TailCallKind : DialectAttr< + LLVM_Dialect, + CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">, + "LLVM Calling Convention specification"> { + let storageType = "::mlir::LLVM::TailCallKindAttr"; + let returnType = "::mlir::LLVM::tailcallkind::TailCallKind"; + let convertFromStorage = "$_self.getTailCallKind()"; + let constBuilderCall = + "::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)"; +} + +//===----------------------------------------------------------------------===// // DIEmissionKind //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 54f38c9..65dfcf9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags, OptionalAttr<DenseI32ArrayAttr>:$branch_weights, - DefaultValuedAttr<CConv, "CConv::C">:$CConv); + DefaultValuedAttr<CConv, "CConv::C">:$CConv, + DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional<LLVM_Type>:$result); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index a01c4ee4..9372caf 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -44,6 +44,7 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::cconv::getMaxEnumValForCConv; using mlir::LLVM::linkage::getMaxEnumValForLinkage; +using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind; #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" @@ -197,6 +198,7 @@ struct EnumTraits {}; REGISTER_ENUM_TYPE(Linkage); REGISTER_ENUM_TYPE(UnnamedAddr); REGISTER_ENUM_TYPE(CConv); +REGISTER_ENUM_TYPE(TailCallKind); REGISTER_ENUM_TYPE(Visibility); } // namespace @@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, build(builder, state, results, TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)), callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, - /*CConv=*/nullptr, + /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, /*CConv=*/nullptr, - /*access_groups=*/nullptr, + /*TailCallKind=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), /*callee=*/nullptr, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, - /*CConv=*/nullptr, + /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), SymbolRefAttr::get(func), args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, - /*CConv=*/nullptr, + /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) { if (getCConv() != LLVM::CConv::C) p << stringifyCConv(getCConv()) << ' '; + if(getTailCallKind() != LLVM::TailCallKind::None) + p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' '; + // Print the direct callee if present as a function attribute, or an indirect // callee (first operand) otherwise. if (isDirect) @@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) { p << " vararg(" << calleeType << ")"; p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), - {getCConvAttrName(), "callee", "callee_type"}); + {getCConvAttrName(), "callee", "callee_type", + getTailCallKindAttrName()}); p << " : "; if (!isDirect) @@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr( return success(); } -// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use) +// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use) // `(` ssa-use-list `)` // ( `vararg(` var-arg-func-type `)` )? // attribute-dict? `:` (type `,`)? function-type @@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( parser, result, LLVM::CConv::C))); + result.addAttribute( + getTailCallKindAttrName(result.name), + TailCallKindAttr::get(parser.getContext(), + parseOptionalLLVMKeyword<TailCallKind>( + parser, result, LLVM::TailCallKind::None))); + // Parse a function pointer for indirect calls. if (parseOptionalCallFuncPtr(parser, operands)) return failure(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index f144c71..3d6dd12 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front()); } call->setCallingConv(convertCConvToLLVM(callOp.getCConv())); + call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind())); moduleTranslation.setAccessGroupsMetadata(callOp, call); moduleTranslation.setAliasScopeMetadata(callOp, call); moduleTranslation.setTBAAMetadata(callOp, call); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 0c8b329..9915576 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1468,6 +1468,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { callOp = builder.create<CallOp>(loc, funcTy, operands); } callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); + callOp.setTailCallKind( + convertTailCallKindFromLLVM(callInst->getTailCallKind())); setFastmathFlagsAttr(inst, callOp); if (!callInst->getType()->isVoidTy()) mapValue(inst, callOp.getResult()); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 2386dde..ca9748a 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) { %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32 llvm.return } + +// CHECK: llvm.func @tail_call_target() -> i32 +llvm.func @tail_call_target() -> i32 + +// CHECK-LABEL: @test_none +llvm.func @test_none() -> i32 { + // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32 + %0 = llvm.call none @tail_call_target() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_default +llvm.func @test_default() -> i32 { + // CHECK-NEXT: llvm.call @tail_call_target() : () -> i32 + %0 = llvm.call @tail_call_target() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_musttail +llvm.func @test_musttail() -> i32 { + // CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32 + %0 = llvm.call musttail @tail_call_target() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_tail +llvm.func @test_tail() -> i32 { + // CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32 + %0 = llvm.call tail @tail_call_target() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_notail +llvm.func @test_notail() -> i32 { + // CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32 + %0 = llvm.call notail @tail_call_target() : () -> i32 + llvm.return %0 : i32 +} diff --git a/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir new file mode 100644 index 0000000..73a6aa2 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/tail-call-kinds.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// CHECK: declare i32 @foo() +llvm.func @foo() -> i32 + +// CHECK-LABEL: @test_none +llvm.func @test_none() -> i32 { + // CHECK-NEXT: call i32 @foo() + %0 = llvm.call none @foo() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_default +llvm.func @test_default() -> i32 { + // CHECK-NEXT: call i32 @foo() + %0 = llvm.call @foo() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_musttail +llvm.func @test_musttail() -> i32 { + // CHECK-NEXT: musttail call i32 @foo() + %0 = llvm.call musttail @foo() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_tail +llvm.func @test_tail() -> i32 { + // CHECK-NEXT: tail call i32 @foo() + %0 = llvm.call tail @foo() : () -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @test_notail +llvm.func @test_notail() -> i32 { + // CHECK-NEXT: notail call i32 @foo() + %0 = llvm.call notail @foo() : () -> i32 + llvm.return %0 : i32 +} diff --git a/mlir/test/Target/LLVMIR/Import/tail-kind.ll b/mlir/test/Target/LLVMIR/Import/tail-kind.ll new file mode 100644 index 0000000..608ae40 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/tail-kind.ll @@ -0,0 +1,35 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK: llvm.func @tailkind() +declare void @tailkind() + +; CHECK-LABEL: @call_tailkind +define void @call_tailkind() { + ; CHECK: llvm.call musttail @tailkind() + musttail call void @tailkind() + ret void +} + +; // ----- + +; CHECK: llvm.func @tailkind() +declare void @tailkind() + +; CHECK-LABEL: @call_tailkind +define void @call_tailkind() { + ; CHECK: llvm.call tail @tailkind() + tail call void @tailkind() + ret void +} + +; // ----- + +; CHECK: llvm.func @tailkind() +declare void @tailkind() + +; CHECK-LABEL: @call_tailkind +define void @call_tailkind() { + ; CHECK: llvm.call notail @tailkind() + notail call void @tailkind() + ret void +} |