summaryrefslogtreecommitdiff
path: root/scalanativelib/src/mill/scalanativelib/ScalaNativeModule.scala
blob: c8d9abda31e894158adad3aa00fab1e1cf958563 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
package mill
package scalanativelib

import java.net.URLClassLoader

import ammonite.ops.Path
import coursier.Cache
import coursier.maven.MavenRepository
import mill.define.{Target, Task}
import mill.eval.Result
import mill.modules.Jvm
import mill.scalalib.{Dep, DepSyntax, Lib, SbtModule, ScalaModule, TestModule, TestRunner}
import mill.util.Loose.Agg
import sbt.testing.{AnnotatedFingerprint, SubclassFingerprint}
import sbt.testing.Fingerprint
import upickle.default.{ReadWriter => RW, macroRW}


sealed abstract class NativeLogLevel(val level: Int) extends Ordered[NativeLogLevel] {
  def compare(that: NativeLogLevel) =  this.level - that.level
}

object NativeLogLevel {
  case object Error extends NativeLogLevel(200)
  case object Warn extends NativeLogLevel(300)
  case object Info extends NativeLogLevel(400)
  case object Debug extends NativeLogLevel(500)
  case object Trace extends NativeLogLevel(600)

  implicit def rw: RW[NativeLogLevel] = macroRW
}

sealed abstract class ReleaseMode(val name: String)

object ReleaseMode {
  case object Debug extends ReleaseMode("debug")
  case object Release extends ReleaseMode("release")

  implicit def rw: RW[ReleaseMode] = macroRW
}


trait ScalaNativeModule extends ScalaModule { outer =>
  def scalaNativeVersion: T[String]
  override def platformSuffix = s"_native${scalaNativeBinaryVersion()}"
  override def artifactSuffix: T[String] = s"${platformSuffix()}_${artifactScalaVersion()}"

  trait Tests extends TestScalaNativeModule {
    override def zincWorker = outer.zincWorker
    override def scalaOrganization = outer.scalaOrganization()
    override def scalaVersion = outer.scalaVersion()
    override def scalaNativeVersion = outer.scalaNativeVersion()
    override def releaseMode = outer.releaseMode()
    override def logLevel = outer.logLevel()
    override def moduleDeps = Seq(outer)
  }

  def scalaNativeBinaryVersion = T{ scalaNativeVersion().split('.').take(2).mkString(".") }

  // This allows compilation and testing versus SNAPSHOT versions of scala-native
  def scalaNativeToolsVersion = T{
    if (scalaNativeVersion().endsWith("-SNAPSHOT"))
      scalaNativeVersion()
    else
      scalaNativeBinaryVersion()
  }

  def scalaNativeWorker = T.task{ ScalaNativeWorkerApi.scalaNativeWorker().impl(bridgeFullClassPath()) }

  def scalaNativeWorkerClasspath = T {
    val workerKey = "MILL_SCALANATIVE_WORKER_" + scalaNativeBinaryVersion().replace('.', '_').replace('-', '_')
    val workerPath = sys.props(workerKey)
    if (workerPath != null)
      Result.Success(Agg(workerPath.split(',').map(p => PathRef(Path(p), quick = true)): _*))
    else
      Lib.resolveDependencies(
        Seq(Cache.ivy2Local, MavenRepository("https://repo1.maven.org/maven2")),
        Lib.depToDependency(_, "2.12.4", ""),
        Seq(ivy"com.lihaoyi::mill-scalanativelib-worker-${scalaNativeBinaryVersion()}:${sys.props("MILL_VERSION")}")
      )
  }

  def toolsIvyDeps = T{
    Seq(
      ivy"org.scala-native:tools_2.12:${scalaNativeVersion()}",
      ivy"org.scala-native:util_2.12:${scalaNativeVersion()}",
      ivy"org.scala-native:nir_2.12:${scalaNativeVersion()}"
    )
  }

  override def transitiveIvyDeps: T[Agg[Dep]] = T{
    ivyDeps() ++ nativeIvyDeps() ++ Task.traverse(moduleDeps)(_.transitiveIvyDeps)().flatten
  }

