summaryrefslogtreecommitdiff
path: root/scalalib/src
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-04-08 12:55:30 -0700
committerLi Haoyi <haoyi.sg@gmail.com>2018-04-08 12:55:30 -0700
commit62a18754179d74252668879c42e1c5d6a45fbdce (patch)
tree99a820224293a2367c116272283fd2a3426957ae /scalalib/src
parentdbcad35c05f1726f26b8033524e8fdd3d68b2de9 (diff)
downloadmill-62a18754179d74252668879c42e1c5d6a45fbdce.tar.gz
mill-62a18754179d74252668879c42e1c5d6a45fbdce.tar.bz2
mill-62a18754179d74252668879c42e1c5d6a45fbdce.zip
Enable JUnit testing, via sbt-test-interface, for `JavaModule`s
Diffstat (limited to 'scalalib/src')
-rw-r--r--scalalib/src/mill/scalalib/JavaModule.scala71
-rw-r--r--scalalib/src/mill/scalalib/Lib.scala117
-rw-r--r--scalalib/src/mill/scalalib/ScalaModule.scala63
-rw-r--r--scalalib/src/mill/scalalib/ScalaWorkerApi.scala6
4 files changed, 185 insertions, 72 deletions
diff --git a/scalalib/src/mill/scalalib/JavaModule.scala b/scalalib/src/mill/scalalib/JavaModule.scala
index b8ae0fd4..ea56b22e 100644
--- a/scalalib/src/mill/scalalib/JavaModule.scala
+++ b/scalalib/src/mill/scalalib/JavaModule.scala
@@ -17,7 +17,11 @@ import mill.util.Loose.Agg
* Core configuration required to compile a single Scala compilation target
*/
trait JavaModule extends mill.Module with TaskModule { outer =>
-
+ trait Tests extends TestModule{
+ override def moduleDeps = Seq(outer)
+ override def repositories = outer.repositories
+ override def javacOptions = outer.javacOptions
+ }
def defaultCommandName() = "run"
def resolvePublishDependency: Task[Dep => publish.Dependency] = T.task{
@@ -281,4 +285,67 @@ trait JavaModule extends mill.Module with TaskModule { outer =>
def artifactName: T[String] = millModuleSegments.parts.mkString("-")
def artifactId: T[String] = artifactName()
-} \ No newline at end of file
+}
+
+trait TestModule extends JavaModule with TaskModule {
+ override def defaultCommandName() = "test"
+ def testFrameworks: T[Seq[String]]
+
+ def forkWorkingDir = ammonite.ops.pwd
+
+ def test(args: String*) = T.command{
+ val outputPath = T.ctx().dest/"out.json"
+
+ Jvm.subprocess(
+ mainClass = "mill.scalaworker.ScalaWorker",
+ classPath = ScalaWorkerModule.classpath(),
+ jvmArgs = forkArgs(),
+ envArgs = forkEnv(),
+ mainArgs =
+ Seq(testFrameworks().length.toString) ++
+ testFrameworks() ++
+ Seq(runClasspath().length.toString) ++
+ runClasspath().map(_.path.toString) ++
+ Seq(args.length.toString) ++
+ args ++
+ Seq(outputPath.toString, T.ctx().log.colored.toString, compile().classes.path.toString, T.ctx().home.toString),
+ workingDir = forkWorkingDir
+ )
+
+ val jsonOutput = ujson.read(outputPath.toIO)
+ val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput)
+ TestModule.handleResults(doneMsg, results)
+
+ }
+ def testLocal(args: String*) = T.command{
+ val outputPath = T.ctx().dest/"out.json"
+
+ Lib.runTests(
+ TestRunner.frameworks(testFrameworks()),
+ runClasspath().map(_.path),
+ Agg(compile().classes.path),
+ args
+ )
+
+ val jsonOutput = ujson.read(outputPath.toIO)
+ val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput)
+ TestModule.handleResults(doneMsg, results)
+
+ }
+}
+
+object TestModule{
+ def handleResults(doneMsg: String, results: Seq[TestRunner.Result]) = {
+
+ val badTests = results.filter(x => Set("Error", "Failure").contains(x.status))
+ if (badTests.isEmpty) Result.Success((doneMsg, results))
+ else {
+ val suffix = if (badTests.length == 1) "" else " and " + (badTests.length-1) + " more"
+
+ Result.Failure(
+ badTests.head.fullyQualifiedName + " " + badTests.head.selector + suffix,
+ Some((doneMsg, results))
+ )
+ }
+ }
+}
diff --git a/scalalib/src/mill/scalalib/Lib.scala b/scalalib/src/mill/scalalib/Lib.scala
index 7b4b5bdb..3eb2defd 100644
--- a/scalalib/src/mill/scalalib/Lib.scala
+++ b/scalalib/src/mill/scalalib/Lib.scala
@@ -1,14 +1,22 @@
package mill
package scalalib
-import java.io.File
+import java.io.{File, FileInputStream}
+import java.lang.annotation.Annotation
+import java.util.zip.ZipInputStream
import javax.tools.ToolProvider
import ammonite.ops._
import ammonite.util.Util
import coursier.{Cache, Dependency, Fetch, Repository, Resolution}
+import mill.Agg
import mill.eval.{PathRef, Result}
-import mill.util.Loose.Agg
+import mill.modules.Jvm
+
+import mill.util.Ctx
+import sbt.testing._
+
+import scala.collection.mutable
object CompilationResult {
implicit val jsonFormatter: upickle.default.ReadWriter[CompilationResult] = upickle.default.macroRW
@@ -194,4 +202,109 @@ object Lib{
force = false
)
+ def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework],
+ entireClasspath: Agg[Path],
+ testClassfilePath: Agg[Path],
+ args: Seq[String])
+ (implicit ctx: Ctx.Log with Ctx.Home): (String, Seq[mill.scalalib.TestRunner.Result]) = {
+ Jvm.inprocess(entireClasspath, classLoaderOverrideSbtTesting = true, cl => {
+ val frameworks = frameworkInstances(cl)
+
+ val events = mutable.Buffer.empty[Event]
+
+ val doneMessages = frameworks.map { framework =>
+ val runner = framework.runner(args.toArray, args.toArray, cl)
+
+ val testClasses = discoverTests(cl, framework, testClassfilePath)
+
+ val tasks = runner.tasks(
+ for ((cls, fingerprint) <- testClasses.toArray)
+ yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array(new SuiteSelector))
+ )
+
+ for (t <- tasks) {
+ t.execute(
+ new EventHandler {
+ def handle(event: Event) = events.append(event)
+ },
+ Array(
+ new Logger {
+ def debug(msg: String) = ctx.log.outputStream.println(msg)
+
+ def error(msg: String) = ctx.log.outputStream.println(msg)
+
+ def ansiCodesSupported() = true
+
+ def warn(msg: String) = ctx.log.outputStream.println(msg)
+
+ def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream)
+
+ def info(msg: String) = ctx.log.outputStream.println(msg)
+ })
+ )
+ }
+ ctx.log.outputStream.println(runner.done())
+ }
+
+ val results = for(e <- events) yield {
+ val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None
+ mill.scalalib.TestRunner.Result(
+ e.fullyQualifiedName(),
+ e.selector() match{
+ case s: NestedSuiteSelector => s.suiteId()
+ case s: NestedTestSelector => s.suiteId() + "." + s.testName()
+ case s: SuiteSelector => s.toString
+ case s: TestSelector => s.testName()
+ case s: TestWildcardSelector => s.testWildcard()
+ },
+ e.duration(),
+ e.status().toString,
+ ex.map(_.getClass.getName),
+ ex.map(_.getMessage),
+ ex.map(_.getStackTrace)
+ )
+ }
+
+ (doneMessages.mkString("\n"), results)
+ })
+ }
+
+ def listClassFiles(base: Path): Iterator[String] = {
+ if (base.isDir) ls.rec(base).toIterator.filter(_.ext == "class").map(_.relativeTo(base).toString)
+ else {
+ val zip = new ZipInputStream(new FileInputStream(base.toIO))
+ Iterator.continually(zip.getNextEntry).takeWhile(_ != null).map(_.getName).filter(_.endsWith(".class"))
+ }
+ }
+
+ def discoverTests(cl: ClassLoader, framework: Framework, classpath: Agg[Path]) = {
+
+ val fingerprints = framework.fingerprints()
+
+ val testClasses = classpath.flatMap { base =>
+ // Don't blow up if there are no classfiles representing
+ // the tests to run Instead just don't run anything
+ if (!exists(base)) Nil
+ else listClassFiles(base).flatMap { path =>
+ val cls = cl.loadClass(path.stripSuffix(".class").replace('/', '.'))
+ fingerprints.find {
+ case f: SubclassFingerprint =>
+ !cls.isInterface &&
+ (f.isModule == cls.getName.endsWith("$")) &&
+ cl.loadClass(f.superclassName()).isAssignableFrom(cls)
+ case f: AnnotatedFingerprint =>
+ val annotationCls = cl.loadClass(f.annotationName()).asInstanceOf[Class[Annotation]]
+ (f.isModule == cls.getName.endsWith("$")) &&
+ (
+ cls.isAnnotationPresent(annotationCls) ||
+ cls.getDeclaredMethods.exists(_.isAnnotationPresent(annotationCls))
+ )
+
+ }.map { f => (cls, f) }
+ }
+ }
+
+ testClasses
+ }
+
}
diff --git a/scalalib/src/mill/scalalib/ScalaModule.scala b/scalalib/src/mill/scalalib/ScalaModule.scala
index b98f248e..a2ca09c5 100644
--- a/scalalib/src/mill/scalalib/ScalaModule.scala
+++ b/scalalib/src/mill/scalalib/ScalaModule.scala
@@ -18,7 +18,7 @@ import mill.util.DummyInputStream
trait ScalaModule extends JavaModule { outer =>
def scalaWorker: ScalaWorkerModule = mill.scalalib.ScalaWorkerModule
- trait Tests extends TestModule{
+ trait Tests extends TestModule with ScalaModule{
def scalaVersion = outer.scalaVersion()
override def repositories = outer.repositories
override def scalacPluginIvyDeps = outer.scalacPluginIvyDeps
@@ -195,64 +195,3 @@ trait ScalaModule extends JavaModule { outer =>
}
-object TestModule{
- def handleResults(doneMsg: String, results: Seq[TestRunner.Result]) = {
-
- val badTests = results.filter(x => Set("Error", "Failure").contains(x.status))
- if (badTests.isEmpty) Result.Success((doneMsg, results))
- else {
- val suffix = if (badTests.length == 1) "" else " and " + (badTests.length-1) + " more"
-
- Result.Failure(
- badTests.head.fullyQualifiedName + " " + badTests.head.selector + suffix,
- Some((doneMsg, results))
- )
- }
- }
-}
-trait TestModule extends ScalaModule with TaskModule {
- override def defaultCommandName() = "test"
- def testFrameworks: T[Seq[String]]
-
- def forkWorkingDir = ammonite.ops.pwd
-
- def test(args: String*) = T.command{
- val outputPath = T.ctx().dest/"out.json"
-
- Jvm.subprocess(
- mainClass = "mill.scalaworker.ScalaWorker",
- classPath = scalaWorker.classpath(),
- jvmArgs = forkArgs(),
- envArgs = forkEnv(),
- mainArgs =
- Seq(testFrameworks().length.toString) ++
- testFrameworks() ++
- Seq(runClasspath().length.toString) ++
- runClasspath().map(_.path.toString) ++
- Seq(args.length.toString) ++
- args ++
- Seq(outputPath.toString, T.ctx().log.colored.toString, compile().classes.path.toString, T.ctx().home.toString),
- workingDir = forkWorkingDir
- )
-
- val jsonOutput = ujson.read(outputPath.toIO)
- val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput)
- TestModule.handleResults(doneMsg, results)
-
- }
- def testLocal(args: String*) = T.command{
- val outputPath = T.ctx().dest/"out.json"
-
- scalaWorker.worker().runTests(
- TestRunner.frameworks(testFrameworks()),
- runClasspath().map(_.path),
- Agg(compile().classes.path),
- args
- )
-
- val jsonOutput = ujson.read(outputPath.toIO)
- val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput)
- TestModule.handleResults(doneMsg, results)
-
- }
-}
diff --git a/scalalib/src/mill/scalalib/ScalaWorkerApi.scala b/scalalib/src/mill/scalalib/ScalaWorkerApi.scala
index f6500ae8..0b6a931b 100644
--- a/scalalib/src/mill/scalalib/ScalaWorkerApi.scala
+++ b/scalalib/src/mill/scalalib/ScalaWorkerApi.scala
@@ -5,7 +5,6 @@ import ammonite.ops.Path
import coursier.Cache
import coursier.maven.MavenRepository
import mill.Agg
-import mill.scalalib.TestRunner.Result
import mill.T
import mill.define.{Discover, Worker}
import mill.scalalib.Lib.resolveDependencies
@@ -68,11 +67,6 @@ trait ScalaWorkerApi {
upstreamCompileOutput: Seq[CompilationResult])
(implicit ctx: mill.util.Ctx): mill.eval.Result[CompilationResult]
- def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework],
- entireClasspath: Agg[Path],
- testClassfilePath: Agg[Path],
- args: Seq[String])
- (implicit ctx: mill.util.Ctx.Log with mill.util.Ctx.Home): (String, Seq[Result])
def discoverMainClasses(compilationResult: CompilationResult)
(implicit ctx: mill.util.Ctx): Seq[String]