aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp37
1 files changed, 19 insertions, 18 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index b866afb..7a70533 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
assert(indices.size() == 2);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
- return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
+ return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
+ indices);
}
/// Casts the given `srcBool` into an integer of `dstType`.
@@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
value = castBoolToIntN(loc, value, dstType, builder);
} else {
if (valueBits < targetBits) {
- value = builder.create<spirv::UConvertOp>(
- loc, builder.getIntegerType(targetBits), value);
+ value = spirv::UConvertOp::create(
+ builder, loc, builder.getIntegerType(targetBits), value);
}
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
@@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
- varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
- /*initializer=*/nullptr);
+ varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
+ /*initializer=*/nullptr);
}
// Get pointer to global variable at the current scope.
@@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
- memoryAccess, alignment);
+ Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
loadOp, "failed to determine memory requirements");
auto [memoryAccess, alignment] = *memoryRequirements;
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
- memoryAccess, alignment);
+ Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!scope)
return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
- Value result = rewriter.create<spirv::AtomicAndOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- clearBitsMask);
- result = rewriter.create<spirv::AtomicOrOp>(
- loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
- storeVal);
+ Value result = spirv::AtomicAndOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+ result = spirv::AtomicOrOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
@@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
genericPtrType = typeConverter.convertType(intermediateType);
}
if (sourceSc != spirv::StorageClass::Generic) {
- result =
- rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
+ result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
+ result);
}
if (resultSc != spirv::StorageClass::Generic) {
result =
- rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
+ spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
}
rewriter.replaceOp(addrCastOp, result);
return success();