aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakub Kuderski <jakub@nod-labs.com>2024-03-28 14:13:04 -0400
committerGitHub <noreply@github.com>2024-03-28 14:13:04 -0400
commitd61ec513c42005bb071eb15386deb5de585ff267 (patch)
tree0d36d923cbf65f62c6a9e394e0d117d2f5514df6
parent599027857e1007ff402094a3a550b4832f3f5146 (diff)
downloadllvm-d61ec513c42005bb071eb15386deb5de585ff267.zip
llvm-d61ec513c42005bb071eb15386deb5de585ff267.tar.gz
llvm-d61ec513c42005bb071eb15386deb5de585ff267.tar.bz2
[mlir][spirv] Add IsInf/IsNan expansion for WebGPU (#86903)
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in https://github.com/llvm/llvm-project/pull/86578. Also do some misc cleanups in the surrounding code.
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp54
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir32
3 files changed, 82 insertions, 16 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
index ac4d38e..d0fc85c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h
@@ -18,12 +18,18 @@
namespace mlir {
namespace spirv {
-/// Appends to a pattern list additional patterns to expand extended
-/// multiplication ops into regular arithmetic ops. Extended multiplication ops
-/// are not supported by the WebGPU Shading Language (WGSL).
+/// Appends patterns to expand extended multiplication and adition ops into
+/// regular arithmetic ops. Extended arithmetic ops are not supported by the
+/// WebGPU Shading Language (WGSL).
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns);
+/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
+/// These are not supported by the WebGPU Shading Language (WGSL). We follow
+/// fast math assumptions and assume that all floating point values are finite.
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns);
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 21de1c9..5d4dd5b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -39,7 +39,7 @@ namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
-Attribute getScalarOrSplatAttr(Type type, int64_t value) {
+static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
-Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
- Value lhs, Value rhs,
- bool signExtendArguments) {
+static Value lowerExtendedMultiplication(Operation *mulOp,
+ PatternRewriter &rewriter, Value lhs,
+ Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
}
};
+struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsInfOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsInf` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
+struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsNanOp op,
+ PatternRewriter &rewriter) const override {
+ // We assume values to be finite and turn `IsNan` info `false`.
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
-class WebGPUPreparePass
- : public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
-public:
+struct WebGPUPreparePass final
+ : impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
+ populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
- patterns.add<
- // clang-format off
- ExpandSMulExtendedPattern,
- ExpandUMulExtendedPattern,
- ExpandAddCarryPattern
- >(patterns.getContext());
+ patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
+ ExpandAddCarryPattern>(patterns.getContext());
}
+
+void populateSPIRVExpandNonFiniteArithmeticPatterns(
+ RewritePatternSet &patterns) {
+ // WGSL currently does not support `isInf` and `isNan`, see:
+ // https://github.com/gpuweb/gpuweb/pull/2311.
+ patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
+}
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
index 1ec4e5e..45f188d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
+// CHECK-LABEL: func @is_inf_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsInf %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_inf_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsInf %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: func @is_nan_f32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
+spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
+ %0 = spirv.IsNan %a : f32
+ spirv.ReturnValue %0 : i1
+}
+
+// CHECK-LABEL: func @is_nan_4xf32
+// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
+// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
+spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
+ %0 = spirv.IsNan %a : vector<4xf32>
+ spirv.ReturnValue %0 : vector<4xi1>
+}
+
} // end module