summaryrefslogtreecommitdiff
path: root/scalatex/compilerPlugin/src/main/scala/scalatex/CompilerPlugin.scala
blob: e122de554a7df23a4b588f8546a568eadb939322 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package scalatex

import java.nio.file.Paths

import scala.reflect.internal.util.BatchSourceFile
import scala.reflect.io.VirtualFile
import scala.tools.nsc.{ Global, Phase }
import scala.tools.nsc.plugins.{ Plugin, PluginComponent }

class CompilerPlugin(val global: Global) extends Plugin {
  import global._

  override def init(options: List[String],  error: String => Unit): Boolean = true

  val name = "scalatex"
  val description = "Compiles scalatex files into Scala compilation units"
  val components = List[PluginComponent](DemoComponent)
  private object DemoComponent extends PluginComponent {

    val global = CompilerPlugin.this.global
    import global._

    override val runsAfter = List("parser")
    override val runsBefore = List("namer")

    val phaseName = "Demo"

    override def newPhase(prev: Phase) = new GlobalPhase(prev) {
      val splitOptions = options.map(o => o.splitAt(o.indexOf(":")+1))
      val scalatexRoots = splitOptions.collect{case ("root:", p) => p}
      override def run() = {
        def recursiveListFiles(f: java.io.File): Iterator[java.io.File] = {
          val (dirs, files) =
            Option(f.listFiles())
              .toSeq
              .flatten
              .partition(_.isDirectory)
          files.iterator ++ dirs.iterator.flatMap(recursiveListFiles)
        }
        for {
          scalatexRoot <- scalatexRoots
          file <- recursiveListFiles(new java.io.File(scalatexRoot))
        } {
          val name = file.getCanonicalPath
          val fakeJfile = new java.io.File(name)
          val txt = io.Source.fromFile(name).mkString
          val virtualFile = new VirtualFile(name) {
            override def file = fakeJfile
          }
          val sourceFile = new BatchSourceFile(virtualFile, txt)
          val unit = new CompilationUnit(sourceFile)
          val objectName = name.slice(name.lastIndexOf('/')+1, name.lastIndexOf('.'))
          val pkgName =
            Paths.get(scalatexRoot)
                .relativize(fakeJfile.getParentFile.toPath)
                .toString
                .split("/")
                .map(s => s"package $s")
                .mkString("\n")

          val shim = s"""
            $pkgName
            import scalatags.Text.all._

            object $objectName{
              def apply() = scalatex.twf("${name}")
            }
          """
          unit.body = global.newUnitParser(shim).parse()
          global.currentRun.compileLate(unit)
        }
      }

      def name: String = phaseName

      def apply(unit: global.CompilationUnit): Unit = {}
    }
  }
}