diff options
-rw-r--r-- | examples/fork-example/Main.scala | 29 | ||||
-rw-r--r-- | examples/fork-example/build/build.scala | 17 | ||||
-rw-r--r-- | libraries/reflect/StaticMethod.scala | 3 | ||||
-rw-r--r-- | libraries/reflect/build/build.scala | 2 | ||||
-rw-r--r-- | libraries/reflect/reflect.scala | 47 | ||||
-rw-r--r-- | nailgun_launcher/NailgunLauncher.java | 2 | ||||
-rw-r--r-- | stage1/Stage1Lib.scala | 136 | ||||
-rw-r--r-- | stage1/resolver.scala | 30 | ||||
-rw-r--r-- | stage2/DirectoryDependency.scala | 2 | ||||
-rw-r--r-- | test/build/build.scala | 2 | ||||
-rw-r--r-- | test/test.scala | 3 |
11 files changed, 212 insertions, 61 deletions
diff --git a/examples/fork-example/Main.scala b/examples/fork-example/Main.scala new file mode 100644 index 0000000..2ce5af8 --- /dev/null +++ b/examples/fork-example/Main.scala @@ -0,0 +1,29 @@ +package cbt_examples.fork_example + +import akka.http.scaladsl.server._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.settings._ + +object Service extends HttpApp with App { + // should all appear in separate lines + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + + override protected def route = + path("test") { + get { + complete(HttpResponse()) + } + } + + startServer("localhost", 8080) +} diff --git a/examples/fork-example/build/build.scala b/examples/fork-example/build/build.scala new file mode 100644 index 0000000..0f033ee --- /dev/null +++ b/examples/fork-example/build/build.scala @@ -0,0 +1,17 @@ +package cbt_examples_build.akka_example + +import cbt._ +import java.net.URL + +class Build(val context: Context) extends BaseBuild { + override def defaultScalaVersion = "2.12.1" + + override def fork = true + + override def dependencies = + super.dependencies ++ + Resolver(mavenCentral).bind( + ScalaDependency("com.typesafe.akka", "akka-http", "10.0.5") + ) + +} diff --git a/libraries/reflect/StaticMethod.scala b/libraries/reflect/StaticMethod.scala index e2a0d07..d7ab3c5 100644 --- a/libraries/reflect/StaticMethod.scala +++ b/libraries/reflect/StaticMethod.scala @@ -1,4 +1,5 @@ package cbt.reflect -case class StaticMethod[Arg, Result]( function: Arg => Result, name: String ) extends ( Arg => Result ) { +import java.lang.reflect.Method +case class StaticMethod[Arg, Result]( function: Arg => Result, method: Method ) extends ( Arg => Result ) { def apply( arg: Arg ): Result = function( arg ) } diff --git a/libraries/reflect/build/build.scala b/libraries/reflect/build/build.scala index 5c27090..3b6658a 100644 --- a/libraries/reflect/build/build.scala +++ b/libraries/reflect/build/build.scala @@ -4,5 +4,5 @@ import cbt_internal._ class Build(val context: Context) extends Library{ override def inceptionYear = 2017 override def description = "discover classes on your classpath and invoke methods reflectively, preventing System.exit" - override def dependencies = super.dependencies :+ libraries.file + override def dependencies = super.dependencies :+ libraries.file :+ libraries.common_1 } diff --git a/libraries/reflect/reflect.scala b/libraries/reflect/reflect.scala index c18d926..bd7c245 100644 --- a/libraries/reflect/reflect.scala +++ b/libraries/reflect/reflect.scala @@ -7,6 +7,7 @@ import scala.reflect.ClassTag import cbt.ExitCode import cbt.file._ +import cbt.common_1._ object `package` extends Module { implicit class CbtClassOps( val c: Class[_] ) extends AnyVal with ops.CbtClassOps @@ -54,27 +55,37 @@ package ops { def isSynchronized = Modifier.isSynchronized( m.getModifiers ) def isTransient = Modifier.isTransient( m.getModifiers ) def isVolatile = Modifier.isVolatile( m.getModifiers ) + + def show = ( + m.name ~ "( " + ~ m.parameters.map( _.getType.name ).mkString( ", " ) + ~ " )" + ) } } trait Module { - def runMain( cls: Class[_], args: Seq[String] ): ExitCode = - discoverStaticExitMethodForced[Array[String]]( cls, "main" ).apply( args.to ) + def getMain( cls: Class[_] ): StaticMethod[Seq[String], ExitCode] = { + val f = findStaticExitMethodForced[Array[String]]( cls, "main" ) + f.copy( + function = ( args: Seq[String] ) => f.function( args.to ) + ) + } - def discoverMain( cls: Class[_] ): Option[StaticMethod[Seq[String], ExitCode]] = { - discoverStaticExitMethod[Array[String]]( cls, "main" ) + def findMain( cls: Class[_] ): Option[StaticMethod[Seq[String], ExitCode]] = { + findStaticExitMethod[Array[String]]( cls, "main" ) .map( f => f.copy( - function = ( arg: Seq[String] ) => f.function( arg.to ) + function = ( args: Seq[String] ) => f.function( args.to ) ) ) } /** ignoreMissingClasses allows ignoring other classes root directories which are subdirectories of this one */ - def iterateClasses( + def topLevelClasses( classesRootDirectory: File, classLoader: ClassLoader, ignoreMissingClasses: Boolean ): Seq[Class[_]] = - iterateClassNames( classesRootDirectory ) + topLevelClassNames( classesRootDirectory ) .map { name => try { classLoader.loadClass( name ) @@ -85,10 +96,10 @@ trait Module { } .filterNot( ignoreMissingClasses && _ == null ) - /** Given a directory corresponding to the root package, iterate - * the names of all classes derived from the class files found + /** Given a directory corresponding to the root package, return + * the names of all top-level classes based on the class files found */ - def iterateClassNames( classesRootDirectory: File ): Seq[String] = + def topLevelClassNames( classesRootDirectory: File ): Seq[String] = classesRootDirectory.listRecursive .filter( _.isFile ) .map( _.getPath ) @@ -102,16 +113,16 @@ trait Module { .replace( File.separator, "." ) } - def discoverStaticExitMethodForced[Arg: ClassTag]( + def findStaticExitMethodForced[Arg: ClassTag]( cls: Class[_], name: String ): StaticMethod[Arg, ExitCode] = { - val f = discoverStaticMethodForced[Arg, Unit]( cls, name ) + val f = findStaticMethodForced[Arg, Unit]( cls, name ) f.copy( function = arg => trapExitCode { f.function( arg ); ExitCode.Success } ) } - def discoverStaticMethodForced[Arg, Result]( + def findStaticMethodForced[Arg, Result]( cls: Class[_], name: String )( implicit @@ -122,15 +133,15 @@ trait Module { typeStaticMethod( m ) } - def discoverStaticExitMethod[Arg: ClassTag]( + def findStaticExitMethod[Arg: ClassTag]( cls: Class[_], name: String ): Option[StaticMethod[Arg, ExitCode]] = - discoverStaticMethod[Arg, Unit]( cls, name ).map( f => + findStaticMethod[Arg, Unit]( cls, name ).map( f => f.copy( function = arg => trapExitCode { f.function( arg ); ExitCode.Success } ) ) - def discoverStaticMethod[Arg, Result]( + def findStaticMethod[Arg, Result]( cls: Class[_], name: String )( implicit @@ -157,9 +168,7 @@ trait Module { else m.declaringClass.newInstance // Dottydoc needs this. It's main method is not static. StaticMethod( arg => m.invoke( instance, arg.asInstanceOf[AnyRef] ).asInstanceOf[Result], - m.getClass.name.stripSuffix( "$" ) ++ "." ++ m.name ++ "( " - ++ m.parameters.map( _.getType.name ).mkString( ", " ) - ++ " )" + m ) } diff --git a/nailgun_launcher/NailgunLauncher.java b/nailgun_launcher/NailgunLauncher.java index 1c6f3b5..0601919 100644 --- a/nailgun_launcher/NailgunLauncher.java +++ b/nailgun_launcher/NailgunLauncher.java @@ -56,6 +56,8 @@ public class NailgunLauncher{ ClassLoader.getSystemClassLoader().getParent() ); + public static boolean runningViaNailgun = System.out.getClass().getName().equals("com.martiansoftware.nailgun.ThreadLocalPrintStream"); + public static List<File> compatibilitySourceFiles; public static List<File> nailgunLauncherSourceFiles; public static List<File> stage1SourceFiles; diff --git a/stage1/Stage1Lib.scala b/stage1/Stage1Lib.scala index 3ffc878..392b885 100644 --- a/stage1/Stage1Lib.scala +++ b/stage1/Stage1Lib.scala @@ -8,7 +8,7 @@ import java.nio.file._ import java.nio.file.attribute.FileTime import javax.tools._ import java.security._ -import java.util.{Set=>_,Map=>_,List=>_,_} +import java.util.{Set=>_,Map=>_,List=>_,Iterator=>_,_} import javax.xml.bind.annotation.adapters.HexBinaryAdapter class Stage1Lib( logger: Logger ) extends @@ -79,28 +79,11 @@ class Stage1Lib( logger: Logger ) extends } } - /* - // ========== compilation / execution ========== - // TODO: move classLoader first - def runMain( className: String, args: Seq[String], classLoader: ClassLoader ): ExitCode = { - import java.lang.reflect.Modifier - logger.run(s"Running $className.main($args) with classLoader: " ++ classLoader.toString) - trapExitCode{ - /* - val cls = classLoader.loadClass(className) - discoverCbtMain( cls ) orElse discoverMain( cls ) getOrElse ( - throw new NoSuchMethodException( "No main method found in " ++ cbt ) - ).apply( arg.toVector )*/ - ExitCode.Success - } - } - */ + def getCbtMain( cls: Class[_] ): cbt.reflect.StaticMethod[Context, ExitCode] = + findStaticMethodForced[Context, ExitCode]( cls, "cbtMain" ) - def discoverCbtMainForced( cls: Class[_] ): cbt.reflect.StaticMethod[Context, ExitCode] = - discoverStaticMethodForced[Context, ExitCode]( cls, "cbtMain" ) - - def discoverCbtMain( cls: Class[_] ): Option[cbt.reflect.StaticMethod[Context, ExitCode]] = - discoverStaticMethod[Context, ExitCode]( cls, "cbtMain" ) + def findCbtMain( cls: Class[_] ): Option[cbt.reflect.StaticMethod[Context, ExitCode]] = + findStaticMethod[Context, ExitCode]( cls, "cbtMain" ) /** shows an interactive dialogue in the shell asking the user to pick one of many choices */ def pickOne[T]( msg: String, choices: Seq[T] )( show: T => String ): Option[T] = { @@ -266,16 +249,20 @@ ${sourceFiles.sorted.mkString(" \\\n")} } } - def getOutErr: (ThreadLocal[PrintStream], ThreadLocal[PrintStream]) = + def getOutErrIn: (ThreadLocal[PrintStream], ThreadLocal[PrintStream], InputStream) = try{ // trying nailgun's System.our/err wrapper val field = System.out.getClass.getDeclaredField("streams") - assert(System.out.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream") - assert(System.err.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream") + val field2 = System.in.getClass.getDeclaredField("streams") + assert(System.out.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream", System.out.getClass.getName) + assert(System.err.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream", System.err.getClass.getName) + assert(System.in.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalInputStream", System.in.getClass.getName) field.setAccessible(true) + field2.setAccessible(true) val out = field.get(System.out).asInstanceOf[ThreadLocal[PrintStream]] val err = field.get(System.err).asInstanceOf[ThreadLocal[PrintStream]] - ( out, err ) + val in = field2.get(System.in).asInstanceOf[ThreadLocal[InputStream]] + ( out, err, in.get ) } catch { case e: NoSuchFieldException => // trying cbt's System.our/err wrapper @@ -289,11 +276,11 @@ ${sourceFiles.sorted.mkString(" \\\n")} field2.setAccessible(true) val out = field2.get(outStream).asInstanceOf[ThreadLocal[PrintStream]] val err = field2.get(errStream).asInstanceOf[ThreadLocal[PrintStream]] - ( out, err ) + ( out, err, System.in ) } def redirectOutToErr[T](code: => T): T = { - val ( out, err ) = getOutErr + val ( out, err, _ ) = getOutErrIn val oldOut: PrintStream = out.get out.set( err.get: PrintStream ) val res = code @@ -444,6 +431,99 @@ ${sourceFiles.sorted.mkString(" \\\n")} outputLastModified ) } + + def asyncPipeCharacterStreamSyncLines( inputStream: InputStream, outputStream: OutputStream, lock: AnyRef ): Thread = { + new Thread( + new Runnable{ + def run = { + val b = new BufferedInputStream( inputStream ) + Iterator.continually{ + b.read // block until and read next character + }.takeWhile(_ != -1).map{ c => + lock.synchronized{ // synchronize with other invocations + outputStream.write(c) + Iterator + .continually( b.read ) + .takeWhile( _ != -1 ) + .map{ c => + try{ + outputStream.write(c) + outputStream.flush + ( + c != '\n' // release lock when new line was encountered, allowing other writers to slip in + && b.available > 0 // also release when nothing is available to not block other outputs + ) + } catch { + case e: IOException if e.getMessage == "Stream closed" => false + } + } + .takeWhile(identity) + .length // force entire iterator + } + }.length // force entire iterator + } + } + ) + } + + def asyncPipeCharacterStream( inputStream: InputStream, outputStream: OutputStream, continue: => Boolean ) = { + new Thread( + new Runnable{ + def run = { + Iterator + .continually{ inputStream.read } + .takeWhile(_ != -1) + .map{ c => + try{ + outputStream.write(c) + outputStream.flush + true + } catch { + case e: IOException if e.getMessage == "Stream closed" => false + } + } + .takeWhile( identity ) + .takeWhile( _ => continue ) + .length // force entire iterator + } + } + ) + } + + def runWithIO( commandLine: Seq[String], directory: Option[File] = None ): ExitCode = { + val (out,err,in) = lib.getOutErrIn match { case (l,r, in) => (l.get,r.get, in) } + val pb = new ProcessBuilder( commandLine: _* ) + val exitCode = + if( !NailgunLauncher.runningViaNailgun ){ + pb.inheritIO.start.waitFor + } else { + val process = directory.map( pb.directory( _ ) ).getOrElse( pb ) + .redirectInput(ProcessBuilder.Redirect.PIPE) + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .redirectError(ProcessBuilder.Redirect.PIPE) + .start + + val lock = new AnyRef + + val t1 = lib.asyncPipeCharacterStreamSyncLines( process.getErrorStream, err, lock ) + val t2 = lib.asyncPipeCharacterStreamSyncLines( process.getInputStream, out, lock ) + val t3 = lib.asyncPipeCharacterStream( System.in, process.getOutputStream, process.isAlive ) + + t1.start + t2.start + t3.start + + t1.join + t2.join + + val e = process.waitFor + System.err.println( scala.Console.RESET + "Please press ENTER to continue..." ) + t3.join + e + } + + ExitCode( exitCode ) + } } import scala.reflect._ diff --git a/stage1/resolver.scala b/stage1/resolver.scala index 0e5d221..6134a16 100644 --- a/stage1/resolver.scala +++ b/stage1/resolver.scala @@ -77,24 +77,36 @@ trait DependencyImplementation extends Dependency{ ) } */ - def flatClassLoader: Boolean = false + def fork = false - def runMain( className: String, args: Seq[String] ): ExitCode = lib.trapExitCode{ - lib.runMain( classLoader.loadClass( className ), args ) + def runMain( className: String, args: Seq[String] ): ExitCode = { + if(fork){ + val java_exe = new File(System.getProperty("java.home")) / "bin" / "java" + lib.runWithIO( + java_exe.string +: "-cp" +: classpath.string +: className +: args + ) + } else { + lib.getMain( classLoader.loadClass( className ) )( args ) + } } - def runMain( args: Seq[String] ): ExitCode = lib.trapExitCode{ - mainMethod.getOrElse( + def runMain( args: Seq[String] ): ExitCode = { + val c = mainClass.getOrElse( throw new RuntimeException( "No main class found in " + this ) - )( args ) + ) + runMain( c.getName, args ) } - def mainMethod = lib.pickOne( "Which one do you want to run?", mainMethods )( _.name ) + def mainClass = lib.pickOne( + "Which one do you want to run?", + classes.filter( lib.findMain(_).nonEmpty ) + )( _.name.stripSuffix( "$" ) ) def classes = exportedClasspath.files.flatMap( - lib.iterateClasses( _, classLoader, false ) + lib.topLevelClasses( _, classLoader, false ) ) - def mainMethods = classes.flatMap( lib.discoverMain ) + + def flatClassLoader: Boolean = false def classLoader: ClassLoader = taskCache[DependencyImplementation]( "classLoader" ).memoize{ if( flatClassLoader ){ diff --git a/stage2/DirectoryDependency.scala b/stage2/DirectoryDependency.scala index 9b07702..6ebb988 100644 --- a/stage2/DirectoryDependency.scala +++ b/stage2/DirectoryDependency.scala @@ -77,7 +77,7 @@ object DirectoryDependency { val buildClasses = buildBuild.exportedClasspath.files.flatMap( - lib.iterateClasses( _, classLoader, false ) + lib.topLevelClasses( _, classLoader, false ) .filter( _.getSimpleName === lib.buildClassName ) .filter( classOf[BaseBuild] isAssignableFrom _ ) ) diff --git a/test/build/build.scala b/test/build/build.scala index 792b34d..27de2bb 100644 --- a/test/build/build.scala +++ b/test/build/build.scala @@ -4,7 +4,7 @@ class Build(val context: cbt.Context) extends BaseBuild{ override def dependencies = super.dependencies :+ context.cbtDependency def apply = run override def run = { - classes.flatMap( lib.discoverCbtMain ).head( context ) + classes.flatMap( lib.findCbtMain ).head( context ) } def args = context.args } diff --git a/test/test.scala b/test/test.scala index f3fb7c5..f942375 100644 --- a/test/test.scala +++ b/test/test.scala @@ -89,7 +89,7 @@ object Main{ args = _args.drop(1), transientCache = new java.util.HashMap() ) - val ( outVar, errVar ) = lib.getOutErr + val ( outVar, errVar, _ ) = lib.getOutErrIn val oldOut = outVar.get val oldErr = errVar.get val out = new ByteArrayOutputStream @@ -258,6 +258,7 @@ object Main{ compile("../plugins/wartremover") compile("../plugins/uber-jar") compile("../plugins/scalafix-compiler-plugin") + compile("../examples/fork-example") compile("../examples/scalafmt-example") compile("../examples/scalariform-example") compile("../examples/scalatest-example") |