From 0981dca7779d4acfcbb92fbb29a7a1033e283b88 Mon Sep 17 00:00:00 2001 From: donald chen Date: Wed, 29 May 2024 22:20:49 +0800 Subject: [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (#93278) For maxnumf and minnumf, the result of calculations involving NaN will be another value, so their neutral element is set to NaN. --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 14 ++++ .../Linalg/transform-op-split-reduction.mlir | 92 ++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index a0b5025..5797c56 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, : APFloat::getInf(semantic, /*Negative=*/true); return builder.getFloatAttr(resultType, identity); } + case AtomicRMWKind::maxnumf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true); + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::addf: case AtomicRMWKind::addi: case AtomicRMWKind::maxu: @@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, return builder.getFloatAttr(resultType, identity); } + case AtomicRMWKind::minnumf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false); + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::mins: return builder.getIntegerAttr( resultType, APInt::getSignedMaxValue( @@ -2518,6 +2530,8 @@ std::optional mlir::arith::getNeutralElement(Operation *op) { .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; }) .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; }) .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; }) + .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; }) + .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; }) // Integer operations. .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir index 31e9fd0..9849f36 100644 --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -407,3 +407,95 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- +// Checks we use nan as the neutral element for maxnumf op. +func.func @generic_split_maxnumf(%in: tensor<32xf32>, %out: tensor) -> tensor { + %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%in : tensor<32xf32>) + outs(%out : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %y = arith.maxnumf %arg1, %arg2 : f32 + linalg.yield %y : f32 + } -> tensor + return %r : tensor +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()> +// CHECK-LABEL: func @generic_split_maxnumf +// The float value 0xFFC00000 that is filled into the init tensor represents negative NaN. +// CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} +// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) { +// CHECK: arith.maxnumf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// CHECK: arith.maxnumf {{.*}} +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- +// Checks we use nan as the neutral element for minnumf op. +func.func @generic_split_minnumf(%in: tensor<32xf32>, %out: tensor) -> tensor { + %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%in : tensor<32xf32>) + outs(%out : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %y = arith.minnumf %arg1, %arg2 : f32 + linalg.yield %y : f32 + } -> tensor + return %r : tensor +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()> +// CHECK-LABEL: func @generic_split_minnumf +// The float value 0x7FC00000 that is filled into the init tensor represents positive NaN. +// CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} +// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) { +// CHECK: arith.minnumf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// CHECK: arith.minnumf {{.*}} +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} -- cgit v1.1