diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/mill/modules/Jvm.scala | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/core/src/main/scala/mill/modules/Jvm.scala b/core/src/main/scala/mill/modules/Jvm.scala index 48f100b2..888a687b 100644 --- a/core/src/main/scala/mill/modules/Jvm.scala +++ b/core/src/main/scala/mill/modules/Jvm.scala @@ -1,6 +1,7 @@ package mill.modules import java.io.FileOutputStream +import java.lang.reflect.Modifier import java.net.URLClassLoader import java.nio.file.attribute.PosixFilePermission import java.util.jar.{JarEntry, JarFile, JarOutputStream} @@ -9,6 +10,7 @@ import ammonite.ops._ import mill.define.Task import mill.eval.PathRef import mill.util.Ctx +import mill.util.Ctx.LogCtx import mill.util.Loose.Agg import scala.annotation.tailrec @@ -36,6 +38,61 @@ object Jvm { %("java", "-cp", classPath.mkString(":"), mainClass, options) } + def inprocess(mainClass: String, + classPath: Agg[Path], + options: Seq[String] = Seq.empty) + (implicit ctx: Ctx): Unit = { + inprocess(classPath, classLoaderOverrideSbtTesting = false, cl => { + getMainMethod(mainClass, cl).invoke(null, options.toArray) + }) + } + + private def getMainMethod(mainClassName: String, cl: ClassLoader) = { + val mainClass = cl.loadClass(mainClassName) + val method = mainClass.getMethod("main", classOf[Array[String]]) + // jvm allows the actual main class to be non-public and to run a method in the non-public class, + // we need to make it accessible + method.setAccessible(true) + val modifiers = method.getModifiers + if (!Modifier.isPublic(modifiers)) + throw new NoSuchMethodException(mainClassName + ".main is not public") + if (!Modifier.isStatic(modifiers)) + throw new NoSuchMethodException(mainClassName + ".main is not static") + method + } + + + def inprocess[T](classPath: Agg[Path], + classLoaderOverrideSbtTesting: Boolean, + body: ClassLoader => T): T = { + val cl = if (classLoaderOverrideSbtTesting) { + val outerClassLoader = getClass.getClassLoader + new URLClassLoader( + classPath.map(_.toIO.toURI.toURL).toArray, + ClassLoader.getSystemClassLoader().getParent()){ + override def findClass(name: String) = { + if (name.startsWith("sbt.testing.")){ + outerClassLoader.loadClass(name) + }else{ + super.findClass(name) + } + } + } + } else { + new URLClassLoader( + classPath.map(_.toIO.toURI.toURL).toArray, + ClassLoader.getSystemClassLoader().getParent()) + } + val oldCl = Thread.currentThread().getContextClassLoader + Thread.currentThread().setContextClassLoader(cl) + try { + body(cl) + }finally{ + Thread.currentThread().setContextClassLoader(oldCl) + cl.close() + } + } + def subprocess(mainClass: String, classPath: Agg[Path], jvmOptions: Seq[String] = Seq.empty, |