summaryrefslogtreecommitdiff
path: root/contrib/twirllib/src/TwirlWorker.scala
blob: 19eb47251e737843563da129994b1d784807598c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package mill
package twirllib

import java.io.File
import java.lang.reflect.Method
import java.net.URLClassLoader
import java.nio.charset.Charset

import mill.api.PathRef
import mill.scalalib.api.CompilationResult

import scala.io.Codec
class TwirlWorker {

  private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)]

  private def twirl(twirlClasspath: Agg[os.Path]) = {
    val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum
    twirlInstanceCache match {
      case Some((sig, instance)) if sig == classloaderSig => instance
      case _ =>
        val cl = new URLClassLoader(twirlClasspath.map(_.toIO.toURI.toURL).toArray, null)

        // Switched to using the java api because of the hack-ish thing going on later.
        //
        // * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports)
        // * we'll need to construct a collection of constructor annotations// *
        // * the default collection in scala api is a Seq[String]
        // * but it is defined in a different classloader (namely in cl)
        // * so we can not construct our own Seq and pass it to the method - it will be from our classloader, and not compatible
        // * the java api uses java collections, manipulating which using reflection is much simpler
        //
        // NOTE: When creating the cl classloader with passing the current classloader as the parent:
        //   val cl = new URLClassLoader(twirlClasspath.map(_.toIO.toURI.toURL).toArray, getClass.getClassLoader)
        // it is possible to cast the default to a Seq[String], construct our own Seq[String], and pass it to the method invoke -
        // classe will be compatible (the tests passed).
        // But when run in an actual mill project with this module enabled, there were exceptions like this:
        // scala.reflect.internal.MissingRequirementError: object scala in compiler mirror not found.

        val twirlCompilerClass = cl.loadClass("play.japi.twirl.compiler.TwirlCompiler")

        val codecClass = cl.loadClass("scala.io.Codec")
        val charsetClass = cl.loadClass("java.nio.charset.Charset")
        val arrayListClass = cl.loadClass("java.util.ArrayList")
        val hashSetClass = cl.loadClass("java.util.HashSet")

        val codecApplyMethod = codecClass.getMethod("apply", charsetClass)
        val charsetForNameMethod = charsetClass.getMethod("forName", classOf[java.lang.String])

        val compileMethod = twirlCompilerClass.getMethod("compile",
          classOf[java.io.File],
          classOf[java.io.File],
          classOf[java.io.File],
          classOf[java.lang.String],
          cl.loadClass("java.util.Collection"),
          cl.loadClass("java.util.List"),
          cl.loadClass("scala.io.Codec"),
          classOf[Boolean])

        val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")

        val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection"))

        val instance = new TwirlWorkerApi {
          override def compileTwirl(source: File,
                                    sourceDirectory: File,
                                    generatedDirectory: File,
                                    formatterType: String,
                                    additionalImports: Seq[String],
                                    constructorAnnotations: Seq[String],
                                    codec: Codec,
                                    inclusiveDot: Boolean) {
            // val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS()
            // val twirlAdditionalImports = new HashSet(defaultImports)
            // additionalImports.foreach(twirlAdditionalImports.add)
            val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection
            val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object]
            val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object])
            additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _))

            // Codec.apply(Charset.forName(codec.charSet.name()))
            val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name()))

            // val twirlConstructorAnnotations = new ArrayList()
            // constructorAnnotations.foreach(twirlConstructorAnnotations.add)
            val twirlConstructorAnnotations = arrayListClass.newInstance().asInstanceOf[Object]
            val arrayListAddMethod = twirlConstructorAnnotations.getClass.getMethod("add", classOf[Object])
            constructorAnnotations.foreach(arrayListAddMethod.invoke(twirlConstructorAnnotations, _))

            // JavaAPI
            //   public static Optional<File> compile(
            //   File source,
            //   File sourceDirectory,
            //   File generatedDirectory,
            //   String formatterType,
            //   Collection<String> additionalImports,
            //   List<String> constructorAnnotations,
            //   Codec codec,
            //   boolean inclusiveDot
            // )
            val o = compileMethod.invoke(null, source,
              sourceDirectory,
              generatedDirectory,
              formatterType,
              twirlAdditionalImports,
              twirlConstructorAnnotations,
              twirlCodec,
              Boolean.box(inclusiveDot)
            )
          }
        }
        twirlInstanceCache = Some((classloaderSig, instance))
        instance
    }
  }

  def compile(twirlClasspath: Agg[os.Path],
              sourceDirectories: Seq[os.Path],
              dest: os.Path,
              additionalImports: Seq[String],
              constructorAnnotations: Seq[String],
              codec: Codec,
              inclusiveDot: Boolean)
             (implicit ctx: mill.api.Ctx): mill.api.Result[CompilationResult] = {
    val compiler = twirl(twirlClasspath)

    def compileTwirlDir(inputDir: os.Path) {
      os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)"))
        .foreach { template =>
          val extFormat = twirlExtensionFormat(template.last)
          compiler.compileTwirl(template.toIO,
            inputDir.toIO,
            dest.toIO,
            s"play.twirl.api.$extFormat",
            additionalImports,
            constructorAnnotations,
            codec,
            inclusiveDot
          )
        }
    }

    sourceDirectories.foreach(compileTwirlDir)

    val zincFile = ctx.dest / 'zinc
    val classesDir = ctx.dest

    mill.api.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
  }

  private def twirlExtensionFormat(name: String) =
    if (name.endsWith("html")) "HtmlFormat"
    else if (name.endsWith("xml")) "XmlFormat"
    else if (name.endsWith("js")) "JavaScriptFormat"
    else "TxtFormat"
}

trait TwirlWorkerApi {
  def compileTwirl(source: File,
                   sourceDirectory: File,
                   generatedDirectory: File,
                   formatterType: String,
                   additionalImports: Seq[String],
                   constructorAnnotations: Seq[String],
                   codec: Codec,
                   inclusiveDot: Boolean)
}

object TwirlWorkerApi {

  def twirlWorker = new TwirlWorker()
}