aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/CodeGen/TargetRewrite.cpp')
-rw-r--r--flang/lib/Optimizer/CodeGen/TargetRewrite.cpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index ac285b5..0776346 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -872,6 +872,14 @@ public:
}
}
+ // Count the number of arguments that have to stay in place at the end of
+ // the argument list.
+ unsigned trailingArgs = 0;
+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
+ trailingArgs =
+ func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
+ }
+
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ public:
}
}
+ // Add the argument at the end if the number of trailing arguments is 0,
+ // otherwise insert the argument at the appropriate index.
+ auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
+ unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
+ auto newArg = trailingArgs == 0
+ ? func.front().addArgument(ty, loc)
+ : func.front().insertArgument(inputIndex, ty, loc);
+ return newArg;
+ };
+
if (!func.empty()) {
// If the function has a body, then apply the fixups to the arguments and
// return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ public:
// original arguments. (Boxchar arguments.)
auto newBufArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto boxTy = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ public:
// appended after all the original arguments.
auto newProcPointerArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto tupleType = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
fir::FirOpBuilder builder(*rewriter, getModule());