diff options
author | Prashant Kumar <pk5561@gmail.com> | 2024-06-20 08:08:37 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-20 08:08:37 +0530 |
commit | fa0666876cdf11162af341911b99311a56be2274 (patch) | |
tree | 0d27af480ac0a1502ca14228a3d4c0442503ac40 /mlir | |
parent | ad79a14c9e5ec4a369eed4adf567c22cc029863f (diff) | |
download | llvm-fa0666876cdf11162af341911b99311a56be2274.zip llvm-fa0666876cdf11162af341911b99311a56be2274.tar.gz llvm-fa0666876cdf11162af341911b99311a56be2274.tar.bz2 |
[mlir][linalg] Fix numerical issue with softmax (#96090)
For more info:
https://github.com/iree-org/iree/issues/17670#issuecomment-2167591878
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-decompose.mlir | 2 |
2 files changed, 3 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 7b62560..57d1266 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2714,8 +2714,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { Value neutralForMaxFInit = b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce) .result(); - Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit, - reductionDim); + Value max = + reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim); // Step 2: Subtract max from input and exponentiate. Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index ef0aca2..2e211d2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -215,7 +215,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): -// CHECK: %[[D8:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32 +// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32 // CHECK: linalg.yield %[[D8]] : f32 // CHECK: } -> tensor<2x16xf32> // CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = |