aboutsummaryrefslogtreecommitdiff
path: root/benchmarks/vec-sgemm/vec-sgemm.S
diff options
context:
space:
mode:
Diffstat (limited to 'benchmarks/vec-sgemm/vec-sgemm.S')
-rw-r--r--benchmarks/vec-sgemm/vec-sgemm.S223
1 files changed, 223 insertions, 0 deletions
diff --git a/benchmarks/vec-sgemm/vec-sgemm.S b/benchmarks/vec-sgemm/vec-sgemm.S
new file mode 100644
index 0000000..290f83c
--- /dev/null
+++ b/benchmarks/vec-sgemm/vec-sgemm.S
@@ -0,0 +1,223 @@
+ .text
+ .balign 4
+ .global vec_sgemm_nn
+# RV64IDV system
+#
+# void
+# vec_sgemm_nn(size_t n,
+# size_t m,
+# size_t k,
+# const float*a, // m * k matrix
+# size_t lda,
+# const float*b, // k * n matrix
+# size_t ldb,
+# float*c, // m * n matrix
+# size_t ldc)
+#
+# c += a*b (alpha=1, no transpose on input matrices)
+# matrices stored in C row-major order
+
+#define n a0
+#define m a1
+#define k a2
+#define ap a3
+#define astride a4
+#define bp a5
+#define bstride a6
+#define cp a7
+#define cstride t0
+#define kt t1
+#define nt t2
+#define bnp t3
+#define cnp t4
+#define akp t5
+#define bkp s0
+#define nvl s1
+#define ccp s2
+#define amp s3
+
+# Use args as additional temporaries
+#define ft12 fa0
+#define ft13 fa1
+#define ft14 fa2
+#define ft15 fa3
+
+#define FRAMESIZE 32
+
+# This version holds a 16*VLMAX block of C matrix in vector registers
+# in inner loop, but otherwise does not cache or TLB tiling.
+
+vec_sgemm_nn:
+ ld cstride, 0(sp) # Get arg from stack frame
+ addi sp, sp, -FRAMESIZE
+ sd s0, 0(sp)
+ sd s1, 8(sp)
+ sd s2, 16(sp)
+
+ # Check for zero size matrices
+ beqz n, exit
+ beqz m, exit
+ beqz k, exit
+
+ # Convert elements strides to byte strides.
+ slli astride, astride, 2
+ slli bstride, bstride, 2
+ slli cstride, cstride, 2
+
+ slti t6, m, 16
+ bnez t6, end_rows
+
+c_row_loop: # Loop across rows of C blocks
+
+ mv nt, n # Initialize n counter for next row of C blocks
+
+ mv bnp, bp # Initialize B n-loop pointer to start
+ mv cnp, cp # Initialize C n-loop pointer
+
+c_col_loop: # Loop across one row of C blocks
+ vsetvli nvl, nt, e32, ta, ma # 32-bit vectors, LMUL=1
+
+ mv akp, ap # reset pointer into A to beginning
+ mv bkp, bnp # step to next column in B matrix
+
+ # Initalize current C submatrix block from memory.
+ vle32.v v0, (cnp); add ccp, cnp, cstride;
+ vle32.v v1, (ccp); add ccp, ccp, cstride;
+ vle32.v v2, (ccp); add ccp, ccp, cstride;
+ vle32.v v3, (ccp); add ccp, ccp, cstride;
+ vle32.v v4, (ccp); add ccp, ccp, cstride;
+ vle32.v v5, (ccp); add ccp, ccp, cstride;
+ vle32.v v6, (ccp); add ccp, ccp, cstride;
+ vle32.v v7, (ccp); add ccp, ccp, cstride;
+ vle32.v v8, (ccp); add ccp, ccp, cstride;
+ vle32.v v9, (ccp); add ccp, ccp, cstride;
+ vle32.v v10, (ccp); add ccp, ccp, cstride;
+ vle32.v v11, (ccp); add ccp, ccp, cstride;
+ vle32.v v12, (ccp); add ccp, ccp, cstride;
+ vle32.v v13, (ccp); add ccp, ccp, cstride;
+ vle32.v v14, (ccp); add ccp, ccp, cstride;
+ vle32.v v15, (ccp)
+
+
+ mv kt, k # Initialize inner loop counter
+
+ # Inner loop scheduled assuming 4-clock occupancy of vfmacc instruction and single-issue pipeline
+ # Software pipeline loads
+ flw ft0, (akp); add amp, akp, astride;
+ flw ft1, (amp); add amp, amp, astride;
+ flw ft2, (amp); add amp, amp, astride;
+ flw ft3, (amp); add amp, amp, astride;
+ # Get vector from B matrix
+ vle32.v v16, (bkp)
+
+ # Loop on inner dimension for current C block
+ k_loop:
+ vfmacc.vf v0, ft0, v16
+ add bkp, bkp, bstride
+ flw ft4, (amp)
+ add amp, amp, astride
+ vfmacc.vf v1, ft1, v16
+ addi kt, kt, -1 # Decrement k counter
+ flw ft5, (amp)
+ add amp, amp, astride
+ vfmacc.vf v2, ft2, v16
+ flw ft6, (amp)
+ add amp, amp, astride
+ flw ft7, (amp)
+ vfmacc.vf v3, ft3, v16
+ add amp, amp, astride
+ flw ft8, (amp)
+ add amp, amp, astride
+ vfmacc.vf v4, ft4, v16
+ flw ft9, (amp)
+ add amp, amp, astride
+ vfmacc.vf v5, ft5, v16
+ flw ft10, (amp)
+ add amp, amp, astride
+ vfmacc.vf v6, ft6, v16
+ flw ft11, (amp)
+ add amp, amp, astride
+ vfmacc.vf v7, ft7, v16
+ flw ft12, (amp)
+ add amp, amp, astride
+ vfmacc.vf v8, ft8, v16
+ flw ft13, (amp)
+ add amp, amp, astride
+ vfmacc.vf v9, ft9, v16
+ flw ft14, (amp)
+ add amp, amp, astride
+ vfmacc.vf v10, ft10, v16
+ flw ft15, (amp)
+ add amp, amp, astride
+ addi akp, akp, 4 # Move to next column of a
+ vfmacc.vf v11, ft11, v16
+ beqz kt, 1f # Don't load past end of matrix
+ flw ft0, (akp)
+ add amp, akp, astride
+1: vfmacc.vf v12, ft12, v16
+ beqz kt, 1f
+ flw ft1, (amp)
+ add amp, amp, astride
+1: vfmacc.vf v13, ft13, v16
+ beqz kt, 1f
+ flw ft2, (amp)
+ add amp, amp, astride
+1: vfmacc.vf v14, ft14, v16
+ beqz kt, 1f # Exit out of loop
+ flw ft3, (amp)
+ add amp, amp, astride
+ vfmacc.vf v15, ft15, v16
+ vle32.v v16, (bkp) # Get next vector from B matrix, overlap loads with jump stalls
+ j k_loop
+
+1: vfmacc.vf v15, ft15, v16
+
+ # Save C matrix block back to memory
+ vse32.v v0, (cnp); add ccp, cnp, cstride;
+ vse32.v v1, (ccp); add ccp, ccp, cstride;
+ vse32.v v2, (ccp); add ccp, ccp, cstride;
+ vse32.v v3, (ccp); add ccp, ccp, cstride;
+ vse32.v v4, (ccp); add ccp, ccp, cstride;
+ vse32.v v5, (ccp); add ccp, ccp, cstride;
+ vse32.v v6, (ccp); add ccp, ccp, cstride;
+ vse32.v v7, (ccp); add ccp, ccp, cstride;
+ vse32.v v8, (ccp); add ccp, ccp, cstride;
+ vse32.v v9, (ccp); add ccp, ccp, cstride;
+ vse32.v v10, (ccp); add ccp, ccp, cstride;
+ vse32.v v11, (ccp); add ccp, ccp, cstride;
+ vse32.v v12, (ccp); add ccp, ccp, cstride;
+ vse32.v v13, (ccp); add ccp, ccp, cstride;
+ vse32.v v14, (ccp); add ccp, ccp, cstride;
+ vse32.v v15, (ccp)
+
+ # Following tail instructions should be scheduled earlier in free slots during C block save.
+ # Leaving here for clarity.
+
+ # Bump pointers for loop across blocks in one row
+ slli t6, nvl, 2
+ add cnp, cnp, t6 # Move C block pointer over
+ add bnp, bnp, t6 # Move B block pointer over
+ sub nt, nt, nvl # Decrement element count in n dimension
+ bnez nt, c_col_loop # Any more to do?
+
+ # Move to next set of rows
+ addi m, m, -16 # Did 16 rows above
+ slli t6, astride, 4 # Multiply astride by 16
+ add ap, ap, t6 # Move A matrix pointer down 16 rows
+ slli t6, cstride, 4 # Multiply cstride by 16
+ add cp, cp, t6 # Move C matrix pointer down 16 rows
+
+ slti t6, m, 16
+ beqz t6, c_row_loop
+
+ # Handle end of matrix with fewer than 16 rows.
+ # Can use smaller versions of above decreasing in powers-of-2 depending on code-size concerns.
+end_rows:
+ # Not done.
+
+exit:
+ ld s0, 0(sp)
+ ld s1, 8(sp)
+ ld s2, 16(sp)
+ addi sp, sp, FRAMESIZE
+ ret