aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp110
1 files changed, 57 insertions, 53 deletions
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b99ed26..1817861 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value vector =
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
- Value dim = rewriter.create<spirv::CompositeExtractOp>(
- op.getLoc(), builtinType, vector,
+ Value dim = spirv::CompositeExtractOp::create(
+ rewriter, op.getLoc(), builtinType, vector,
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
if (forShader && builtinType != indexType)
- dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
+ dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
rewriter.replaceOp(op, dim);
return success();
}
@@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value builtinValue =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
if (i32Type != indexType)
- builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
- builtinValue);
+ builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
+ builtinValue);
rewriter.replaceOp(op, builtinValue);
return success();
}
@@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
- auto newFuncOp = rewriter.create<spirv::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
+ auto newFuncOp = spirv::FuncOp::create(
+ rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
@@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
- auto spvModule = rewriter.create<spirv::ModuleOp>(
- moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
+ auto spvModule = spirv::ModuleOp::create(
+ rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
@@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR: {
- result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleXorOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::IDX: {
- result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
- result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleDownOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- resultLaneId, adaptor.getWidth());
+ arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
- result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset());
+ result = spirv::GroupNonUniformShuffleUpOp::create(
+ rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
- rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
+ arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
- validVal = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, resultLaneId,
- rewriter.create<arith::ConstantOp>(
- loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
+ validVal = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
+ arith::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}
@@ -507,24 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
- Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
- loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
+ Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
- Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
- validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
+ validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -548,18 +551,18 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
? spirv::GroupOperation::ClusteredReduce
: spirv::GroupOperation::Reduce);
if (isUniform) {
- return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
+ return UniformOp::create(builder, loc, type, scope, groupOp, arg)
.getResult();
}
Value clusterSizeValue;
if (clusterSize.has_value())
- clusterSizeValue = builder.create<spirv::ConstantOp>(
- loc, builder.getI32Type(),
+ clusterSizeValue = spirv::ConstantOp::create(
+ builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
- return builder
- .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
+ return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
+ clusterSizeValue)
.getResult();
}
@@ -740,8 +743,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
- return rewriter.create<spirv::SpecConstantOp>(
- loc, rewriter.getStringAttr(specCstName), attr);
+ return spirv::SpecConstantOp::create(
+ rewriter, loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
@@ -774,8 +777,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();
- specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
- loc, TypeAttr::get(globalType),
+ specCstComposite = spirv::SpecConstantCompositeOp::create(
+ rewriter, loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));
@@ -785,23 +788,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
- globalVar = rewriter.create<spirv::GlobalVariableOp>(
- loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
+ globalVar = spirv::GlobalVariableOp::create(
+ rewriter, loc, ptrType, globalVarName,
+ FlatSymbolRefAttr::get(specCstComposite));
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
- Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
- Value fmtStr = rewriter.create<spirv::BitcastOp>(
- loc,
+ Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
+ Value fmtStr = spirv::BitcastOp::create(
+ rewriter, loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);
// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
- rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+ spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V