aboutsummaryrefslogtreecommitdiff
path: root/gcc
diff options
context:
space:
mode:
authorTamar Christina <tamar.christina@arm.com>2021-01-14 20:58:12 +0000
committerTamar Christina <tamar.christina@arm.com>2021-01-14 20:58:12 +0000
commit31fac31800b5671d17c46108013d6fc709370ef3 (patch)
treed98e36d4c883b57648f26c48e77234096a8624aa /gcc
parente09173d84dabd186ad2463f47350c1bcabfeab8f (diff)
downloadgcc-31fac31800b5671d17c46108013d6fc709370ef3.zip
gcc-31fac31800b5671d17c46108013d6fc709370ef3.tar.gz
gcc-31fac31800b5671d17c46108013d6fc709370ef3.tar.bz2
slp: support complex FMA and complex FMA conjugate
This adds support for FMA and FMA conjugated to the slp pattern matcher. Example of instructions matched: #include <stdio.h> #include <complex.h> #define N 200 #define ROT #define TYPE float #define TYPE2 float void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += a[i] * (b[i] ROT); } } void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += conjf (a[i]) * (b[i] ROT); } } void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N]) { for (int i=0; i < N; i++) { c[i] += a[i] * conjf (b[i] ROT); } } void caxpy_add(double complex * restrict y, double complex * restrict x, size_t N, double complex f) { for (size_t i = 0; i < N; ++i) y[i] += x[i]* f; } gcc/ChangeLog: * internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ): New. * optabs.def (cmla_optab, cmla_conj_optab): New. * doc/md.texi: Document them. * tree-vect-slp-patterns.c (vect_match_call_p, class complex_fma_pattern, vect_slp_reset_pattern, complex_fma_pattern::matches, complex_fma_pattern::recognize, complex_fma_pattern::build): New.
Diffstat (limited to 'gcc')
-rw-r--r--gcc/doc/md.texi45
-rw-r--r--gcc/internal-fn.def2
-rw-r--r--gcc/optabs.def2
-rw-r--r--gcc/tree-vect-slp-patterns.c180
4 files changed, 229 insertions, 0 deletions
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index 60e8c94..49a1ce0 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6202,6 +6202,51 @@ The operation is only supported for vector modes @var{m}.
This pattern is not allowed to @code{FAIL}.
+@cindex @code{cmla@var{m}4} instruction pattern
+@item @samp{cmla@var{m}4}
+Perform a vector multiply and accumulate that is semantically the same as
+a multiply and accumulate of complex numbers.
+
+@smallexample
+ complex TYPE c[N];
+ complex TYPE a[N];
+ complex TYPE b[N];
+ for (int i = 0; i < N; i += 1)
+ @{
+ c[i] += a[i] * b[i];
+ @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmla_conj@var{m}4} instruction pattern
+@item @samp{cmla_conj@var{m}4}
+Perform a vector multiply by conjugate and accumulate that is semantically
+the same as a multiply and accumulate of complex numbers where the second
+multiply arguments is conjugated.
+
+@smallexample
+ complex TYPE c[N];
+ complex TYPE a[N];
+ complex TYPE b[N];
+ for (int i = 0; i < N; i += 1)
+ @{
+ c[i] += a[i] * conj (b[i]);
+ @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
@cindex @code{cmul@var{m}4} instruction pattern
@item @samp{cmul@var{m}4}
Perform a vector multiply that is semantically the same as multiply of
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index e3e4fe5..020b586 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -288,6 +288,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
/* Ternary math functions. */
DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary)
/* Unary integer ops. */
DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary)
diff --git a/gcc/optabs.def b/gcc/optabs.def
index fcc27d0..cecd1b6 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -294,6 +294,8 @@ OPTAB_D (cadd90_optab, "cadd90$a3")
OPTAB_D (cadd270_optab, "cadd270$a3")
OPTAB_D (cmul_optab, "cmul$a3")
OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
+OPTAB_D (cmla_optab, "cmla$a4")
+OPTAB_D (cmla_conj_optab, "cmla_conj$a4")
OPTAB_D (cos_optab, "cos$a2")
OPTAB_D (cosh_optab, "cosh$a2")
OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index dc96be5..bd632e01 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -325,6 +325,24 @@ vect_match_expression_p (slp_tree node, tree_code code)
return true;
}
+/* Checks to see if the expression represented by NODE is a call to the internal
+ function FN. */
+
+static inline bool
+vect_match_call_p (slp_tree node, internal_fn fn)
+{
+ if (!node
+ || !SLP_TREE_REPRESENTATIVE (node))
+ return false;
+
+ gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node));
+ if (!expr
+ || !gimple_call_internal_p (expr, fn))
+ return false;
+
+ return true;
+}
+
/* Check if the given lane permute in PERMUTES matches an alternating sequence
of {even odd even odd ...}. This to account for unrolled loops. Further
mode there resulting permute must be linear. */
@@ -1086,6 +1104,168 @@ complex_mul_pattern::build (vec_info *vinfo)
}
/*******************************************************************************
+ * complex_fma_pattern class
+ ******************************************************************************/
+
+class complex_fma_pattern : public complex_pattern
+{
+ protected:
+ complex_fma_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+ : complex_pattern (node, m_ops, ifn)
+ {
+ this->m_num_args = 3;
+ }
+
+ public:
+ void build (vec_info *);
+ static internal_fn
+ matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
+ vec<slp_tree> *);
+
+ static vect_pattern*
+ recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
+
+ static vect_pattern*
+ mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+ {
+ return new complex_fma_pattern (node, m_ops, ifn);
+ }
+};
+
+/* Helper function to "reset" a previously matched node and undo the changes
+ made enough so that the node is treated as an irrelevant node. */
+
+static inline void
+vect_slp_reset_pattern (slp_tree node)
+{
+ stmt_vec_info stmt_info = vect_orig_stmt (SLP_TREE_REPRESENTATIVE (node));
+ STMT_VINFO_IN_PATTERN_P (stmt_info) = false;
+ STMT_SLP_TYPE (stmt_info) = pure_slp;
+ SLP_TREE_REPRESENTATIVE (node) = stmt_info;
+}
+
+/* Pattern matcher for trying to match complex multiply and accumulate
+ and multiply and subtract patterns in SLP tree.
+ If the operation matches then IFN is set to the operation it matched and
+ the arguments to the two replacement statements are put in m_ops.
+
+ If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
+
+ This function matches the patterns shaped as:
+
+ double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
+ double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
+
+ c[i] = c[i] - ax;
+ c[i+1] = c[i+1] + bx;
+
+ If a match occurred then TRUE is returned, else FALSE. The match is
+ performed after COMPLEX_MUL which would have done the majority of the work.
+ This function merely matches an ADD with a COMPLEX_MUL IFN. The initial
+ match is expected to be in OP1 and the initial match operands in args0. */
+
+internal_fn
+complex_fma_pattern::matches (complex_operation_t op,
+ slp_tree_to_load_perm_map_t * /* perm_cache */,
+ slp_tree *ref_node, vec<slp_tree> *ops)
+{
+ internal_fn ifn = IFN_LAST;
+
+ /* Find the two components. We match Complex MUL first which reduces the
+ amount of work this pattern has to do. After that we just match the
+ head node and we're done.:
+
+ * FMA: + +.
+
+ We need to ignore the two_operands nodes that may also match.
+ For that we can check if they have any scalar statements and also
+ check that it's not a permute node as we're looking for a normal
+ PLUS_EXPR operation. */
+ if (op != CMPLX_NONE)
+ return IFN_LAST;
+
+ /* Find the two components. We match Complex MUL first which reduces the
+ amount of work this pattern has to do. After that we just match the
+ head node and we're done.:
+
+ * FMA: + + on a non-two_operands node. */
+ slp_tree vnode = *ref_node;
+ if (SLP_TREE_LANE_PERMUTATION (vnode).exists ()
+ || !SLP_TREE_CHILDREN (vnode).exists ()
+ || !vect_match_expression_p (vnode, PLUS_EXPR))
+ return IFN_LAST;
+
+ slp_tree node = SLP_TREE_CHILDREN (vnode)[1];
+
+ if (vect_match_call_p (node, IFN_COMPLEX_MUL))
+ ifn = IFN_COMPLEX_FMA;
+ else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ))
+ ifn = IFN_COMPLEX_FMA_CONJ;
+ else
+ return IFN_LAST;
+
+ if (!vect_pattern_validate_optab (ifn, vnode))
+ return IFN_LAST;
+
+ /* FMA matched ADD + CMUL. During the matching of CMUL the
+ stmt that starts the pattern is marked as being in a pattern,
+ namely the CMUL. When replacing this with a CFMA we have to
+ unmark this statement as being in a pattern. This is because
+ vect_mark_pattern_stmts will only mark the current stmt as being
+ in a pattern. Later on when the scalar stmts are examined the
+ old statement which is supposed to be irrelevant will point to
+ CMUL unless we undo the pattern relationship here. */
+ vect_slp_reset_pattern (node);
+ ops->truncate (0);
+ ops->create (3);
+
+ if (ifn == IFN_COMPLEX_FMA)
+ {
+ ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
+ ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
+ ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
+ }
+ else
+ {
+ ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
+ ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
+ ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
+ }
+
+ return ifn;
+}
+
+/* Attempt to recognize a complex mul pattern. */
+
+vect_pattern*
+complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
+ slp_tree *node)
+{
+ auto_vec<slp_tree> ops;
+ complex_operation_t op
+ = vect_detect_pair_op (*node, true, &ops);
+ internal_fn ifn
+ = complex_fma_pattern::matches (op, perm_cache, node, &ops);
+ if (ifn == IFN_LAST)
+ return NULL;
+
+ return new complex_fma_pattern (node, &ops, ifn);
+}
+
+/* Perform a replacement of the detected complex mul pattern with the new
+ instruction sequences. */
+
+void
+complex_fma_pattern::build (vec_info *vinfo)
+{
+ SLP_TREE_CHILDREN (*this->m_node).release ();
+ SLP_TREE_CHILDREN (*this->m_node).create (3);
+ SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
+
+ complex_pattern::build (vinfo);
+}
+
+/*******************************************************************************
* Pattern matching definitions
******************************************************************************/