aboutsummaryrefslogtreecommitdiff
path: root/benchmarks/mm/gen.scala
blob: 982daa86d7375995b30c5686877a98e32ecd09a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import scala.sys.process._
object MMGen {
  implicit def i2s(i: Int) = i.toString
  def writeFile(name: String, contents: String) = {
    val f = new java.io.FileWriter(name)
    f.write(contents)
    f.close
  }

  var indent = 0
  def spacing = "  " * indent
  def assign(lhs: String, rhs: String) =
    spacing + lhs + " = " + rhs + ";\n"
  def init(t: String, n: String, v: String) =
    assign(t+" "+n, v)
  def open_block(s: String = "") = {
    val result = (if (s != "") spacing + s else "") + spacing + "{\n"
    indent = indent + 1
    result
  }
  def close_block = {
    indent = indent - 1
    spacing + "}\n"
  }

  def ar(m: String, i: String) = m+"["+i+"]"
  def r(a: String, b: String*) = (a :: b.toList).reduceLeft(_+"_"+_)

  def rb(m: Int, n: Int, p: Int) = {
    var s = open_block("static inline void kloop(size_t p, t* a0, size_t lda, t* b0, size_t ldb, t* c, size_t ldc)\n")

    for (i <- 0 until m)
      s += init("t*", r("c", i), "&"+ar("c", "ldc*"+i))
    for (i <- 0 until m; j <- 0 until n)
      s += init("t", r("c", i, j), ar(r("c", i), j))

    def doit(m: Int, n: Int, p: Int) = {
      for (i <- 0 until m)
        s += init("t*", r("a", i), "&"+ar("a", "lda*"+i))
      for (k <- 0 until p)
        s += init("t*", r("b", k), "&"+ar("b", "ldb*"+k))
      for (k <- 0 until p; i <- 0 until m; j <- 0 until n)
        s += assign(r("c", i, j), "fma(" + ar(r("a", i), k) + ", " + ar(r("b", k), j) + ", " + r("c", i, j) + ")")
    }

    s += open_block("for (t *a = a0, *b = b0; a < a0 + p/RBK*RBK; a += RBK, b += RBK*ldb)\n")
    doit(m, n, p)
    s += close_block

    s += open_block("for (t *a = a0 + p/RBK*RBK, *b = b0 + p/RBK*RBK*ldb; a < a0 + p; a++, b += ldb)\n")
    doit(m, n, 1)
    s += close_block

    for (i <- 0 until m; j <- 0 until n)
      s += assign(ar(r("c", i), j), r("c", i, j))
    s += close_block

    s
  }
  def gcd(a: Int, b: Int): Int = if (b == 0) a else gcd(b, a%b)
  def lcm(a: Int, b: Int): Int = a*b/gcd(a, b)
  def lcm(a: Seq[Int]): Int = {
    if (a.tail.isEmpty) a.head
    else lcm(a.head, lcm(a.tail))
  }
  def test1(m: Int, n: Int, p: Int, m1: Int, n1: Int, p1: Int) = {
    val decl = "static const int RBM = "+m+", RBN = "+n+", RBK = "+p+";\n" +
               "static const int CBM = "+m1+", CBN = "+n1+", CBK = "+p1+";\n"
    writeFile("rb.h", decl + rb(m, n, p))
    //"make"!!

    "make run"!

    ("cp a.out " + Seq("b", m, n, p, m1, n1, p1, "run").reduce(_+"."+_))!
  }
  def main(args: Array[String]): Unit = {
    test1(4, 5, 6, 24, 25, 24)
    //for (i <- 4 to 6; j <- 4 to 6; k <- 4 to 6)
    //  test1(i, j, k, if (i == 5) 35 else 36, if (j == 5) 35 else 36, if (k == 5) 35 else 36)
  }
}