diff options
Diffstat (limited to 'scalalib/src/TestRunner.scala')
-rw-r--r-- | scalalib/src/TestRunner.scala | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/scalalib/src/TestRunner.scala b/scalalib/src/TestRunner.scala new file mode 100644 index 00000000..42e65d63 --- /dev/null +++ b/scalalib/src/TestRunner.scala @@ -0,0 +1,153 @@ +package mill.scalalib +import ammonite.util.Colors +import mill.Agg +import mill.modules.Jvm +import mill.scalalib.Lib.discoverTests +import mill.util.{Ctx, PrintLogger} +import mill.util.JsonFormatters._ +import sbt.testing._ + +import scala.collection.mutable +object TestRunner { + + + def main(args: Array[String]): Unit = { + try{ + var i = 0 + def readArray() = { + val count = args(i).toInt + val slice = args.slice(i + 1, i + count + 1) + i = i + count + 1 + slice + } + val frameworks = readArray() + val classpath = readArray() + val arguments = readArray() + val outputPath = args(i + 0) + val colored = args(i + 1) + val testCp = args(i + 2) + val homeStr = args(i + 3) + val ctx = new Ctx.Log with Ctx.Home { + val log = PrintLogger( + colored == "true", + true, + if(colored == "true") Colors.Default + else Colors.BlackWhite, + System.out, + System.err, + System.err, + System.in, + debugEnabled = false + ) + val home = os.Path(homeStr) + } + val result = runTests( + frameworkInstances = TestRunner.frameworks(frameworks), + entireClasspath = Agg.from(classpath.map(os.Path(_))), + testClassfilePath = Agg(os.Path(testCp)), + args = arguments + )(ctx) + + // Clear interrupted state in case some badly-behaved test suite + // dirtied the thread-interrupted flag and forgot to clean up. Otherwise + // that flag causes writing the results to disk to fail + Thread.interrupted() + ammonite.ops.write(os.Path(outputPath), upickle.default.write(result)) + }catch{case e: Throwable => + println(e) + e.printStackTrace() + } + // Tests are over, kill the JVM whether or not anyone's threads are still running + // Always return 0, even if tests fail. The caller can pick up the detailed test + // results from the outputPath + System.exit(0) + } + + def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework], + entireClasspath: Agg[os.Path], + testClassfilePath: Agg[os.Path], + args: Seq[String]) + (implicit ctx: Ctx.Log with Ctx.Home): (String, Seq[mill.scalalib.TestRunner.Result]) = { + //Leave the context class loader set and open so that shutdown hooks can access it + Jvm.inprocess(entireClasspath, classLoaderOverrideSbtTesting = true, isolated = true, closeContextClassLoaderWhenDone = false, cl => { + val frameworks = frameworkInstances(cl) + + val events = mutable.Buffer.empty[Event] + + val doneMessages = frameworks.map{ framework => + val runner = framework.runner(args.toArray, Array[String](), cl) + + val testClasses = discoverTests(cl, framework, testClassfilePath) + + val tasks = runner.tasks( + for ((cls, fingerprint) <- testClasses.toArray) + yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array(new SuiteSelector)) + ) + + val taskQueue = tasks.to[mutable.Queue] + while (taskQueue.nonEmpty){ + val next = taskQueue.dequeue().execute( + new EventHandler { + def handle(event: Event) = events.append(event) + }, + Array( + new Logger { + def debug(msg: String) = ctx.log.outputStream.println(msg) + + def error(msg: String) = ctx.log.outputStream.println(msg) + + def ansiCodesSupported() = true + + def warn(msg: String) = ctx.log.outputStream.println(msg) + + def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream) + + def info(msg: String) = ctx.log.outputStream.println(msg) + }) + ) + taskQueue.enqueue(next:_*) + } + runner.done() + } + + val results = for(e <- events) yield { + val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None + mill.scalalib.TestRunner.Result( + e.fullyQualifiedName(), + e.selector() match{ + case s: NestedSuiteSelector => s.suiteId() + case s: NestedTestSelector => s.suiteId() + "." + s.testName() + case s: SuiteSelector => s.toString + case s: TestSelector => s.testName() + case s: TestWildcardSelector => s.testWildcard() + }, + e.duration(), + e.status().toString, + ex.map(_.getClass.getName), + ex.map(_.getMessage), + ex.map(_.getStackTrace) + ) + } + + (doneMessages.mkString("\n"), results) + }) + } + + def frameworks(frameworkNames: Seq[String])(cl: ClassLoader): Seq[sbt.testing.Framework] = { + frameworkNames.map { name => + cl.loadClass(name).newInstance().asInstanceOf[sbt.testing.Framework] + } + } + + case class Result(fullyQualifiedName: String, + selector: String, + duration: Long, + status: String, + exceptionName: Option[String] = None, + exceptionMsg: Option[String] = None, + exceptionTrace: Option[Seq[StackTraceElement]] = None) + + object Result{ + implicit def resultRW: upickle.default.ReadWriter[Result] = upickle.default.macroRW[Result] + } +} |