diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 65 |
1 files changed, 34 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 78d1327..dc92367 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, auto pointerType = spirv::PointerType::get(convertedType, spirv::StorageClass::Function); rewriter.setInsertionPoint(newOp); - auto alloc = rewriter.create<spirv::VariableOp>( - loc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); allocas.push_back(alloc); rewriter.setInsertionPointAfter(newOp); - Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); + Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc); resultValue.push_back(loadResult); } rewriter.replaceOp(scfOp, resultValue); @@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { // a single back edge from the continue to header block, and a single exit // from header to merge. auto loc = forOp.getLoc(); - auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); @@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); - rewriter.create<spirv::BranchOp>(loc, header, args); + spirv::BranchOp::create(rewriter, loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create<spirv::SLessThanOp>( - loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); + auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); - rewriter.create<spirv::BranchConditionalOp>( - loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); + spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, + ArrayRef<Value>(), mergeBlock, + ArrayRef<Value>()); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create<spirv::IAddOp>( - loc, newIndVar.getType(), newIndVar, adaptor.getStep()); - rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); + Value updatedIndVar = spirv::IAddOp::create( + rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep()); + spirv::BranchOp::create(rewriter, loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get // converted to CooperativeMatrix or to Vector type, to avoid having complex @@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { // Create `spirv.selection` operation, selection header block and merge // block. - auto selectionOp = - rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); + auto selectionOp = spirv::SelectionOp::create( + rewriter, loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); - rewriter.create<spirv::MergeOp>(loc); + spirv::MergeOp::create(rewriter, loc); OpBuilder::InsertionGuard guard(rewriter); auto *selectionHeaderBlock = @@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { auto &thenRegion = ifOp.getThenRegion(); auto *thenBlock = &thenRegion.front(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create<spirv::BranchOp>(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(thenRegion, mergeBlock); auto *elseBlock = mergeBlock; @@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { auto &elseRegion = ifOp.getElseRegion(); elseBlock = &elseRegion.front(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create<spirv::BranchOp>(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(elseRegion, mergeBlock); } // Create a `spirv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(), - thenBlock, ArrayRef<Value>(), - elseBlock, ArrayRef<Value>()); + spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(), + thenBlock, ArrayRef<Value>(), elseBlock, + ArrayRef<Value>()); replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, returnTypes); @@ -310,7 +312,7 @@ public: auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) - rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); + spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]); if (isa<spirv::LoopOp>(parent)) { // For loops we also need to update the branch jumping back to the // header. @@ -319,8 +321,8 @@ public: SmallVector<Value, 8> args(br.getBlockArguments()); args.append(operands.begin(), operands.end()); rewriter.setInsertionPoint(br); - rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), - args); + spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(), + args); rewriter.eraseOp(br); } } @@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); - auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); @@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { // Jump from the loop entry block to the loop header block. rewriter.setInsertionPointToEnd(&entryBlock); - rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits()); + spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits()); auto condLoc = cond.getLoc(); @@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { // Create local variables before the scf.while op. rewriter.setInsertionPoint(loopOp); - auto alloc = rewriter.create<spirv::VariableOp>( - condLoc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); // Load the final result values after the scf.while op. rewriter.setInsertionPointAfter(loopOp); - auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc); + auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc); resultValues[i] = loadResult; // Store the current iteration's result value. rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.create<spirv::StoreOp>(condLoc, alloc, res); + spirv::StoreOp::create(rewriter, condLoc, alloc, res); } rewriter.setInsertionPointToEnd(&beforeBlock); |