aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp133
1 files changed, 68 insertions, 65 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 01ca5e9..1037e29 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
- ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
@@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
+ return LLVM::GlobalOp::create(b, loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
}
LogicalResult
@@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
gpuFuncOp.getWorkgroupAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- auto globalOp = rewriter.create<LLVM::GlobalOp>(
- gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
+ auto globalOp = LLVM::GlobalOp::create(
+ rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
workgroupAddrSpace);
workgroupBuffers.push_back(globalOp);
@@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
? kernelCallingConvention
: nonKernelCallingConvention;
- auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
+ auto llvmFuncOp = LLVM::LLVMFuncOp::create(
+ rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
/*comdat=*/nullptr, attributes);
@@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
global.getAddrSpace());
- Value address = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, global.getSymNameAttr());
+ Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
+ global.getSymNameAttr());
Value memory =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
- address, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
+ address, ArrayRef<LLVM::GEPArg>{0, 0});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
Type elementType = typeConverter->convertType(type.getElementType());
auto ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
+ Value numElements = LLVM::ConstantOp::create(
+ rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
- Value allocated = rewriter.create<LLVM::AllocaOp>(
- gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
+ Value allocated =
+ LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
+ elementType, numElements, alignment);
Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
@@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
{llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
- Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
- auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
+ Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
+ auto printfBeginCall =
+ LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
@@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
- Value stringLen = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringLen = LLVM::ConstantOp::create(
+ rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
- Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
- Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
+ Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
+ Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
- auto appendFormatCall = rewriter.create<LLVM::CallOp>(
- loc, ocklAppendStringN,
+ auto appendFormatCall = LLVM::CallOp::create(
+ rewriter, loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.getArgs().empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
@@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
- rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
+ LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.getArgs()[i];
if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
if (!floatType.isF64())
- arg = rewriter.create<LLVM::FPExtOp>(
- loc, typeConverter->convertType(rewriter.getF64Type()), arg);
- arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
+ arg = LLVM::FPExtOp::create(
+ rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
+ arg);
+ arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
- arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
+ arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
arguments.push_back(arg);
}
@@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
- auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
+ auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
@@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
/*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc,
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
"printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
Value stringStart =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
@@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
assert(type.isIntOrFloat());
if (isa<FloatType>(type)) {
type = rewriter.getF64Type();
- promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
+ promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
}
types.push_back(type);
args.push_back(promotedArg);
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- rewriter.getIndexAttr(1));
+ Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(1));
Value tempAlloc =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
- /*alignment=*/0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
+ /*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
- Value ptr = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, structType, tempAlloc,
+ Value ptr = LLVM::GEPOp::create(
+ rewriter, loc, ptrType, structType, tempAlloc,
ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
- rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
+ LLVM::StoreOp::create(rewriter, loc, arg, ptr);
}
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
- rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
+ LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
TypeRange operandTypes(operands);
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
- Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
+ Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
- Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
+ Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
auto extractElement = [&](Value operand) -> Value {
if (!isa<VectorType>(operand.getType()))
return operand;
- return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
+ return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
};
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
Operation *scalarOp =
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
- result = rewriter.create<LLVM::InsertElementOp>(
- loc, result, scalarOp->getResult(0), index);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ scalarOp->getResult(0), index);
}
return result;
}
@@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
- return rewriter.create<LLVM::GlobalOp>(
- op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace.value());
+ return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
+ /*isConstant=*/false, LLVM::Linkage::Internal,
+ symName, /*value=*/Attribute(), alignmentByte,
+ addressSpace.value());
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
@@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
- auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
Type baseType = basePtr->getResultTypes().front();
// Step 4. Generate GEP using offsets
SmallVector<LLVM::GEPArg> gepArgs = {0};
- Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
- basePtr, gepArgs);
+ Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
+ basePtr, gepArgs);
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
@@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
- Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
+ Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
- packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());