summaryrefslogtreecommitdiff
path: root/contrib/twirllib/src/mill/twirllib/TwirlWorker.scala
blob: ce8175a6834975564ba22055f8d86486471e4952 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package mill
package twirllib

import java.io.File
import java.lang.reflect.Method
import java.net.URLClassLoader

import mill.eval.PathRef
import mill.scalalib.CompilationResult

import scala.io.Codec

class TwirlWorker {

  private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)]

  private def twirl(twirlClasspath: Agg[os.Path]) = {
    val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum
    twirlInstanceCache match {
      case Some((sig, instance)) if sig == classloaderSig => instance
      case _ =>
        val cl = new URLClassLoader(twirlClasspath.map(_.toIO.toURI.toURL).toArray, null)
        val twirlCompilerClass = cl.loadClass("play.twirl.compiler.TwirlCompiler")
        val compileMethod = twirlCompilerClass.getMethod("compile",
          classOf[java.io.File],
          classOf[java.io.File],
          classOf[java.io.File],
          classOf[java.lang.String],
          cl.loadClass("scala.collection.Seq"),
          cl.loadClass("scala.collection.Seq"),
          cl.loadClass("scala.io.Codec"),
          classOf[Boolean])

        val defaultAdditionalImportsMethod = twirlCompilerClass.getMethod("compile$default$5")
        val defaultConstructorAnnotationsMethod = twirlCompilerClass.getMethod("compile$default$6")
        val defaultCodecMethod = twirlCompilerClass.getMethod("compile$default$7")
        val defaultFlagMethod = twirlCompilerClass.getMethod("compile$default$8")

        val instance = new TwirlWorkerApi {
          override def compileTwirl(source: File,
                                    sourceDirectory: File,
                                    generatedDirectory: File,
                                    formatterType: String,
                                    additionalImports: Seq[String],
                                    constructorAnnotations: Seq[String],
                                    codec: Codec,
                                    inclusiveDot: Boolean) {
            val o = compileMethod.invoke(null, source,
              sourceDirectory,
              generatedDirectory,
              formatterType,
              defaultAdditionalImportsMethod.invoke(null),
              defaultConstructorAnnotationsMethod.invoke(null),
              defaultCodecMethod.invoke(null),
              defaultFlagMethod.invoke(null))
          }
        }
        twirlInstanceCache = Some((classloaderSig, instance))
        instance
    }
  }

  def compile(twirlClasspath: Agg[os.Path],
              sourceDirectories: Seq[os.Path],
              dest: os.Path,
              additionalImports: Seq[String],
              constructorAnnotations: Seq[String],
              codec: Codec,
              inclusiveDot: Boolean)
             (implicit ctx: mill.util.Ctx): mill.eval.Result[CompilationResult] = {
    val compiler = twirl(twirlClasspath)

    def compileTwirlDir(inputDir: os.Path) {
      os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)"))
        .foreach { template =>
          val extFormat = twirlExtensionFormat(template.last)
          compiler.compileTwirl(template.toIO,
            inputDir.toIO,
            dest.toIO,
            s"play.twirl.api.$extFormat",
            additionalImports,
            constructorAnnotations,
            codec,
            inclusiveDot
          )
        }
    }

    sourceDirectories.foreach(compileTwirlDir)

    val zincFile = ctx.dest / 'zinc
    val classesDir = ctx.dest / 'html

    mill.eval.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
  }

  private def twirlExtensionFormat(name: String) =
    if (name.endsWith("html")) "HtmlFormat"
    else if (name.endsWith("xml")) "XmlFormat"
    else if (name.endsWith("js")) "JavaScriptFormat"
    else "TxtFormat"
}

trait TwirlWorkerApi {
  def compileTwirl(source: File,
                   sourceDirectory: File,
                   generatedDirectory: File,
                   formatterType: String,
                   additionalImports: Seq[String],
                   constructorAnnotations: Seq[String],
                   codec: Codec,
                   inclusiveDot: Boolean)
}

object TwirlWorkerApi {

  def twirlWorker = new TwirlWorker()
}