diff options
-rw-r--r-- | examples/fork-example/Main.scala | 29 | ||||
-rw-r--r-- | examples/fork-example/build/build.scala | 17 | ||||
-rw-r--r-- | nailgun_launcher/NailgunLauncher.java | 2 | ||||
-rw-r--r-- | stage1/Stage1Lib.scala | 111 | ||||
-rw-r--r-- | stage1/resolver.scala | 8 | ||||
-rw-r--r-- | test/test.scala | 3 |
6 files changed, 162 insertions, 8 deletions
diff --git a/examples/fork-example/Main.scala b/examples/fork-example/Main.scala new file mode 100644 index 0000000..2ce5af8 --- /dev/null +++ b/examples/fork-example/Main.scala @@ -0,0 +1,29 @@ +package cbt_examples.fork_example + +import akka.http.scaladsl.server._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.settings._ + +object Service extends HttpApp with App { + // should all appear in separate lines + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + System.out.println("HelloHello") + System.err.println("HelloHello") + + override protected def route = + path("test") { + get { + complete(HttpResponse()) + } + } + + startServer("localhost", 8080) +} diff --git a/examples/fork-example/build/build.scala b/examples/fork-example/build/build.scala new file mode 100644 index 0000000..0f033ee --- /dev/null +++ b/examples/fork-example/build/build.scala @@ -0,0 +1,17 @@ +package cbt_examples_build.akka_example + +import cbt._ +import java.net.URL + +class Build(val context: Context) extends BaseBuild { + override def defaultScalaVersion = "2.12.1" + + override def fork = true + + override def dependencies = + super.dependencies ++ + Resolver(mavenCentral).bind( + ScalaDependency("com.typesafe.akka", "akka-http", "10.0.5") + ) + +} diff --git a/nailgun_launcher/NailgunLauncher.java b/nailgun_launcher/NailgunLauncher.java index 1c6f3b5..0601919 100644 --- a/nailgun_launcher/NailgunLauncher.java +++ b/nailgun_launcher/NailgunLauncher.java @@ -56,6 +56,8 @@ public class NailgunLauncher{ ClassLoader.getSystemClassLoader().getParent() ); + public static boolean runningViaNailgun = System.out.getClass().getName().equals("com.martiansoftware.nailgun.ThreadLocalPrintStream"); + public static List<File> compatibilitySourceFiles; public static List<File> nailgunLauncherSourceFiles; public static List<File> stage1SourceFiles; diff --git a/stage1/Stage1Lib.scala b/stage1/Stage1Lib.scala index 55c9234..392b885 100644 --- a/stage1/Stage1Lib.scala +++ b/stage1/Stage1Lib.scala @@ -8,7 +8,7 @@ import java.nio.file._ import java.nio.file.attribute.FileTime import javax.tools._ import java.security._ -import java.util.{Set=>_,Map=>_,List=>_,_} +import java.util.{Set=>_,Map=>_,List=>_,Iterator=>_,_} import javax.xml.bind.annotation.adapters.HexBinaryAdapter class Stage1Lib( logger: Logger ) extends @@ -249,16 +249,20 @@ ${sourceFiles.sorted.mkString(" \\\n")} } } - def getOutErr: (ThreadLocal[PrintStream], ThreadLocal[PrintStream]) = + def getOutErrIn: (ThreadLocal[PrintStream], ThreadLocal[PrintStream], InputStream) = try{ // trying nailgun's System.our/err wrapper val field = System.out.getClass.getDeclaredField("streams") - assert(System.out.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream") - assert(System.err.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream") + val field2 = System.in.getClass.getDeclaredField("streams") + assert(System.out.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream", System.out.getClass.getName) + assert(System.err.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream", System.err.getClass.getName) + assert(System.in.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalInputStream", System.in.getClass.getName) field.setAccessible(true) + field2.setAccessible(true) val out = field.get(System.out).asInstanceOf[ThreadLocal[PrintStream]] val err = field.get(System.err).asInstanceOf[ThreadLocal[PrintStream]] - ( out, err ) + val in = field2.get(System.in).asInstanceOf[ThreadLocal[InputStream]] + ( out, err, in.get ) } catch { case e: NoSuchFieldException => // trying cbt's System.our/err wrapper @@ -272,11 +276,11 @@ ${sourceFiles.sorted.mkString(" \\\n")} field2.setAccessible(true) val out = field2.get(outStream).asInstanceOf[ThreadLocal[PrintStream]] val err = field2.get(errStream).asInstanceOf[ThreadLocal[PrintStream]] - ( out, err ) + ( out, err, System.in ) } def redirectOutToErr[T](code: => T): T = { - val ( out, err ) = getOutErr + val ( out, err, _ ) = getOutErrIn val oldOut: PrintStream = out.get out.set( err.get: PrintStream ) val res = code @@ -427,6 +431,99 @@ ${sourceFiles.sorted.mkString(" \\\n")} outputLastModified ) } + + def asyncPipeCharacterStreamSyncLines( inputStream: InputStream, outputStream: OutputStream, lock: AnyRef ): Thread = { + new Thread( + new Runnable{ + def run = { + val b = new BufferedInputStream( inputStream ) + Iterator.continually{ + b.read // block until and read next character + }.takeWhile(_ != -1).map{ c => + lock.synchronized{ // synchronize with other invocations + outputStream.write(c) + Iterator + .continually( b.read ) + .takeWhile( _ != -1 ) + .map{ c => + try{ + outputStream.write(c) + outputStream.flush + ( + c != '\n' // release lock when new line was encountered, allowing other writers to slip in + && b.available > 0 // also release when nothing is available to not block other outputs + ) + } catch { + case e: IOException if e.getMessage == "Stream closed" => false + } + } + .takeWhile(identity) + .length // force entire iterator + } + }.length // force entire iterator + } + } + ) + } + + def asyncPipeCharacterStream( inputStream: InputStream, outputStream: OutputStream, continue: => Boolean ) = { + new Thread( + new Runnable{ + def run = { + Iterator + .continually{ inputStream.read } + .takeWhile(_ != -1) + .map{ c => + try{ + outputStream.write(c) + outputStream.flush + true + } catch { + case e: IOException if e.getMessage == "Stream closed" => false + } + } + .takeWhile( identity ) + .takeWhile( _ => continue ) + .length // force entire iterator + } + } + ) + } + + def runWithIO( commandLine: Seq[String], directory: Option[File] = None ): ExitCode = { + val (out,err,in) = lib.getOutErrIn match { case (l,r, in) => (l.get,r.get, in) } + val pb = new ProcessBuilder( commandLine: _* ) + val exitCode = + if( !NailgunLauncher.runningViaNailgun ){ + pb.inheritIO.start.waitFor + } else { + val process = directory.map( pb.directory( _ ) ).getOrElse( pb ) + .redirectInput(ProcessBuilder.Redirect.PIPE) + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .redirectError(ProcessBuilder.Redirect.PIPE) + .start + + val lock = new AnyRef + + val t1 = lib.asyncPipeCharacterStreamSyncLines( process.getErrorStream, err, lock ) + val t2 = lib.asyncPipeCharacterStreamSyncLines( process.getInputStream, out, lock ) + val t3 = lib.asyncPipeCharacterStream( System.in, process.getOutputStream, process.isAlive ) + + t1.start + t2.start + t3.start + + t1.join + t2.join + + val e = process.waitFor + System.err.println( scala.Console.RESET + "Please press ENTER to continue..." ) + t3.join + e + } + + ExitCode( exitCode ) + } } import scala.reflect._ diff --git a/stage1/resolver.scala b/stage1/resolver.scala index be4d278..6134a16 100644 --- a/stage1/resolver.scala +++ b/stage1/resolver.scala @@ -77,9 +77,17 @@ trait DependencyImplementation extends Dependency{ ) } */ + def fork = false def runMain( className: String, args: Seq[String] ): ExitCode = { + if(fork){ + val java_exe = new File(System.getProperty("java.home")) / "bin" / "java" + lib.runWithIO( + java_exe.string +: "-cp" +: classpath.string +: className +: args + ) + } else { lib.getMain( classLoader.loadClass( className ) )( args ) + } } def runMain( args: Seq[String] ): ExitCode = { diff --git a/test/test.scala b/test/test.scala index f3fb7c5..f942375 100644 --- a/test/test.scala +++ b/test/test.scala @@ -89,7 +89,7 @@ object Main{ args = _args.drop(1), transientCache = new java.util.HashMap() ) - val ( outVar, errVar ) = lib.getOutErr + val ( outVar, errVar, _ ) = lib.getOutErrIn val oldOut = outVar.get val oldErr = errVar.get val out = new ByteArrayOutputStream @@ -258,6 +258,7 @@ object Main{ compile("../plugins/wartremover") compile("../plugins/uber-jar") compile("../plugins/scalafix-compiler-plugin") + compile("../examples/fork-example") compile("../examples/scalafmt-example") compile("../examples/scalariform-example") compile("../examples/scalatest-example") |