aboutsummaryrefslogtreecommitdiff
path: root/benchmarks/spmv/spmv_gendata.scala
blob: f777445f59cf6a395d95b0b2a4efc9778b5cafa6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env scala
!#

val m = args(0).toInt
val n = args(1).toInt
val approx_nnz = args(2).toInt

val pnnz = approx_nnz.toDouble/(m*n)
val idx = collection.mutable.ArrayBuffer[Int]()
val p = collection.mutable.ArrayBuffer(0)

for (i <- 0 until m) {
  for (j <- 0 until n) {
    if (util.Random.nextDouble < pnnz)
      idx += j
  }
  p += idx.length
}

val nnz = idx.length
val v = Array.tabulate(n)(i => util.Random.nextInt(1000))
val d = Array.tabulate(nnz)(i => util.Random.nextInt(1000))

def printVec(t: String, name: String, data: Seq[Int]) = {
  println("const " + t + " " + name + "[" + data.length + "] = {")
  println("  "+data.map(_.toString).reduceLeft(_+",\n  "+_))
  println("};")
}

def spmv(p: Seq[Int], d: Seq[Int], idx: Seq[Int], v: Seq[Int]) = {
  val y = collection.mutable.ArrayBuffer[Int]()
  for (i <- 0 until p.length-1) {
    var yi = 0
    for (k <- p(i) until p(i+1))
      yi = yi + d(k)*v(idx(k))
    y += yi
  }
  y
}

println("#define R " + m)
println("#define C " + n)
println("#define NNZ " + nnz)
printVec("double", "val", d)
printVec("int", "idx", idx)
printVec("double", "x", v)
printVec("int", "ptr", p)
printVec("double", "verify_data", spmv(p, d, idx, v))