diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 6150b5e..2024a2e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase workgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), DenseI32ArrayAttr(), - entryPointAttr.getSubgroupSize()); + entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth()); } } if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) { @@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase subgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), - std::nullopt); + std::nullopt, entryPointAttr.getTargetWidth()); } } - if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize()) + if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) { + std::optional<ArrayRef<spirv::Capability>> caps = + spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); + if (!caps || targetEnv.allows(*caps)) { + builder.create<spirv::ExecutionModeOp>( + funcOp.getLoc(), funcOp, + spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); + // Erase target width. + entryPointAttr = spirv::EntryPointABIAttr::get( + entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), + entryPointAttr.getSubgroupSize(), std::nullopt); + } + } + if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() || + entryPointAttr.getTargetWidth()) funcOp->setAttr(entryPointAttrName, entryPointAttr); else funcOp->removeAttr(entryPointAttrName); |