aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp')
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp63
1 files changed, 32 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 1ef6ede..317bfc2 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering
Location loc = op->getLoc();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
- Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
+ Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
- auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
- mode.value(), offset);
+ auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
+ op.getValue(), mode.value(), offset);
rewriter.replaceOp(op, reduxOp->getResult(0));
return success();
@@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto predTy = IntegerType::get(rewriter.getContext(), 1);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
- Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
- Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
- Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
- loc, int32Type, thirtyTwo, adaptor.getWidth());
+ Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
+ Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
+ Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
+ Value numLeadInactiveLane = LLVM::SubOp::create(
+ rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
- Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
- numLeadInactiveLane);
+ Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
+ numLeadInactiveLane);
Value maskAndClamp;
if (op.getMode() == gpu::ShuffleMode::UP) {
// Clamp lane: `32 - activeWidth`
maskAndClamp = numLeadInactiveLane;
} else {
// Clamp lane: `activeWidth - 1`
- maskAndClamp =
- rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
+ maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
+ adaptor.getWidth(), one);
}
bool predIsUsed = !op->getResult(1).use_empty();
@@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueTy, predTy});
}
- Value shfl = rewriter.create<NVVM::ShflOp>(
- loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
- maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
+ Value shfl = NVVM::ShflOp::create(
+ rewriter, loc, resultTy, activeMask, adaptor.getValue(),
+ adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
+ returnValueAndIsValidAttr);
if (predIsUsed) {
- Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+ Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
Value isActiveSrcLane =
- rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
} else {
rewriter.replaceOp(op, {shfl, nullptr});
@@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
Value newOp =
- rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
+ NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
- newOp = rewriter.create<LLVM::SExtOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::SExtOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
- newOp = rewriter.create<LLVM::TruncOp>(
- loc, IntegerType::get(context, indexBitwidth), newOp);
+ newOp = LLVM::TruncOp::create(
+ rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
return success();
@@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
- rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
- assertBlock);
+ cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
+ assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
- rewriter.create<cf::BranchOp>(loc, afterBlock);
+ cf::BranchOp::create(rewriter, loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
@@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering
// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
- Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ Value globalPtr = LLVM::AddressOfOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
- globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
@@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
- Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
+ Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,