From b1e622a42de5d48b82c108f2d7931b170a460f5e Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Mon, 1 Jan 2018 01:39:09 -0800 Subject: Properly set the context classloader inside `TestRunner`'s `ClassLoader` --- .../src/main/scala/mill/scalalib/TestRunner.scala | 101 +++++++++++---------- 1 file changed, 54 insertions(+), 47 deletions(-) (limited to 'scalalib/src') diff --git a/scalalib/src/main/scala/mill/scalalib/TestRunner.scala b/scalalib/src/main/scala/mill/scalalib/TestRunner.scala index 0a3df35e..7fd6079e 100644 --- a/scalalib/src/main/scala/mill/scalalib/TestRunner.scala +++ b/scalalib/src/main/scala/mill/scalalib/TestRunner.scala @@ -62,7 +62,7 @@ object TestRunner { args: Seq[String]) (implicit ctx: LogCtx): (String, Seq[Result]) = { val outerClassLoader = getClass.getClassLoader - pprint.log(entireClasspath.map(_.toIO.toURI.toURL).toArray, height=9999) + val cl = new URLClassLoader( entireClasspath.map(_.toIO.toURI.toURL).toArray, ClassLoader.getSystemClassLoader().getParent()){ @@ -75,61 +75,68 @@ object TestRunner { } } - val framework = cl.loadClass(frameworkName) - .newInstance() - .asInstanceOf[sbt.testing.Framework] + val oldCl = Thread.currentThread().getContextClassLoader + Thread.currentThread().setContextClassLoader(cl) + try { + val framework = cl.loadClass(frameworkName) + .newInstance() + .asInstanceOf[sbt.testing.Framework] - val testClasses = runTests(cl, framework, testClassfilePath) + val testClasses = runTests(cl, framework, testClassfilePath) - val runner = framework.runner(args.toArray, args.toArray, cl) + val runner = framework.runner(args.toArray, args.toArray, cl) - val tasks = runner.tasks( - for((cls, fingerprint) <- testClasses.toArray) - yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array()) - ) - val events = mutable.Buffer.empty[Event] - for(t <- tasks){ - t.execute( - new EventHandler { - def handle(event: Event) = events.append(event) - }, - Array( - new Logger { - def debug(msg: String) = ctx.log.info(msg) + val tasks = runner.tasks( + for ((cls, fingerprint) <- testClasses.toArray) + yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array()) + ) + val events = mutable.Buffer.empty[Event] + for (t <- tasks) { + t.execute( + new EventHandler { + def handle(event: Event) = events.append(event) + }, + Array( + new Logger { + def debug(msg: String) = ctx.log.info(msg) - def error(msg: String) = ctx.log.error(msg) + def error(msg: String) = ctx.log.error(msg) - def ansiCodesSupported() = true + def ansiCodesSupported() = true - def warn(msg: String) = ctx.log.info(msg) + def warn(msg: String) = ctx.log.info(msg) - def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream) + def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream) - def info(msg: String) = ctx.log.info(msg) - }) - ) - } - val doneMsg = runner.done() - - val results = for(e <- events) yield { - val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None - Result( - e.fullyQualifiedName(), - e.selector() match{ - case s: NestedSuiteSelector => s.suiteId() - case s: NestedTestSelector => s.suiteId() + "." + s.testName() - case s: SuiteSelector => s.toString - case s: TestSelector => s.testName() - case s: TestWildcardSelector => s.testWildcard() - }, - e.duration(), - e.status(), - ex.map(_.getClass.getName), - ex.map(_.getMessage), - ex.map(_.getStackTrace) - ) + def info(msg: String) = ctx.log.info(msg) + }) + ) + } + val doneMsg = runner.done() + val results = for(e <- events) yield { + val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None + Result( + e.fullyQualifiedName(), + e.selector() match{ + case s: NestedSuiteSelector => s.suiteId() + case s: NestedTestSelector => s.suiteId() + "." + s.testName() + case s: SuiteSelector => s.toString + case s: TestSelector => s.testName() + case s: TestWildcardSelector => s.testWildcard() + }, + e.duration(), + e.status(), + ex.map(_.getClass.getName), + ex.map(_.getMessage), + ex.map(_.getStackTrace) + ) + } + (doneMsg, results) + }finally{ + Thread.currentThread().setContextClassLoader(oldCl) } - (doneMsg, results) + + } case class Result(fullyQualifiedName: String, -- cgit v1.2.3