diff options
author | Wang Pengcheng <wangpengcheng.pp@bytedance.com> | 2024-04-17 21:47:29 +0800 |
---|---|---|
committer | Wang Pengcheng <wangpengcheng.pp@bytedance.com> | 2024-04-17 21:47:29 +0800 |
commit | 7d1b70a3960cdacfa4d7531531a9a921dadd3d88 (patch) | |
tree | ac207bdf9fcb3c3656deafb734acd8f03b40c8db /mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | |
parent | c9d96c0d77b67c208aaf7f8f2554f972baa412d2 (diff) | |
parent | e77f6742143d71161f3f1161270648c9b95b2137 (diff) | |
download | llvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.zip llvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.tar.gz llvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.tar.bz2 |
Rebase on #88868users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns
Created using spr 1.3.6-beta.1
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 6150b5e..2024a2e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase workgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), DenseI32ArrayAttr(), - entryPointAttr.getSubgroupSize()); + entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth()); } } if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) { @@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase subgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), - std::nullopt); + std::nullopt, entryPointAttr.getTargetWidth()); } } - if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize()) + if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) { + std::optional<ArrayRef<spirv::Capability>> caps = + spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); + if (!caps || targetEnv.allows(*caps)) { + builder.create<spirv::ExecutionModeOp>( + funcOp.getLoc(), funcOp, + spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); + // Erase target width. + entryPointAttr = spirv::EntryPointABIAttr::get( + entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), + entryPointAttr.getSubgroupSize(), std::nullopt); + } + } + if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() || + entryPointAttr.getTargetWidth()) funcOp->setAttr(entryPointAttrName, entryPointAttr); else funcOp->removeAttr(entryPointAttrName); |