diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 63 |
1 files changed, 32 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 1ef6ede..317bfc2 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering Location loc = op->getLoc(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); + Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); - auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(), - mode.value(), offset); + auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type, + op.getValue(), mode.value(), offset); rewriter.replaceOp(op, reduxOp->getResult(0)); return success(); @@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1); - Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); - Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32); - Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>( - loc, int32Type, thirtyTwo, adaptor.getWidth()); + Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); + Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); + Value numLeadInactiveLane = LLVM::SubOp::create( + rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne, - numLeadInactiveLane); + Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, + numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { // Clamp lane: `32 - activeWidth` maskAndClamp = numLeadInactiveLane; } else { // Clamp lane: `activeWidth - 1` - maskAndClamp = - rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one); + maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type, + adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = rewriter.create<NVVM::ShflOp>( - loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), - maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); + Value shfl = NVVM::ShflOp::create( + rewriter, loc, resultTy, activeMask, adaptor.getValue(), + adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), + returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0); + Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); Value isActiveSrcLane = - rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1); + LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> { bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds); + NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = rewriter.create<LLVM::SExtOp>( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = rewriter.create<LLVM::TruncOp>( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); return success(); @@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering Block *afterBlock = rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); rewriter.setInsertionPointToEnd(beforeBlock); - rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, - assertBlock); + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock, + assertBlock); rewriter.setInsertionPointToEnd(assertBlock); - rewriter.create<cf::BranchOp>(loc, afterBlock); + cf::BranchOp::create(rewriter, loc, afterBlock); // Continue cf.assert lowering. rewriter.setInsertionPoint(assertOp); @@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering // Create constants. auto getGlobal = [&](LLVM::GlobalOp global) { // Get a pointer to the format string's first element. - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), global.getSymNameAttr()); Value start = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); return start; }; Value assertMessage = getGlobal(getOrCreateStringConstant( @@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering Value assertFunc = getGlobal(getOrCreateStringConstant( rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); Value assertLine = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine); - Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1); + LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1); // Insert function call to __assertfail. SmallVector<Value> arguments{assertMessage, assertFile, assertLine, |