diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index e79a02f..6a005e6 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); - for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { - BlockArgument attribution = en.value(); - + for (const auto [idx, attribution] : + llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { auto type = dyn_cast<MemRefType>(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, auto elementType = cast<Type>(typeConverter->convertType(type.getElementType())); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); - std::string name = std::string( - llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); + std::string name = + std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx)); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()))) + idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); auto globalOp = rewriter.create<LLVM::GlobalOp>( gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, @@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); - for (const auto &en : llvm::enumerate(workgroupBuffers)) { - LLVM::GlobalOp global = en.value(); + for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); Value address = rewriter.create<LLVM::AddressOfOp>( @@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, // existing memref infrastructure. This may use more registers than // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. - Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; + Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; auto type = cast<MemRefType>(attribution.getType()); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); - signatureConversion.remapInput(numProperArguments + en.index(), descr); + signatureConversion.remapInput(numProperArguments + idx, descr); } // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); - for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { - Value attribution = en.value(); + for (const auto [idx, attribution] : + llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { auto type = cast<MemRefType>(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); @@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()))) + idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); Value allocated = rewriter.create<LLVM::AllocaOp>( gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( - numProperArguments + numWorkgroupAttributions + en.index(), descr); + numProperArguments + numWorkgroupAttributions + idx, descr); } } @@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, if (getTypeConverter()->getOptions().useBarePtrCallConv) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); - for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto memrefTy = dyn_cast<MemRefType>(en.value()); + for (const auto [idx, argTy] : + llvm::enumerate(gpuFuncOp.getArgumentTypes())) { + auto memrefTy = dyn_cast<MemRefType>(argTy); if (!memrefTy) continue; assert(memrefTy.hasStaticShape() && "Bare pointer convertion used with dynamically-shaped memrefs"); // Use a placeholder when replacing uses of the memref argument to prevent // circular replacements. - auto remapping = signatureConversion.getInputMapping(en.index()); + auto remapping = signatureConversion.getInputMapping(idx); assert(remapping && remapping->size == 1 && "Type converter should produce 1-to-1 mapping for bare memrefs"); BlockArgument newArg = @@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, // Get memref type from function arguments and set the noalias to // pointer arguments. - for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto memrefTy = en.value().dyn_cast<MemRefType>(); - NamedAttrList argAttr = argAttrs - ? argAttrs[en.index()].cast<DictionaryAttr>() - : NamedAttrList(); - + for (const auto [idx, argTy] : + llvm::enumerate(gpuFuncOp.getArgumentTypes())) { + auto remapping = signatureConversion.getInputMapping(idx); + NamedAttrList argAttr = + argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList(); + auto copyAttribute = [&](StringRef attrName) { + Attribute attr = argAttr.erase(attrName); + if (!attr) + return; + for (size_t i = 0, e = remapping->size; i < e; ++i) + llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); + }; auto copyPointerAttribute = [&](StringRef attrName) { Attribute attr = argAttr.erase(attrName); - // This is a proxy for the bare pointer calling convention. if (!attr) return; - auto remapping = signatureConversion.getInputMapping(en.index()); if (remapping->size > 1 && attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { emitWarning(llvmFuncOp.getLoc(), @@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, if (argAttr.empty()) continue; - if (memrefTy) { + copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); + copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); + copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); + bool lowersToPointer = false; + for (size_t i = 0, e = remapping->size; i < e; ++i) { + lowersToPointer |= isa<LLVM::LLVMPointerType>( + llvmFuncOp.getArgument(remapping->inputNo + i).getType()); + } + + if (lowersToPointer) { copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); + copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); + copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); + copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); + copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); copyPointerAttribute( |