aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-06-22 14:54:21 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-06-23 11:13:03 +0200
commit0ae7616116a9e31171de4b7fb98c18b4c0c92b68 (patch)
tree70435cc2618afcb8e962f0f640021e6cfe8ddc06
parent70c8b9c24a7cf2b7c6e65675cbdb42a65ff668ba (diff)
downloadllvm-0ae7616116a9e31171de4b7fb98c18b4c0c92b68.zip
llvm-0ae7616116a9e31171de4b7fb98c18b4c0c92b68.tar.gz
llvm-0ae7616116a9e31171de4b7fb98c18b4c0c92b68.tar.bz2
[mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer handling
Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention was enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp53
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp29
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp22
3 files changed, 17 insertions, 87 deletions
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7442366..efb8046 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
- ConversionPatternRewriter &rewriter, Location loc,
- const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
- TypeRange oldArgTypes) {
- if (funcOp.getBody().empty())
- return;
-
- // Promote bare pointers from memref arguments to memref descriptors at the
- // beginning of the function so that all the memrefs in the function have a
- // uniform representation.
- Block *entryBlock = &funcOp.getBody().front();
- auto blockArgs = entryBlock->getArguments();
- assert(blockArgs.size() == oldArgTypes.size() &&
- "The number of arguments and types doesn't match");
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(entryBlock);
- for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
- BlockArgument arg = std::get<0>(it);
- Type argTy = std::get<1>(it);
-
- // Unranked memrefs are not supported in the bare pointer calling
- // convention. We should have bailed out before in the presence of
- // unranked memrefs.
- assert(!isa<UnrankedMemRefType>(argTy) &&
- "Unranked memref is not supported");
- auto memrefTy = dyn_cast<MemRefType>(argTy);
- if (!memrefTy)
- continue;
-
- // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
- // or unranked memref descriptor and replace placeholder with the last
- // instruction of the memref descriptor.
- // TODO: The placeholder is needed to avoid replacing barePtr uses in the
- // MemRef descriptor instructions. We may want to have a utility in the
- // rewriter to properly handle this use case.
- Location loc = funcOp.getLoc();
- auto placeholder = rewriter.create<LLVM::UndefOp>(
- loc, typeConverter.convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
- Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
- memrefTy, arg);
- rewriter.replaceOp(placeholder, {desc});
- }
-}
-
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
newFuncOp);
}
- } else {
- modifyFuncOpToUseBarePtrCallingConv(
- rewriter, funcOp->getLoc(), converter, newFuncOp,
- llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
}
return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 3e6fcc0..6053e34 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
&signatureConversion)))
return failure();
- // If bare memref pointers are being used, remap them back to memref
- // descriptors This must be done after signature conversion to get rid of the
- // unrealized casts.
- if (getTypeConverter()->getOptions().useBarePtrCallConv) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
- for (const auto [idx, argTy] :
- llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto memrefTy = dyn_cast<MemRefType>(argTy);
- if (!memrefTy)
- continue;
- assert(memrefTy.hasStaticShape() &&
- "Bare pointer convertion used with dynamically-shaped memrefs");
- // Use a placeholder when replacing uses of the memref argument to prevent
- // circular replacements.
- auto remapping = signatureConversion.getInputMapping(idx);
- assert(remapping && remapping->size == 1 &&
- "Type converter should produce 1-to-1 mapping for bare memrefs");
- BlockArgument newArg =
- llvmFuncOp.getBody().getArgument(remapping->inputNo);
- auto placeholder = rewriter.create<LLVM::UndefOp>(
- loc, getTypeConverter()->convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
- Value desc = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), memrefTy, newArg);
- rewriter.replaceOp(placeholder, {desc});
- }
- }
-
// Get memref type from function arguments and set the noalias to
// pointer arguments.
for (const auto [idx, argTy] :
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795..f5620a6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
- if (inputs.size() == 1)
+ if (inputs.size() == 1) {
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
+ }
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
- // TODO: bare ptr conversion could be handled here but we would need a way
- // to distinguish between FuncOp and other regions.
- if (inputs.size() == 1)
- return std::nullopt;
+ if (inputs.size() == 1) {
+ // This is a bare pointer. We allow bare pointers only for function entry
+ // blocks.
+ BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+ if (!barePtr)
+ return std::nullopt;
+ Block *block = barePtr.getOwner();
+ if (!block->isEntryBlock() ||
+ !isa<FunctionOpInterface>(block->getParentOp()))
+ return std::nullopt;
+ return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+ inputs[0]);
+ }
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
});
// Add generic source and target materializations to handle cases where