diff options
Diffstat (limited to 'benchmarks/vec-sgemm/vec-sgemm.S')
-rw-r--r-- | benchmarks/vec-sgemm/vec-sgemm.S | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/benchmarks/vec-sgemm/vec-sgemm.S b/benchmarks/vec-sgemm/vec-sgemm.S new file mode 100644 index 0000000..290f83c --- /dev/null +++ b/benchmarks/vec-sgemm/vec-sgemm.S @@ -0,0 +1,223 @@ + .text + .balign 4 + .global vec_sgemm_nn +# RV64IDV system +# +# void +# vec_sgemm_nn(size_t n, +# size_t m, +# size_t k, +# const float*a, // m * k matrix +# size_t lda, +# const float*b, // k * n matrix +# size_t ldb, +# float*c, // m * n matrix +# size_t ldc) +# +# c += a*b (alpha=1, no transpose on input matrices) +# matrices stored in C row-major order + +#define n a0 +#define m a1 +#define k a2 +#define ap a3 +#define astride a4 +#define bp a5 +#define bstride a6 +#define cp a7 +#define cstride t0 +#define kt t1 +#define nt t2 +#define bnp t3 +#define cnp t4 +#define akp t5 +#define bkp s0 +#define nvl s1 +#define ccp s2 +#define amp s3 + +# Use args as additional temporaries +#define ft12 fa0 +#define ft13 fa1 +#define ft14 fa2 +#define ft15 fa3 + +#define FRAMESIZE 32 + +# This version holds a 16*VLMAX block of C matrix in vector registers +# in inner loop, but otherwise does not cache or TLB tiling. + +vec_sgemm_nn: + ld cstride, 0(sp) # Get arg from stack frame + addi sp, sp, -FRAMESIZE + sd s0, 0(sp) + sd s1, 8(sp) + sd s2, 16(sp) + + # Check for zero size matrices + beqz n, exit + beqz m, exit + beqz k, exit + + # Convert elements strides to byte strides. + slli astride, astride, 2 + slli bstride, bstride, 2 + slli cstride, cstride, 2 + + slti t6, m, 16 + bnez t6, end_rows + +c_row_loop: # Loop across rows of C blocks + + mv nt, n # Initialize n counter for next row of C blocks + + mv bnp, bp # Initialize B n-loop pointer to start + mv cnp, cp # Initialize C n-loop pointer + +c_col_loop: # Loop across one row of C blocks + vsetvli nvl, nt, e32, ta, ma # 32-bit vectors, LMUL=1 + + mv akp, ap # reset pointer into A to beginning + mv bkp, bnp # step to next column in B matrix + + # Initalize current C submatrix block from memory. + vle32.v v0, (cnp); add ccp, cnp, cstride; + vle32.v v1, (ccp); add ccp, ccp, cstride; + vle32.v v2, (ccp); add ccp, ccp, cstride; + vle32.v v3, (ccp); add ccp, ccp, cstride; + vle32.v v4, (ccp); add ccp, ccp, cstride; + vle32.v v5, (ccp); add ccp, ccp, cstride; + vle32.v v6, (ccp); add ccp, ccp, cstride; + vle32.v v7, (ccp); add ccp, ccp, cstride; + vle32.v v8, (ccp); add ccp, ccp, cstride; + vle32.v v9, (ccp); add ccp, ccp, cstride; + vle32.v v10, (ccp); add ccp, ccp, cstride; + vle32.v v11, (ccp); add ccp, ccp, cstride; + vle32.v v12, (ccp); add ccp, ccp, cstride; + vle32.v v13, (ccp); add ccp, ccp, cstride; + vle32.v v14, (ccp); add ccp, ccp, cstride; + vle32.v v15, (ccp) + + + mv kt, k # Initialize inner loop counter + + # Inner loop scheduled assuming 4-clock occupancy of vfmacc instruction and single-issue pipeline + # Software pipeline loads + flw ft0, (akp); add amp, akp, astride; + flw ft1, (amp); add amp, amp, astride; + flw ft2, (amp); add amp, amp, astride; + flw ft3, (amp); add amp, amp, astride; + # Get vector from B matrix + vle32.v v16, (bkp) + + # Loop on inner dimension for current C block + k_loop: + vfmacc.vf v0, ft0, v16 + add bkp, bkp, bstride + flw ft4, (amp) + add amp, amp, astride + vfmacc.vf v1, ft1, v16 + addi kt, kt, -1 # Decrement k counter + flw ft5, (amp) + add amp, amp, astride + vfmacc.vf v2, ft2, v16 + flw ft6, (amp) + add amp, amp, astride + flw ft7, (amp) + vfmacc.vf v3, ft3, v16 + add amp, amp, astride + flw ft8, (amp) + add amp, amp, astride + vfmacc.vf v4, ft4, v16 + flw ft9, (amp) + add amp, amp, astride + vfmacc.vf v5, ft5, v16 + flw ft10, (amp) + add amp, amp, astride + vfmacc.vf v6, ft6, v16 + flw ft11, (amp) + add amp, amp, astride + vfmacc.vf v7, ft7, v16 + flw ft12, (amp) + add amp, amp, astride + vfmacc.vf v8, ft8, v16 + flw ft13, (amp) + add amp, amp, astride + vfmacc.vf v9, ft9, v16 + flw ft14, (amp) + add amp, amp, astride + vfmacc.vf v10, ft10, v16 + flw ft15, (amp) + add amp, amp, astride + addi akp, akp, 4 # Move to next column of a + vfmacc.vf v11, ft11, v16 + beqz kt, 1f # Don't load past end of matrix + flw ft0, (akp) + add amp, akp, astride +1: vfmacc.vf v12, ft12, v16 + beqz kt, 1f + flw ft1, (amp) + add amp, amp, astride +1: vfmacc.vf v13, ft13, v16 + beqz kt, 1f + flw ft2, (amp) + add amp, amp, astride +1: vfmacc.vf v14, ft14, v16 + beqz kt, 1f # Exit out of loop + flw ft3, (amp) + add amp, amp, astride + vfmacc.vf v15, ft15, v16 + vle32.v v16, (bkp) # Get next vector from B matrix, overlap loads with jump stalls + j k_loop + +1: vfmacc.vf v15, ft15, v16 + + # Save C matrix block back to memory + vse32.v v0, (cnp); add ccp, cnp, cstride; + vse32.v v1, (ccp); add ccp, ccp, cstride; + vse32.v v2, (ccp); add ccp, ccp, cstride; + vse32.v v3, (ccp); add ccp, ccp, cstride; + vse32.v v4, (ccp); add ccp, ccp, cstride; + vse32.v v5, (ccp); add ccp, ccp, cstride; + vse32.v v6, (ccp); add ccp, ccp, cstride; + vse32.v v7, (ccp); add ccp, ccp, cstride; + vse32.v v8, (ccp); add ccp, ccp, cstride; + vse32.v v9, (ccp); add ccp, ccp, cstride; + vse32.v v10, (ccp); add ccp, ccp, cstride; + vse32.v v11, (ccp); add ccp, ccp, cstride; + vse32.v v12, (ccp); add ccp, ccp, cstride; + vse32.v v13, (ccp); add ccp, ccp, cstride; + vse32.v v14, (ccp); add ccp, ccp, cstride; + vse32.v v15, (ccp) + + # Following tail instructions should be scheduled earlier in free slots during C block save. + # Leaving here for clarity. + + # Bump pointers for loop across blocks in one row + slli t6, nvl, 2 + add cnp, cnp, t6 # Move C block pointer over + add bnp, bnp, t6 # Move B block pointer over + sub nt, nt, nvl # Decrement element count in n dimension + bnez nt, c_col_loop # Any more to do? + + # Move to next set of rows + addi m, m, -16 # Did 16 rows above + slli t6, astride, 4 # Multiply astride by 16 + add ap, ap, t6 # Move A matrix pointer down 16 rows + slli t6, cstride, 4 # Multiply cstride by 16 + add cp, cp, t6 # Move C matrix pointer down 16 rows + + slti t6, m, 16 + beqz t6, c_row_loop + + # Handle end of matrix with fewer than 16 rows. + # Can use smaller versions of above decreasing in powers-of-2 depending on code-size concerns. +end_rows: + # Not done. + +exit: + ld s0, 0(sp) + ld s1, 8(sp) + ld s2, 16(sp) + addi sp, sp, FRAMESIZE + ret |