diff options
Diffstat (limited to 'mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index b866afb..7a70533 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, assert(indices.size() == 2); indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); - return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); + return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(), + indices); } /// Casts the given `srcBool` into an integer of `dstType`. @@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, value = castBoolToIntN(loc, value, dstType, builder); } else { if (valueBits < targetBits) { - value = builder.create<spirv::UConvertOp>( - loc, builder.getIntegerType(targetBits), value); + value = spirv::UConvertOp::create( + builder, loc, builder.getIntegerType(targetBits), value); } value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); @@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); - varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, - /*initializer=*/nullptr); + varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName, + /*initializer=*/nullptr); } // Get pointer to global variable at the current scope. @@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, - memoryAccess, alignment); + Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain, + memoryAccess, alignment); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); rewriter.replaceOp(loadOp, loadVal); @@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, - memoryAccess, alignment); + Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr, + memoryAccess, alignment); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX @@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, if (!scope) return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); - Value result = rewriter.create<spirv::AtomicAndOp>( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - clearBitsMask); - result = rewriter.create<spirv::AtomicOrOp>( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - storeVal); + Value result = spirv::AtomicAndOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, clearBitsMask); + result = spirv::AtomicOrOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() @@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( genericPtrType = typeConverter.convertType(intermediateType); } if (sourceSc != spirv::StorageClass::Generic) { - result = - rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); + result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType, + result); } if (resultSc != spirv::StorageClass::Generic) { result = - rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); + spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result); } rewriter.replaceOp(addrCastOp, result); return success(); |