diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 110 |
1 files changed, 57 insertions, 53 deletions
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index b99ed26..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( Value vector = spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); - Value dim = rewriter.create<spirv::CompositeExtractOp>( - op.getLoc(), builtinType, vector, + Value dim = spirv::CompositeExtractOp::create( + rewriter, op.getLoc(), builtinType, vector, rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())})); if (forShader && builtinType != indexType) - dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim); + dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim); rewriter.replaceOp(op, dim); return success(); } @@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( Value builtinValue = spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); if (i32Type != indexType) - builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, - builtinValue); + builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, + builtinValue); rewriter.replaceOp(op, builtinValue); return success(); } @@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.create<spirv::FuncOp>( - funcOp.getLoc(), funcOp.getName(), + auto newFuncOp = spirv::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {})); for (const auto &namedAttr : funcOp->getAttrs()) { if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || @@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite( // Add a keyword to the module name to avoid symbolic conflict. std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); - auto spvModule = rewriter.create<spirv::ModuleOp>( - moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, + auto spvModule = spirv::ModuleOp::create( + rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, StringRef(spvModuleName)); // Move the region from the module op into the SPIR-V module. @@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( switch (shuffleOp.getMode()) { case gpu::ShuffleMode::XOR: { - result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleXorOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::IDX: { - result = rewriter.create<spirv::GroupNonUniformShuffleOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::DOWN: { - result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleDownOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset()); - validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, - resultLaneId, adaptor.getWidth()); + arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset()); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + resultLaneId, adaptor.getWidth()); break; } case gpu::ShuffleMode::UP: { - result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleUpOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset()); + arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); - validVal = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, resultLaneId, - rewriter.create<arith::ConstantOp>( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); + validVal = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, resultLaneId, + arith::ConstantOp::create(rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, 0))); break; } } @@ -507,24 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); - Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>( - loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth()); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); + Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); - validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + IntegerAttr widthAttr = adaptor.getWidthAttr(); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -548,18 +551,18 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, ? spirv::GroupOperation::ClusteredReduce : spirv::GroupOperation::Reduce); if (isUniform) { - return builder.create<UniformOp>(loc, type, scope, groupOp, arg) + return UniformOp::create(builder, loc, type, scope, groupOp, arg) .getResult(); } Value clusterSizeValue; if (clusterSize.has_value()) - clusterSizeValue = builder.create<spirv::ConstantOp>( - loc, builder.getI32Type(), + clusterSizeValue = spirv::ConstantOp::create( + builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } @@ -740,8 +743,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstName = makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); - return rewriter.create<spirv::SpecConstantOp>( - loc, rewriter.getStringAttr(specCstName), attr); + return spirv::SpecConstantOp::create( + rewriter, loc, rewriter.getStringAttr(specCstName), attr); }; { Operation *parent = @@ -774,8 +777,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstCompositeName = (llvm::Twine(globalVarName) + "_scc").str(); - specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>( - loc, TypeAttr::get(globalType), + specCstComposite = spirv::SpecConstantCompositeOp::create( + rewriter, loc, TypeAttr::get(globalType), rewriter.getStringAttr(specCstCompositeName), rewriter.getArrayAttr(constituents)); @@ -785,23 +788,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( // Define a GlobalVarOp initialized using specialized constants // that is used to specify the printf format string // to be passed to the SPIRV CLPrintfOp. - globalVar = rewriter.create<spirv::GlobalVariableOp>( - loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); + globalVar = spirv::GlobalVariableOp::create( + rewriter, loc, ptrType, globalVarName, + FlatSymbolRefAttr::get(specCstComposite)); globalVar->setAttr("Constant", rewriter.getUnitAttr()); } // Get SSA value of Global variable and create pointer to i8 to point to // the format string. - Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar); - Value fmtStr = rewriter.create<spirv::BitcastOp>( - loc, + Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar); + Value fmtStr = spirv::BitcastOp::create( + rewriter, loc, spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), globalPtr); // Get printf arguments. auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs()); - rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs); + spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs); // Need to erase the gpu.printf op as gpu.printf does not use result vs // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V |