From 62a18754179d74252668879c42e1c5d6a45fbdce Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 8 Apr 2018 12:55:30 -0700 Subject: Enable JUnit testing, via sbt-test-interface, for `JavaModule`s --- scalalib/src/mill/scalalib/JavaModule.scala | 71 +++++++++++++- scalalib/src/mill/scalalib/Lib.scala | 117 +++++++++++++++++++++++- scalalib/src/mill/scalalib/ScalaModule.scala | 63 +------------ scalalib/src/mill/scalalib/ScalaWorkerApi.scala | 6 -- 4 files changed, 185 insertions(+), 72 deletions(-) (limited to 'scalalib/src') diff --git a/scalalib/src/mill/scalalib/JavaModule.scala b/scalalib/src/mill/scalalib/JavaModule.scala index b8ae0fd4..ea56b22e 100644 --- a/scalalib/src/mill/scalalib/JavaModule.scala +++ b/scalalib/src/mill/scalalib/JavaModule.scala @@ -17,7 +17,11 @@ import mill.util.Loose.Agg * Core configuration required to compile a single Scala compilation target */ trait JavaModule extends mill.Module with TaskModule { outer => - + trait Tests extends TestModule{ + override def moduleDeps = Seq(outer) + override def repositories = outer.repositories + override def javacOptions = outer.javacOptions + } def defaultCommandName() = "run" def resolvePublishDependency: Task[Dep => publish.Dependency] = T.task{ @@ -281,4 +285,67 @@ trait JavaModule extends mill.Module with TaskModule { outer => def artifactName: T[String] = millModuleSegments.parts.mkString("-") def artifactId: T[String] = artifactName() -} \ No newline at end of file +} + +trait TestModule extends JavaModule with TaskModule { + override def defaultCommandName() = "test" + def testFrameworks: T[Seq[String]] + + def forkWorkingDir = ammonite.ops.pwd + + def test(args: String*) = T.command{ + val outputPath = T.ctx().dest/"out.json" + + Jvm.subprocess( + mainClass = "mill.scalaworker.ScalaWorker", + classPath = ScalaWorkerModule.classpath(), + jvmArgs = forkArgs(), + envArgs = forkEnv(), + mainArgs = + Seq(testFrameworks().length.toString) ++ + testFrameworks() ++ + Seq(runClasspath().length.toString) ++ + runClasspath().map(_.path.toString) ++ + Seq(args.length.toString) ++ + args ++ + Seq(outputPath.toString, T.ctx().log.colored.toString, compile().classes.path.toString, T.ctx().home.toString), + workingDir = forkWorkingDir + ) + + val jsonOutput = ujson.read(outputPath.toIO) + val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput) + TestModule.handleResults(doneMsg, results) + + } + def testLocal(args: String*) = T.command{ + val outputPath = T.ctx().dest/"out.json" + + Lib.runTests( + TestRunner.frameworks(testFrameworks()), + runClasspath().map(_.path), + Agg(compile().classes.path), + args + ) + + val jsonOutput = ujson.read(outputPath.toIO) + val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput) + TestModule.handleResults(doneMsg, results) + + } +} + +object TestModule{ + def handleResults(doneMsg: String, results: Seq[TestRunner.Result]) = { + + val badTests = results.filter(x => Set("Error", "Failure").contains(x.status)) + if (badTests.isEmpty) Result.Success((doneMsg, results)) + else { + val suffix = if (badTests.length == 1) "" else " and " + (badTests.length-1) + " more" + + Result.Failure( + badTests.head.fullyQualifiedName + " " + badTests.head.selector + suffix, + Some((doneMsg, results)) + ) + } + } +} diff --git a/scalalib/src/mill/scalalib/Lib.scala b/scalalib/src/mill/scalalib/Lib.scala index 7b4b5bdb..3eb2defd 100644 --- a/scalalib/src/mill/scalalib/Lib.scala +++ b/scalalib/src/mill/scalalib/Lib.scala @@ -1,14 +1,22 @@ package mill package scalalib -import java.io.File +import java.io.{File, FileInputStream} +import java.lang.annotation.Annotation +import java.util.zip.ZipInputStream import javax.tools.ToolProvider import ammonite.ops._ import ammonite.util.Util import coursier.{Cache, Dependency, Fetch, Repository, Resolution} +import mill.Agg import mill.eval.{PathRef, Result} -import mill.util.Loose.Agg +import mill.modules.Jvm + +import mill.util.Ctx +import sbt.testing._ + +import scala.collection.mutable object CompilationResult { implicit val jsonFormatter: upickle.default.ReadWriter[CompilationResult] = upickle.default.macroRW @@ -194,4 +202,109 @@ object Lib{ force = false ) + def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework], + entireClasspath: Agg[Path], + testClassfilePath: Agg[Path], + args: Seq[String]) + (implicit ctx: Ctx.Log with Ctx.Home): (String, Seq[mill.scalalib.TestRunner.Result]) = { + Jvm.inprocess(entireClasspath, classLoaderOverrideSbtTesting = true, cl => { + val frameworks = frameworkInstances(cl) + + val events = mutable.Buffer.empty[Event] + + val doneMessages = frameworks.map { framework => + val runner = framework.runner(args.toArray, args.toArray, cl) + + val testClasses = discoverTests(cl, framework, testClassfilePath) + + val tasks = runner.tasks( + for ((cls, fingerprint) <- testClasses.toArray) + yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array(new SuiteSelector)) + ) + + for (t <- tasks) { + t.execute( + new EventHandler { + def handle(event: Event) = events.append(event) + }, + Array( + new Logger { + def debug(msg: String) = ctx.log.outputStream.println(msg) + + def error(msg: String) = ctx.log.outputStream.println(msg) + + def ansiCodesSupported() = true + + def warn(msg: String) = ctx.log.outputStream.println(msg) + + def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream) + + def info(msg: String) = ctx.log.outputStream.println(msg) + }) + ) + } + ctx.log.outputStream.println(runner.done()) + } + + val results = for(e <- events) yield { + val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None + mill.scalalib.TestRunner.Result( + e.fullyQualifiedName(), + e.selector() match{ + case s: NestedSuiteSelector => s.suiteId() + case s: NestedTestSelector => s.suiteId() + "." + s.testName() + case s: SuiteSelector => s.toString + case s: TestSelector => s.testName() + case s: TestWildcardSelector => s.testWildcard() + }, + e.duration(), + e.status().toString, + ex.map(_.getClass.getName), + ex.map(_.getMessage), + ex.map(_.getStackTrace) + ) + } + + (doneMessages.mkString("\n"), results) + }) + } + + def listClassFiles(base: Path): Iterator[String] = { + if (base.isDir) ls.rec(base).toIterator.filter(_.ext == "class").map(_.relativeTo(base).toString) + else { + val zip = new ZipInputStream(new FileInputStream(base.toIO)) + Iterator.continually(zip.getNextEntry).takeWhile(_ != null).map(_.getName).filter(_.endsWith(".class")) + } + } + + def discoverTests(cl: ClassLoader, framework: Framework, classpath: Agg[Path]) = { + + val fingerprints = framework.fingerprints() + + val testClasses = classpath.flatMap { base => + // Don't blow up if there are no classfiles representing + // the tests to run Instead just don't run anything + if (!exists(base)) Nil + else listClassFiles(base).flatMap { path => + val cls = cl.loadClass(path.stripSuffix(".class").replace('/', '.')) + fingerprints.find { + case f: SubclassFingerprint => + !cls.isInterface && + (f.isModule == cls.getName.endsWith("$")) && + cl.loadClass(f.superclassName()).isAssignableFrom(cls) + case f: AnnotatedFingerprint => + val annotationCls = cl.loadClass(f.annotationName()).asInstanceOf[Class[Annotation]] + (f.isModule == cls.getName.endsWith("$")) && + ( + cls.isAnnotationPresent(annotationCls) || + cls.getDeclaredMethods.exists(_.isAnnotationPresent(annotationCls)) + ) + + }.map { f => (cls, f) } + } + } + + testClasses + } + } diff --git a/scalalib/src/mill/scalalib/ScalaModule.scala b/scalalib/src/mill/scalalib/ScalaModule.scala index b98f248e..a2ca09c5 100644 --- a/scalalib/src/mill/scalalib/ScalaModule.scala +++ b/scalalib/src/mill/scalalib/ScalaModule.scala @@ -18,7 +18,7 @@ import mill.util.DummyInputStream trait ScalaModule extends JavaModule { outer => def scalaWorker: ScalaWorkerModule = mill.scalalib.ScalaWorkerModule - trait Tests extends TestModule{ + trait Tests extends TestModule with ScalaModule{ def scalaVersion = outer.scalaVersion() override def repositories = outer.repositories override def scalacPluginIvyDeps = outer.scalacPluginIvyDeps @@ -195,64 +195,3 @@ trait ScalaModule extends JavaModule { outer => } -object TestModule{ - def handleResults(doneMsg: String, results: Seq[TestRunner.Result]) = { - - val badTests = results.filter(x => Set("Error", "Failure").contains(x.status)) - if (badTests.isEmpty) Result.Success((doneMsg, results)) - else { - val suffix = if (badTests.length == 1) "" else " and " + (badTests.length-1) + " more" - - Result.Failure( - badTests.head.fullyQualifiedName + " " + badTests.head.selector + suffix, - Some((doneMsg, results)) - ) - } - } -} -trait TestModule extends ScalaModule with TaskModule { - override def defaultCommandName() = "test" - def testFrameworks: T[Seq[String]] - - def forkWorkingDir = ammonite.ops.pwd - - def test(args: String*) = T.command{ - val outputPath = T.ctx().dest/"out.json" - - Jvm.subprocess( - mainClass = "mill.scalaworker.ScalaWorker", - classPath = scalaWorker.classpath(), - jvmArgs = forkArgs(), - envArgs = forkEnv(), - mainArgs = - Seq(testFrameworks().length.toString) ++ - testFrameworks() ++ - Seq(runClasspath().length.toString) ++ - runClasspath().map(_.path.toString) ++ - Seq(args.length.toString) ++ - args ++ - Seq(outputPath.toString, T.ctx().log.colored.toString, compile().classes.path.toString, T.ctx().home.toString), - workingDir = forkWorkingDir - ) - - val jsonOutput = ujson.read(outputPath.toIO) - val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput) - TestModule.handleResults(doneMsg, results) - - } - def testLocal(args: String*) = T.command{ - val outputPath = T.ctx().dest/"out.json" - - scalaWorker.worker().runTests( - TestRunner.frameworks(testFrameworks()), - runClasspath().map(_.path), - Agg(compile().classes.path), - args - ) - - val jsonOutput = ujson.read(outputPath.toIO) - val (doneMsg, results) = upickle.default.readJs[(String, Seq[TestRunner.Result])](jsonOutput) - TestModule.handleResults(doneMsg, results) - - } -} diff --git a/scalalib/src/mill/scalalib/ScalaWorkerApi.scala b/scalalib/src/mill/scalalib/ScalaWorkerApi.scala index f6500ae8..0b6a931b 100644 --- a/scalalib/src/mill/scalalib/ScalaWorkerApi.scala +++ b/scalalib/src/mill/scalalib/ScalaWorkerApi.scala @@ -5,7 +5,6 @@ import ammonite.ops.Path import coursier.Cache import coursier.maven.MavenRepository import mill.Agg -import mill.scalalib.TestRunner.Result import mill.T import mill.define.{Discover, Worker} import mill.scalalib.Lib.resolveDependencies @@ -68,11 +67,6 @@ trait ScalaWorkerApi { upstreamCompileOutput: Seq[CompilationResult]) (implicit ctx: mill.util.Ctx): mill.eval.Result[CompilationResult] - def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework], - entireClasspath: Agg[Path], - testClassfilePath: Agg[Path], - args: Seq[String]) - (implicit ctx: mill.util.Ctx.Log with mill.util.Ctx.Home): (String, Seq[Result]) def discoverMainClasses(compilationResult: CompilationResult) (implicit ctx: mill.util.Ctx): Seq[String] -- cgit v1.2.3