aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
diff options
context:
space:
mode:
authorWang Pengcheng <wangpengcheng.pp@bytedance.com>2024-04-17 21:47:29 +0800
committerWang Pengcheng <wangpengcheng.pp@bytedance.com>2024-04-17 21:47:29 +0800
commit7d1b70a3960cdacfa4d7531531a9a921dadd3d88 (patch)
treeac207bdf9fcb3c3656deafb734acd8f03b40c8db /mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
parentc9d96c0d77b67c208aaf7f8f2554f972baa412d2 (diff)
parente77f6742143d71161f3f1161270648c9b95b2137 (diff)
downloadllvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.zip
llvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.tar.gz
llvm-users/wangpc-pp/spr/riscv-dont-use-v0-directly-in-patterns.tar.bz2
Created using spr 1.3.6-beta.1
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp20
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 6150b5e..2024a2e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
// Erase workgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), DenseI32ArrayAttr(),
- entryPointAttr.getSubgroupSize());
+ entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
}
}
if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
@@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
// Erase subgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
- std::nullopt);
+ std::nullopt, entryPointAttr.getTargetWidth());
}
}
- if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize())
+ if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
+ std::optional<ArrayRef<spirv::Capability>> caps =
+ spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
+ if (!caps || targetEnv.allows(*caps)) {
+ builder.create<spirv::ExecutionModeOp>(
+ funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
+ // Erase target width.
+ entryPointAttr = spirv::EntryPointABIAttr::get(
+ entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
+ entryPointAttr.getSubgroupSize(), std::nullopt);
+ }
+ }
+ if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
+ entryPointAttr.getTargetWidth())
funcOp->setAttr(entryPointAttrName, entryPointAttr);
else
funcOp->removeAttr(entryPointAttrName);