aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scalam/plotting/Plot.scala
blob: 45d755e4bc6d643b7c047821b61a095daec906d0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package scalam.plotting

import scala.sys.process._
import scalam.m.ast._
import scalax.file.Path
import scalam.plotting.styles._
import scala.collection.mutable.Map
import scala.collection.mutable.ListBuffer

class Plot(
  val dataSets: Seq[DataSet],
  title: String,
  xLabel: String,
  yLabel: String,
  grid: Boolean = true,
  legend: Boolean = true,
  fontSize: Int = 10,
  styles: Seq[Style[StyleElement]] = Seq(),
  name: String = "plot" + Plot.next) {

  val directory = Path(name)
  val localPlotFile = Path("results.m")

  def preamble = {
    val df = new java.text.SimpleDateFormat("EEE, d MMM yyyy HH:mm:ss")
    val now = (new java.util.Date(System.currentTimeMillis()))
    Seq(
      DoubleComment("Generated by scalam, v1.0-SNAPSHOT"),
      DoubleComment(df.format(now)))
  }

  class RichDataSet(val id: Identifier, val localPath: Path, val underlying: DataSet)

  def richDataSets = {
    val knownIds = Map[Identifier, Int]()
    def toRich(dataSet: DataSet) = {
      val firstId = Identifier(dataSet.name)
      val finalId = knownIds.get(firstId) match {
        case None => { //firstId is not already used
          knownIds += (firstId -> 1)
          firstId
        }
        case Some(prev) => { // firstId is already in use
          knownIds(firstId) = prev + 1
          Identifier(firstId.name + "_" + prev)
        }
      }
      new RichDataSet(finalId, Path("data") / finalId.name, dataSet)
    }
    dataSets.map(toRich(_))
  }

  def resolveStyles: (Seq[Root], Seq[DataSet => StyleElement]) = {
    val setupAndStyles = styles.map(_.apply(dataSets))
    val setup = setupAndStyles.map(_._1).flatten
    val styleMaps = setupAndStyles.map(_._2)
    (setup, styleMaps)
  }

  def roots: Seq[Root] = {
    import Plot._
    val richDataSets = this.richDataSets

    val (setup, styleMappings) = resolveStyles

    val loads = richDataSets.map(r =>
      m.load(r.id, r.localPath) withComment SimpleComment(r.underlying.label))

    val plots = richDataSets.map { r =>
      val styleElements = styleMappings.map(_.apply(r.underlying))
      m.plot(r.id, styleElements)
    }

    val roots = new ListBuffer[Root]
    roots ++= preamble
    roots ++= loads
    roots += m.newFigure
    roots += m.hold(true)
    roots += m.grid(this.grid)
    roots += m.title(this.title)
    roots += m.xLabel(this.xLabel)
    roots += m.yLabel(this.yLabel)
    roots += m.fontSize(this.fontSize)
    roots ++= plots
    roots += m.legend(dataSets)

    roots.toList
  }

  def save() = {
    for (d <- richDataSets) d.underlying.save(directory / d.localPath)

    val plotFile = (directory / localPlotFile)
    plotFile.createFile(createParents = true, failIfExists = false)
    for (processor <- plotFile.outputProcessor; out = processor.asOutput) {
      for (p <- preamble) out.write(p.line + "\n")
      for (r <- roots) out.write(r.line + "\n")
    }

  }

  def run() = {
    Process(
      "matlab -nodesktop -nosplash -r " + localPlotFile.path.takeWhile(_ != '.'),
      directory.fileOption,
      "" -> "") #> (directory / "log.txt").fileOption.get run
  }

}

object Plot {
  private[this] var counter = -1
  private def next = { counter += 1; counter }

  private def randomDataSet(length: Int) = {
    import scala.util.Random
    val data = for (i <- 0 until length) yield (i * 1.0, Random.nextDouble() * 10)
    val name = "a"
    DataSet(data, name)
  }

  val ds = Seq(
    DataSet(Seq((0.0, 1.0), (1.0, 1.0), (2.0, 1.0), (3.0, 0.0), (4.0, 1.0), (5.0, 1.0)), "temperature"),
    DataSet(Seq((0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (3.0, 9.0)), """\alpha""")) ++ (0 to 10).map(_ => randomDataSet(10))

  val test = new Plot(ds, "title", "x", "y")

  private object m {
    import scalam.m.ast._

    val On = StringLiteral("on")
    val Off = StringLiteral("off")

    def newFigure = Function(Identifier("figure"))
    def hold(b: Boolean) = Function(Identifier("hold"), if (b) On else Off)
    def grid(show: Boolean) = Function(Identifier("grid"), if (show) On else Off)
    def title(s: String) = Function(Identifier("title"), StringLiteral(s))
    def xLabel(s: String) = Function(Identifier("xlabel"), StringLiteral(s))
    def yLabel(s: String) = Function(Identifier("ylabel"), StringLiteral(s))
    def fontSize(size: Int) = Function(Identifier("set"), Variable(Identifier("gca")), StringLiteral("fontsize"), IntLiteral(size))
    def load(id: Identifier, path: Path) = Assign(id, Function(Identifier("load"), StringLiteral(path.path)))
    def plot(dataSet: Identifier, styleElements: Seq[StyleElement]) = {
      val plot = Identifier("plot")
      val params = Seq(
        IndexMatrix(dataSet, SliceLiteral, IntLiteral(1)),
        IndexMatrix(dataSet, SliceLiteral, IntLiteral(2))) ++
        styleElements.flatMap(e => Seq(e.name, e.expression))
      Function(plot, params: _*)
    }
    def legend(dataSets: Seq[DataSet]) =
      Function(Identifier("legend"), dataSets.map(d => StringLiteral(d.label)): _*)
  }

}