  def nativeLibIvy = T{ ivy"org.scala-native::nativelib_native${scalaNativeToolsVersion()}:${scalaNativeVersion()}" }

  def nativeIvyDeps = T{
    Seq(nativeLibIvy()) ++
    Seq(
      ivy"org.scala-native::javalib_native${scalaNativeToolsVersion()}:${scalaNativeVersion()}",
      ivy"org.scala-native::auxlib_native${scalaNativeToolsVersion()}:${scalaNativeVersion()}",
      ivy"org.scala-native::scalalib_native${scalaNativeToolsVersion()}:${scalaNativeVersion()}"
    )
  }

  def bridgeFullClassPath = T {
    Lib.resolveDependencies(
      Seq(Cache.ivy2Local, MavenRepository("https://repo1.maven.org/maven2")),
      Lib.depToDependency(_, scalaVersion(), platformSuffix()),
      toolsIvyDeps()
    ).map(t => (scalaNativeWorkerClasspath().toSeq ++ t.toSeq).map(_.path))
  }

  override def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++
    Agg(ivy"org.scala-native:nscplugin_${scalaVersion()}:${scalaNativeVersion()}")

  def logLevel: Target[NativeLogLevel] = T{ NativeLogLevel.Info }

  def releaseMode: Target[ReleaseMode] = T { ReleaseMode.Debug }

  def nativeWorkdir = T{ T.ctx().dest }

  // Location of the clang compiler
  def nativeClang = T{ scalaNativeWorker().discoverClang }

  // Location of the clang++ compiler
  def nativeClangPP = T{ scalaNativeWorker().discoverClangPP }

  // GC choice, either "none", "boehm" or "immix"
  def nativeGC = T{
    Option(System.getenv.get("SCALANATIVE_GC"))
      .getOrElse(scalaNativeWorker().defaultGarbageCollector)
  }

  def nativeTarget = T{ scalaNativeWorker().discoverTarget(nativeClang(), nativeWorkdir()) }

  // Options that are passed to clang during compilation
  def nativeCompileOptions = T{ scalaNativeWorker().discoverCompileOptions }

  // Options that are passed to clang during linking
  def nativeLinkingOptions = T{ scalaNativeWorker().discoverLinkingOptions }

  // Whether to link `@stub` methods, or ignore them
  def nativeLinkStubs = T { false }


  def nativeLibJar = T{
    resolveDeps(T.task{Agg(nativeLibIvy())})()
      .filter{p => p.toString.contains("scala-native") && p.toString.contains("nativelib")}
      .toList
      .head
  }

  def nativeConfig = T.task {
    val classpath = runClasspath().map(_.path).filter(_.toIO.exists).toList

    scalaNativeWorker().config(
      nativeLibJar().path,
      finalMainClass(),
      classpath,
      nativeWorkdir(),
      nativeClang(),
      nativeClangPP(),
      nativeTarget(),
      nativeCompileOptions(),
      nativeLinkingOptions(),
      nativeGC(),
      nativeLinkStubs(),
      releaseMode(),
      logLevel())
  }

