aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp65
1 files changed, 34 insertions, 31 deletions
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 78d1327..dc92367 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- loc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
- Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
+ Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
@@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// a single back edge from the continue to header block, and a single exit
// from header to merge.
auto loc = forOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
@@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
- rewriter.create<spirv::BranchOp>(loc, header, args);
+ spirv::BranchOp::create(rewriter, loc, header, args);
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = rewriter.create<spirv::SLessThanOp>(
- loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
+ auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
- rewriter.create<spirv::BranchConditionalOp>(
- loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
+ ArrayRef<Value>(), mergeBlock,
+ ArrayRef<Value>());
// Generate instructions to increment the step of the induction variable and
// branch to the header.
@@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
rewriter.setInsertionPointToEnd(continueBlock);
// Add the step to the induction variable and branch to the header.
- Value updatedIndVar = rewriter.create<spirv::IAddOp>(
- loc, newIndVar.getType(), newIndVar, adaptor.getStep());
- rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
+ Value updatedIndVar = spirv::IAddOp::create(
+ rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
+ spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
// Infer the return types from the init operands. Vector type may get
// converted to CooperativeMatrix or to Vector type, to avoid having complex
@@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
// Create `spirv.selection` operation, selection header block and merge
// block.
- auto selectionOp =
- rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto selectionOp = spirv::SelectionOp::create(
+ rewriter, loc, spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
- rewriter.create<spirv::MergeOp>(loc);
+ spirv::MergeOp::create(rewriter, loc);
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
@@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &thenRegion = ifOp.getThenRegion();
auto *thenBlock = &thenRegion.front();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
auto *elseBlock = mergeBlock;
@@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
auto &elseRegion = ifOp.getElseRegion();
elseBlock = &elseRegion.front();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- rewriter.create<spirv::BranchOp>(loc, mergeBlock);
+ spirv::BranchOp::create(rewriter, loc, mergeBlock);
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
}
// Create a `spirv.BranchConditional` operation for selection header block.
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
- rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
- thenBlock, ArrayRef<Value>(),
- elseBlock, ArrayRef<Value>());
+ spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
+ thenBlock, ArrayRef<Value>(), elseBlock,
+ ArrayRef<Value>());
replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
returnTypes);
@@ -310,7 +312,7 @@ public:
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
- rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
+ spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(parent)) {
// For loops we also need to update the branch jumping back to the
// header.
@@ -319,8 +321,8 @@ public:
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
- rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
- args);
+ spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
+ args);
rewriter.eraseOp(br);
}
}
@@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ auto loopOp =
+ spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
@@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Jump from the loop entry block to the loop header block.
rewriter.setInsertionPointToEnd(&entryBlock);
- rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
+ spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
auto condLoc = cond.getLoc();
@@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
// Create local variables before the scf.while op.
rewriter.setInsertionPoint(loopOp);
- auto alloc = rewriter.create<spirv::VariableOp>(
- condLoc, pointerType, spirv::StorageClass::Function,
- /*initializer=*/nullptr);
+ auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
// Load the final result values after the scf.while op.
rewriter.setInsertionPointAfter(loopOp);
- auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
+ auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
resultValues[i] = loadResult;
// Store the current iteration's result value.
rewriter.setInsertionPointToEnd(&beforeBlock);
- rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
+ spirv::StoreOp::create(rewriter, condLoc, alloc, res);
}
rewriter.setInsertionPointToEnd(&beforeBlock);