diff options
-rw-r--r-- | gcc/doc/invoke.texi | 6 | ||||
-rw-r--r-- | gcc/params.opt | 7 | ||||
-rw-r--r-- | gcc/testsuite/gcc.dg/pr110279-2.c | 41 | ||||
-rw-r--r-- | gcc/tree-ssa-reassoc.cc | 150 |
4 files changed, 177 insertions, 27 deletions
diff --git a/gcc/doc/invoke.texi b/gcc/doc/invoke.texi index f89f926..c671bec 100644 --- a/gcc/doc/invoke.texi +++ b/gcc/doc/invoke.texi @@ -16583,6 +16583,12 @@ Maximum number of basic blocks for VRP to use a basic cache vector. @item avoid-fma-max-bits Maximum number of bits for which we avoid creating FMAs. +@item fully-pipelined-fma +Whether the target fully pipelines FMA instructions. If non-zero, +reassociation considers the benefit of parallelizing FMA's multiplication +part and addition part, assuming FMUL and FMA use the same units that can +also do FADD. + @item sms-loop-average-count-threshold A threshold on the average loop count considered by the swing modulo scheduler. diff --git a/gcc/params.opt b/gcc/params.opt index f187659..63a4c30 100644 --- a/gcc/params.opt +++ b/gcc/params.opt @@ -134,6 +134,13 @@ Maximal estimated growth of function body caused by early inlining of single cal Common Joined UInteger Var(param_fsm_scale_path_stmts) Init(2) IntegerRange(1, 10) Param Optimization Scale factor to apply to the number of statements in a threading path crossing a loop backedge when comparing to max-jump-thread-duplication-stmts. +-param=fully-pipelined-fma= +Common Joined UInteger Var(param_fully_pipelined_fma) Init(0) IntegerRange(0, 1) Param Optimization +Whether the target fully pipelines FMA instructions. If non-zero, +reassociation considers the benefit of parallelizing FMA's multiplication +part and addition part, assuming FMUL and FMA use the same units that can +also do FADD. + -param=gcse-after-reload-critical-fraction= Common Joined UInteger Var(param_gcse_after_reload_critical_fraction) Init(10) Param Optimization The threshold ratio of critical edges execution count that permit performing redundancy elimination after reload. diff --git a/gcc/testsuite/gcc.dg/pr110279-2.c b/gcc/testsuite/gcc.dg/pr110279-2.c new file mode 100644 index 0000000..0304a77 --- /dev/null +++ b/gcc/testsuite/gcc.dg/pr110279-2.c @@ -0,0 +1,41 @@ +/* PR tree-optimization/110279 */ +/* { dg-do compile } */ +/* { dg-options "-Ofast --param tree-reassoc-width=4 --param fully-pipelined-fma=1 -fdump-tree-reassoc2-details -fdump-tree-optimized" } */ +/* { dg-additional-options "-march=armv8.2-a" { target aarch64-*-* } } */ + +#define LOOP_COUNT 800000000 +typedef double data_e; + +#include <stdio.h> + +__attribute_noinline__ data_e +foo (data_e in) +{ + data_e a1, a2, a3, a4; + data_e tmp, result = 0; + a1 = in + 0.1; + a2 = in * 0.1; + a3 = in + 0.01; + a4 = in * 0.59; + + data_e result2 = 0; + + for (int ic = 0; ic < LOOP_COUNT; ic++) + { + /* Test that a complete FMA chain with length=4 is not broken. */ + tmp = a1 + a2 * a2 + a3 * a3 + a4 * a4 ; + result += tmp - ic; + result2 = result2 / 2 - tmp; + + a1 += 0.91; + a2 += 0.1; + a3 -= 0.01; + a4 -= 0.89; + + } + + return result + result2; +} + +/* { dg-final { scan-tree-dump-not "was chosen for reassociation" "reassoc2"} } */ +/* { dg-final { scan-tree-dump-times {\.FMA } 3 "optimized"} } */
\ No newline at end of file diff --git a/gcc/tree-ssa-reassoc.cc b/gcc/tree-ssa-reassoc.cc index 07fc8e2..cdef9f7 100644 --- a/gcc/tree-ssa-reassoc.cc +++ b/gcc/tree-ssa-reassoc.cc @@ -5430,13 +5430,35 @@ get_required_cycles (int ops_num, int cpu_width) return res; } +/* Given that the target fully pipelines FMA instructions, return the latency + of MULT_EXPRs that can't be hidden by the FMAs. WIDTH is the number of + pipes. */ + +static inline int +get_mult_latency_consider_fma (int ops_num, int mult_num, int width) +{ + gcc_checking_assert (mult_num && mult_num <= ops_num); + + /* For each partition, if mult_num == ops_num, there's latency(MULT)*2. + e.g: + + A * B + C * D + => + _1 = A * B; + _2 = .FMA (C, D, _1); + + Otherwise there's latency(MULT)*1 in the first FMA. */ + return CEIL (ops_num, width) == CEIL (mult_num, width) ? 2 : 1; +} + /* Returns an optimal number of registers to use for computation of given statements. - LHS is the result ssa name of OPS. */ + LHS is the result ssa name of OPS. MULT_NUM is number of sub-expressions + that are MULT_EXPRs, when OPS are PLUS_EXPRs or MINUS_EXPRs. */ static int -get_reassociation_width (vec<operand_entry *> *ops, tree lhs, +get_reassociation_width (vec<operand_entry *> *ops, int mult_num, tree lhs, enum tree_code opc, machine_mode mode) { int param_width = param_tree_reassoc_width; @@ -5462,16 +5484,68 @@ get_reassociation_width (vec<operand_entry *> *ops, tree lhs, so we can perform a binary search for the minimal width that still results in the optimal cycle count. */ width_min = 1; - while (width > width_min) + + /* If the target fully pipelines FMA instruction, the multiply part can start + already if its operands are ready. Assuming symmetric pipes are used for + FMUL/FADD/FMA, then for a sequence of FMA like: + + _8 = .FMA (_2, _3, _1); + _9 = .FMA (_5, _4, _8); + _10 = .FMA (_7, _6, _9); + + , if width=1, the latency is latency(MULT) + latency(ADD)*3. + While with width=2: + + _8 = _4 * _5; + _9 = .FMA (_2, _3, _1); + _10 = .FMA (_6, _7, _8); + _11 = _9 + _10; + + , it is latency(MULT)*2 + latency(ADD)*2. Assuming latency(MULT) >= + latency(ADD), the first variant is preferred. + + Find out if we can get a smaller width considering FMA. */ + if (width > 1 && mult_num && param_fully_pipelined_fma) { - int width_mid = (width + width_min) / 2; + /* When param_fully_pipelined_fma is set, assume FMUL and FMA use the + same units that can also do FADD. For other scenarios, such as when + FMUL and FADD are using separated units, the following code may not + appy. */ + int width_mult = targetm.sched.reassociation_width (MULT_EXPR, mode); + gcc_checking_assert (width_mult <= width); + + /* Latency of MULT_EXPRs. */ + int lat_mul + = get_mult_latency_consider_fma (ops_num, mult_num, width_mult); + + /* Quick search might not apply. So start from 1. */ + for (int i = 1; i < width_mult; i++) + { + int lat_mul_new + = get_mult_latency_consider_fma (ops_num, mult_num, i); + int lat_add_new = get_required_cycles (ops_num, i); - if (get_required_cycles (ops_num, width_mid) == cycles_best) - width = width_mid; - else if (width_min < width_mid) - width_min = width_mid; - else - break; + /* Assume latency(MULT) >= latency(ADD). */ + if (lat_mul - lat_mul_new >= lat_add_new - cycles_best) + { + width = i; + break; + } + } + } + else + { + while (width > width_min) + { + int width_mid = (width + width_min) / 2; + + if (get_required_cycles (ops_num, width_mid) == cycles_best) + width = width_mid; + else if (width_min < width_mid) + width_min = width_mid; + else + break; + } } /* If there's loop dependent FMA result, return width=2 to avoid it. This is @@ -6841,8 +6915,10 @@ transform_stmt_to_multiply (gimple_stmt_iterator *gsi, gimple *stmt, Rearrange ops to -> e + a * b + c * d generates: _4 = .FMA (c_7(D), d_8(D), _3); - _11 = .FMA (a_5(D), b_6(D), _4); */ -static bool + _11 = .FMA (a_5(D), b_6(D), _4); + + Return the number of MULT_EXPRs in the chain. */ +static int rank_ops_for_fma (vec<operand_entry *> *ops) { operand_entry *oe; @@ -6856,9 +6932,26 @@ rank_ops_for_fma (vec<operand_entry *> *ops) if (TREE_CODE (oe->op) == SSA_NAME) { gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op); - if (is_gimple_assign (def_stmt) - && gimple_assign_rhs_code (def_stmt) == MULT_EXPR) - ops_mult.safe_push (oe); + if (is_gimple_assign (def_stmt)) + { + if (gimple_assign_rhs_code (def_stmt) == MULT_EXPR) + ops_mult.safe_push (oe); + /* A negate on the multiplication leads to FNMA. */ + else if (gimple_assign_rhs_code (def_stmt) == NEGATE_EXPR + && TREE_CODE (gimple_assign_rhs1 (def_stmt)) == SSA_NAME) + { + gimple *neg_def_stmt + = SSA_NAME_DEF_STMT (gimple_assign_rhs1 (def_stmt)); + if (is_gimple_assign (neg_def_stmt) + && gimple_bb (neg_def_stmt) == gimple_bb (def_stmt) + && gimple_assign_rhs_code (neg_def_stmt) == MULT_EXPR) + ops_mult.safe_push (oe); + else + ops_others.safe_push (oe); + } + else + ops_others.safe_push (oe); + } else ops_others.safe_push (oe); } @@ -6874,7 +6967,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops) Putting ops that not def from mult in front can generate more FMAs. 2. If all ops are defined with mult, we don't need to rearrange them. */ - if (ops_mult.length () >= 2 && ops_mult.length () != ops_length) + unsigned mult_num = ops_mult.length (); + if (mult_num >= 2 && mult_num != ops_length) { /* Put no-mult ops and mult ops alternately at the end of the queue, which is conducive to generating more FMA and reducing the @@ -6890,9 +6984,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops) if (opindex > 0) opindex--; } - return true; } - return false; + return mult_num; } /* Reassociate expressions in basic block BB and its post-dominator as children. @@ -7057,8 +7150,8 @@ reassociate_bb (basic_block bb) { machine_mode mode = TYPE_MODE (TREE_TYPE (lhs)); int ops_num = ops.length (); - int width; - bool has_fma = false; + int width = 0; + int mult_num = 0; /* For binary bit operations, if there are at least 3 operands and the last operand in OPS is a constant, @@ -7081,16 +7174,17 @@ reassociate_bb (basic_block bb) opt_type) && (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR)) { - has_fma = rank_ops_for_fma (&ops); + mult_num = rank_ops_for_fma (&ops); } /* Only rewrite the expression tree to parallel in the last reassoc pass to avoid useless work back-and-forth with initial linearization. */ + bool has_fma = mult_num >= 2 && mult_num != ops_num; if (!reassoc_insert_powi_p && ops.length () > 3 - && (width - = get_reassociation_width (&ops, lhs, rhs_code, mode)) + && (width = get_reassociation_width (&ops, mult_num, lhs, + rhs_code, mode)) > 1) { if (dump_file && (dump_flags & TDF_DETAILS)) @@ -7111,10 +7205,12 @@ reassociate_bb (basic_block bb) if (len >= 3 && (!has_fma /* width > 1 means ranking ops results in better - parallelism. */ - || get_reassociation_width (&ops, lhs, rhs_code, - mode) - > 1)) + parallelism. Check current value to avoid + calling get_reassociation_width again. */ + || (width != 1 + && get_reassociation_width ( + &ops, mult_num, lhs, rhs_code, mode) + > 1))) swap_ops_for_binary_stmt (ops, len - 3); new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops, |