  // Generates native binary
  def nativeLink = T{ scalaNativeWorker().nativeLink(nativeConfig(), (T.ctx().dest / 'out)) }

  // Runs the native binary
  override def run(args: String*) = T.command{
    Jvm.baseInteractiveSubprocess(
      Vector(nativeLink().toString) ++ args,
      forkEnv(),
      workingDir = ammonite.ops.pwd)
  }
}


trait TestScalaNativeModule extends ScalaNativeModule with TestModule { testOuter =>
  case class TestDefinition(framework: String, clazz: Class[_], fingerprint: Fingerprint) {
    def name = clazz.getName.reverse.dropWhile(_ == '$').reverse
  }

  override def testLocal(args: String*) = T.command { test(args:_*) }

  override def test(args: String*) = T.command{
    val outputPath = T.ctx().dest / "out.json"

    // The test frameworks run under the JVM and communicate with the native binary over a socket
    // therefore the test framework is loaded from a JVM classloader
    val testClassloader =
    new URLClassLoader(testClasspathJvm().map(_.path.toIO.toURI.toURL).toArray,
      this.getClass.getClassLoader)
    val frameworkInstances = TestRunner.frameworks(testFrameworks())(testClassloader)
    val testBinary = testRunnerNative.nativeLink().toIO
    val envVars = forkEnv()

    val nativeFrameworks = (cl: ClassLoader) =>
      frameworkInstances.zipWithIndex.map { case (f, id) =>
        scalaNativeWorker().newScalaNativeFrameWork(f, id, testBinary, logLevel(), envVars)
      }

    val (doneMsg, results) = TestRunner.runTests(
      nativeFrameworks,
      testClasspathJvm().map(_.path),
      Agg(compile().classes.path),
      args
    )

    TestModule.handleResults(doneMsg, results)
  }

  private val supportedTestFrameworks = Set("utest", "scalatest")

  // get the JVM classpath entries for supported test frameworks
  def testFrameworksJvmClasspath = T{
    Lib.resolveDependencies(
      repositories,
      Lib.depToDependency(_, scalaVersion(), ""),
      transitiveIvyDeps().filter(d => d.cross.isBinary && supportedTestFrameworks(d.dep.module.name))
    )
  }

  def testClasspathJvm = T{
    localClasspath() ++
      transitiveLocalClasspath() ++
      unmanagedClasspath() ++
      testFrameworksJvmClasspath()
  }

  // creates a specific binary used for running tests - has a different (generated) main class
  // which knows the names of all the tests and references to invoke them
  object testRunnerNative extends ScalaNativeModule {
    override def zincWorker = testOuter.zincWorker
    override def scalaOrganization = testOuter.scalaOrganization()
    override def scalaVersion = testOuter.scalaVersion()
    override def scalaNativeVersion = testOuter.scalaNativeVersion()
    override def moduleDeps = Seq(testOuter)
    override def releaseMode = testOuter.releaseMode()
    override def logLevel = testOuter.logLevel()
    override def nativeLinkStubs = true

    override def ivyDeps = testOuter.ivyDeps() ++ Agg(
      ivy"org.scala-native::test-interface_native${scalaNativeToolsVersion()}:${scalaNativeVersion()}"
    )

    override def mainClass = Some("scala.scalanative.testinterface.TestMain")

    override def generatedSources = T {
      val outDir = T.ctx().dest
      ammonite.ops.write.over(outDir / "TestMain.scala", makeTestMain())
      Seq(PathRef(outDir))
    }
  }

  // generate a main class for the tests
  def makeTestMain = T{
    val frameworkInstances = TestRunner.frameworks(testFrameworks()) _

    val testClasses =
      Jvm.inprocess(testClasspathJvm().map(_.path), classLoaderOverrideSbtTesting = true, isolated = true, closeContextClassLoaderWhenDone = true,
        cl => {
          frameworkInstances(cl).flatMap { framework =>
            val df = Lib.discoverTests(cl, framework, Agg(compile().classes.path))
            df.map(d => TestDefinition(framework.getClass.getName, d._1, d._2))
          }
        }
      )

    val frameworks = testClasses.map(_.framework).distinct

    val frameworksList =
      if (frameworks.nonEmpty) frameworks.mkString("List(new _root_.", ", new _root_.", ")")
      else {
        throw new Exception(
          "Cannot find any tests; make sure you defined the test framework correctly, " +
          "and extend whatever trait or annotation necessary to mark your test suites"
        )
      }


    val testsMap = makeTestsMap(testClasses)

    s"""package scala.scalanative.testinterface
       |object TestMain extends TestMainBase {
       |  override val frameworks = $frameworksList
       |  override val tests = Map[String, AnyRef]($testsMap)
       |  def main(args: Array[String]): Unit =
       |    testMain(args)
       |}""".stripMargin
  }

  private def makeTestsMap(tests: Seq[TestDefinition]): String = {
    tests
      .map { t =>
        val isModule = t.fingerprint match {
          case af: AnnotatedFingerprint => af.isModule
          case sf: SubclassFingerprint  => sf.isModule
        }

        val inst =
          if (isModule) s"_root_.${t.name}" else s"new _root_.${t.name}"
        s""""${t.name}" -> $inst"""
      }
      .mkString(", ")
  }
}


trait SbtNativeModule extends ScalaNativeModule with SbtModule