diff options
Diffstat (limited to 'stage1/Stage1Lib.scala')
-rw-r--r-- | stage1/Stage1Lib.scala | 125 |
1 files changed, 28 insertions, 97 deletions
diff --git a/stage1/Stage1Lib.scala b/stage1/Stage1Lib.scala index 565a06c..8aaa6e6 100644 --- a/stage1/Stage1Lib.scala +++ b/stage1/Stage1Lib.scala @@ -11,31 +11,11 @@ import java.security._ import java.util.{Set=>_,Map=>_,List=>_,_} import javax.xml.bind.annotation.adapters.HexBinaryAdapter -// CLI interop -case class ExitCode(integer: Int){ - def ||( other: => ExitCode ) = if( this == ExitCode.Success ) this else other - def &&( other: => ExitCode ) = if( this != ExitCode.Success ) this else other -} -object ExitCode{ - val Success = ExitCode(0) - val Failure = ExitCode(1) -} - -object CatchTrappedExitCode{ - def unapply(e: Throwable): Option[ExitCode] = { - Option(e) flatMap { - case i: InvocationTargetException => unapply(i.getTargetException) - case e if TrapSecurityManager.isTrappedExit(e) => Some( ExitCode(TrapSecurityManager.exitCode(e)) ) - case _ => None - } - } -} - -class BaseLib{ - def realpath(name: File) = new File(java.nio.file.Paths.get(name.getAbsolutePath).normalize.toString) -} - -class Stage1Lib( logger: Logger ) extends BaseLib{ +class Stage1Lib( logger: Logger ) extends + _root_.cbt.common_1.Module with + _root_.cbt.reflect.Module with + _root_.cbt.file.Module +{ lib => implicit protected val implicitLogger: Logger = logger @@ -99,24 +79,28 @@ class Stage1Lib( logger: Logger ) extends BaseLib{ } } + /* // ========== compilation / execution ========== // TODO: move classLoader first - def runMain( cls: String, args: Seq[String], classLoader: ClassLoader, fakeInstance: Boolean = false ): ExitCode = { + def runMain( className: String, args: Seq[String], classLoader: ClassLoader ): ExitCode = { import java.lang.reflect.Modifier - logger.run(s"Running $cls.main($args) with classLoader: " ++ classLoader.toString) + logger.run(s"Running $className.main($args) with classLoader: " ++ classLoader.toString) trapExitCode{ - val c = classLoader.loadClass(cls) - val m = c.getMethod( "main", classOf[Array[String]] ) - val instance = - if(!fakeInstance) null else c.newInstance - assert( - fakeInstance || (m.getModifiers & java.lang.reflect.Modifier.STATIC) > 0, - "Cannot run non-static method " ++ cls+".main" - ) - m.invoke( instance, args.toArray.asInstanceOf[AnyRef] ) + /* + val cls = classLoader.loadClass(className) + discoverCbtMain( cls ) orElse discoverMain( cls ) getOrElse ( + throw new NoSuchMethodException( "No main method found in " ++ cbt ) + ).apply( arg.toVector )*/ ExitCode.Success } } + */ + + def discoverCbtMainForced( cls: Class[_] ): cbt.reflect.StaticMethod[Context, ExitCode] = + discoverStaticMethodForced[Context, ExitCode]( cls, "cbtMain" ) + + def discoverCbtMain( cls: Class[_] ): Option[cbt.reflect.StaticMethod[Context, ExitCode]] = + discoverStaticMethod[Context, ExitCode]( cls, "cbtMain" ) /** shows an interactive dialogue in the shell asking the user to pick one of many choices */ def pickOne[T]( msg: String, choices: Seq[T] )( show: T => String ): Option[T] = { @@ -149,51 +133,10 @@ class Stage1Lib( logger: Logger ) extends BaseLib{ } /** interactively pick one main class */ - def runClass( mainClasses: Seq[Class[_]] ): Option[Class[_]] = { + def pickClass( mainClasses: Seq[Class[_]] ): Option[Class[_]] = { pickOne( "Which one do you want to run?", mainClasses )( _.toString ) } - /** Given a directory corresponding to the root package, iterate - the names of all classes derived from the class files found */ - def iterateClassNames( classesRootDirectory: File ): Seq[String] = - classesRootDirectory - .listRecursive - .filter(_.isFile) - .map(_.getPath) - .collect{ - // no $ to avoid inner classes - case path if !path.contains("$") && path.endsWith(".class") => - path.stripSuffix(".class") - .stripPrefix(classesRootDirectory.getPath) - .stripPrefix(File.separator) // 1 for the slash - .replace(File.separator, ".") - } - - /** ignoreMissingClasses allows ignoring other classes root directories which are subdirectories of this one */ - def iterateClasses( classesRootDirectory: File, classLoader: ClassLoader, ignoreMissingClasses: Boolean ) = - iterateClassNames(classesRootDirectory).map{ name => - try{ - classLoader.loadClass(name) - } catch { - case e: ClassNotFoundException if ignoreMissingClasses => null - case e: NoClassDefFoundError if ignoreMissingClasses => null - } - }.filterNot(ignoreMissingClasses && _ == null) - - def mainClasses( classesRootDirectory: File, classLoader: ClassLoader ): Seq[Class[_]] = { - val arrayClass = classOf[Array[String]] - val unitClass = classOf[Unit] - - iterateClasses( classesRootDirectory, classLoader, true ).filter( c => - !c.isInterface && - c.getDeclaredMethods().exists( m => - m.getName == "main" - && m.getParameterTypes.toList == List(arrayClass) - && m.getReturnType == unitClass - ) - ) - } - implicit class ClassLoaderExtensions(classLoader: ClassLoader){ def canLoad(className: String) = { try{ @@ -322,8 +265,9 @@ ${sourceFiles.sorted.mkString(" \\\n")} } } } - def redirectOutToErr[T](code: => T): T = { - val ( out, err ) = try{ + + def getOutErr: (ThreadLocal[PrintStream], ThreadLocal[PrintStream]) = + try{ // trying nailgun's System.our/err wrapper val field = System.out.getClass.getDeclaredField("streams") assert(System.out.getClass.getName == "com.martiansoftware.nailgun.ThreadLocalPrintStream") @@ -339,8 +283,8 @@ ${sourceFiles.sorted.mkString(" \\\n")} field.setAccessible(true) val outStream = field.get(System.out) val errStream = field.get(System.err) - assert(outStream.getClass.getName == "cbt.ThreadLocalOutputStream") - assert(errStream.getClass.getName == "cbt.ThreadLocalOutputStream") + assert(outStream.getClass.getName == "cbt.ThreadLocalOutputStream", outStream.getClass.getName) + assert(errStream.getClass.getName == "cbt.ThreadLocalOutputStream", errStream.getClass.getName) val field2 = outStream.getClass.getDeclaredField("threadLocal") field2.setAccessible(true) val out = field2.get(outStream).asInstanceOf[ThreadLocal[PrintStream]] @@ -348,6 +292,8 @@ ${sourceFiles.sorted.mkString(" \\\n")} ( out, err ) } + def redirectOutToErr[T](code: => T): T = { + val ( out, err ) = getOutErr val oldOut: PrintStream = out.get out.set( err.get: PrintStream ) val res = code @@ -355,21 +301,6 @@ ${sourceFiles.sorted.mkString(" \\\n")} res } - def trapExitCodeOrValue[T]( result: => T ): Either[ExitCode,T] = { - val trapExitCodeBefore = TrapSecurityManager.trapExitCode().get - try{ - TrapSecurityManager.trapExitCode().set(true) - Right( result ) - } catch { - case CatchTrappedExitCode(exitCode) => - logger.stage1(s"caught exit code $exitCode") - Left( exitCode ) - } finally { - TrapSecurityManager.trapExitCode().set(trapExitCodeBefore) - } - } - - def trapExitCode( code: => ExitCode ): ExitCode = trapExitCodeOrValue(code).merge def ScalaDependency( groupId: String, artifactId: String, version: String, classifier: Classifier = Classifier.none, |