summaryrefslogblamecommitdiff
path: root/scalaworker/src/mill/scalaworker/ScalaWorker.scala
blob: 0411af925961a8807c258c33506db22f4bae56d3 (plain) (tree)
1
2
3
4
5
6
7
8
9







                                      
                                                     
                           



                         
                                                         


                                                                                        
                                   














                                                                                                
                   


                                         





                                                    
       





                                  
                               









                                              
                                
       
                                                        



                                                               
            
 









                                                                                     
 


                                                                                               


                                                                         
                                             
                                         




                                                      

                                                


                         
                                                                            
                                                                                                 


                                                                    
                                                                                            






                                                                             
 


                















                                                                                                             

                                        
                                               

                                                



                                                                 
                                                                                        
                                                                    
                                                            
 
                                                                                             
 
                                                              


                                                                                                         
                                                                                                           



                                                           
                                                                                             








                                                                           
                                                                                                    











































                                                                                               



                                                            
                                                

























                                                                                                          
          







                                
       
 

                                                                                

   
                                                                             


                                            
                                                                              
                                                                                
                                             
 
                                              
 

                                                                     
 





                                                                                                        
 






                                                             
                                                                          
 
                                                                          
 
                                               
 
                                                                         
 

                                                                                 
                                                                         


                
                                                   
       
 

















                                                                               

                                            









                                                                                                             
                                                                                    



                                                 



                                                              


                                                                            

                                                          










                                                                                
 
package mill.scalaworker

import java.io.{File, FileInputStream}
import java.lang.annotation.Annotation
import java.net.URLClassLoader
import java.util.Optional
import java.util.zip.ZipInputStream

import ammonite.ops.{Path, exists, ls, mkdir, rm, up}
import ammonite.util.Colors
import mill.Agg
import mill.define.Worker
import mill.eval.PathRef
import mill.modules.Jvm
import mill.scalalib.{CompilationResult, Lib, TestRunner}
import xsbti.compile.{CompilerCache => _, FileAnalysisStore => _, ScalaInstance => _, _}
import mill.scalalib.Lib.grepJar
import mill.scalalib.TestRunner.Result
import mill.util.{Ctx, PrintLogger}
import sbt.internal.inc._
import sbt.internal.util.{ConsoleOut, MainAppender}
import sbt.testing._
import sbt.util.LogExchange

import scala.collection.mutable

case class MockedLookup(am: File => Optional[CompileAnalysis]) extends PerClasspathEntryLookup {
  override def analysis(classpathEntry: File): Optional[CompileAnalysis] =
    am(classpathEntry)

  override def definesClass(classpathEntry: File): DefinesClass =
    Locate.definesClass(classpathEntry)
}

object ScalaWorker{

  def main(args: Array[String]): Unit = {
    try{
      var i = 0
      def readArray() = {
        val count = args(i).toInt
        val slice = args.slice(i + 1, i + count + 1)
        i = i + count + 1
        slice
      }
      val frameworks = readArray()
      val classpath = readArray()
      val arguments = readArray()
      val outputPath = args(i + 0)
      val colored = args(i + 1)
      val testCp = args(i + 2)
      val homeStr = args(i + 3)
      val ctx = new Ctx.Log with Ctx.Home {
        val log = PrintLogger(
          colored == "true",
          if(colored == "true") Colors.Default
          else Colors.BlackWhite,
          System.out,
          System.err,
          System.err,
          System.in
        )
        val home = Path(homeStr)
      }
      val result = new ScalaWorker(null, null).runTests(
        frameworkInstances = TestRunner.frameworks(frameworks),
        entireClasspath = Agg.from(classpath.map(Path(_))),
        testClassfilePath = Agg(Path(testCp)),
        args = arguments
      )(ctx)

      ammonite.ops.write(Path(outputPath), upickle.default.write(result))
    }catch{case e: Throwable =>
      println(e)
      e.printStackTrace()
    }
    // Tests are over, kill the JVM whether or not anyone's threads are still running
    // Always return 0, even if tests fail. The caller can pick up the detailed test
    // results from the outputPath
    System.exit(0)
  }
}

class ScalaWorker(ctx0: mill.util.Ctx,
                  compilerBridgeClasspath: Array[String]) extends mill.scalalib.ScalaWorkerApi{
  @volatile var scalaClassloaderCache = Option.empty[(Long, ClassLoader)]
  @volatile var scalaInstanceCache = Option.empty[(Long, ScalaInstance)]

  def compileZincBridge(scalaVersion: String,
                        sourcesJar: Path,
                        compilerJars: Array[File]) = {
    val workingDir = ctx0.dest / scalaVersion
    val compiledDest = workingDir / 'compiled
    if (!exists(workingDir)) {

      println("Compiling compiler interface...")

      mkdir(workingDir)
      mkdir(compiledDest)

      val sourceFolder = mill.modules.Util.unpackZip(sourcesJar)(workingDir)
      val classloader = mill.util.ClassLoader.create(compilerJars.map(_.toURI.toURL), null)(ctx0)
      val scalacMain = classloader.loadClass("scala.tools.nsc.Main")
      val argsArray = Array[String](
        "-d", compiledDest.toString,
        "-classpath", (compilerJars ++ compilerBridgeClasspath).mkString(File.pathSeparator)
      ) ++ ls.rec(sourceFolder.path).filter(_.ext == "scala").map(_.toString)

      scalacMain.getMethods
        .find(_.getName == "process")
        .get
        .invoke(null, argsArray)
    }

    compiledDest
  }



  def discoverMainClasses(compilationResult: CompilationResult)(implicit ctx: mill.util.Ctx): Seq[String] = {
    def toScala[A](o: Optional[A]): Option[A] = if (o.isPresent) Some(o.get) else None

    toScala(FileAnalysisStore.binary(compilationResult.analysisFile.toIO).get())
      .map(_.getAnalysis)
      .flatMap{
        case analysis: Analysis =>
          Some(analysis.infos.allInfos.values.map(_.getMainClasses).flatten.toSeq.sorted)
        case _ =>
          None
      }
      .getOrElse(Seq.empty[String])
  }

  def compileScala(scalaVersion: String,
                   sources: Agg[Path],
                   compilerBridgeSources: Path,
                   compileClasspath: Agg[Path],
                   compilerClasspath: Agg[Path],
                   scalacOptions: Seq[String],
                   scalacPluginClasspath: Agg[Path],
                   javacOptions: Seq[String],
                   upstreamCompileOutput: Seq[CompilationResult])
                  (implicit ctx: mill.util.Ctx): mill.eval.Result[CompilationResult] = {
    val compileClasspathFiles = compileClasspath.map(_.toIO).toArray
    val compilerJars = compilerClasspath.toArray.map(_.toIO)

    val compilerBridge = compileZincBridge(scalaVersion, compilerBridgeSources, compilerJars)

    val pluginJars = scalacPluginClasspath.toArray.map(_.toIO)

    val compilerClassloaderSig = compilerClasspath.map(p => p.toString().hashCode + p.mtime.toMillis).sum
    val scalaInstanceSig =
      compilerClassloaderSig + scalacPluginClasspath.map(p => p.toString().hashCode + p.mtime.toMillis).sum

    val compilerClassLoader = scalaClassloaderCache match{
      case Some((k, v)) if k == compilerClassloaderSig => v
      case _ =>
        val classloader = mill.util.ClassLoader.create(compilerJars.map(_.toURI.toURL), null)
        scalaClassloaderCache = Some((compilerClassloaderSig, classloader))
        classloader
    }

    val scalaInstance = scalaInstanceCache match{
      case Some((k, v)) if k == scalaInstanceSig => v
      case _ =>
        val scalaInstance = new ScalaInstance(
          version = scalaVersion,
          loader = mill.util.ClassLoader.create(pluginJars.map(_.toURI.toURL), compilerClassLoader),
          libraryJar = grepJar(compilerClasspath, s"scala-library-$scalaVersion.jar"),
          compilerJar = grepJar(compilerClasspath, s"scala-compiler-$scalaVersion.jar"),
          allJars = compilerJars ++ pluginJars,
          explicitActual = None
        )
        scalaInstanceCache = Some((scalaInstanceSig, scalaInstance))
        scalaInstance
    }

    mkdir(ctx.dest)

    val ic = new sbt.internal.inc.IncrementalCompilerImpl()

    val logger = {
      val consoleAppender = MainAppender.defaultScreen(ConsoleOut.printStreamOut(
        ctx.log.outputStream
      ))
      val l = LogExchange.logger("Hello")
      LogExchange.unbindLoggerAppenders("Hello")
      LogExchange.bindLoggerAppenders("Hello", (consoleAppender -> sbt.util.Level.Info) :: Nil)
      l
    }

    def analysisMap(f: File): Optional[CompileAnalysis] = {
      if (f.isFile) {
        Optional.empty[CompileAnalysis]
      } else {
        upstreamCompileOutput.collectFirst {
          case CompilationResult(zincPath, classFiles) if classFiles.path.toNIO == f.toPath =>
            FileAnalysisStore.binary(zincPath.toIO).get().map[CompileAnalysis](_.getAnalysis)
        }.getOrElse(Optional.empty[CompileAnalysis])
      }
    }

    val lookup = MockedLookup(analysisMap)

    val zincFile = ctx.dest / 'zinc
    val classesDir = ctx.dest / 'classes

    val zincIOFile = zincFile.toIO
    val classesIODir = classesDir.toIO

    val store = FileAnalysisStore.binary(zincIOFile)

    try {
      val newResult = ic.compile(
        ic.inputs(
          classpath = classesIODir +: compileClasspathFiles,
          sources = sources.toArray.map(_.toIO),
          classesDirectory = classesIODir,
          scalacOptions = (scalacPluginClasspath.map(jar => s"-Xplugin:${jar}") ++ scalacOptions).toArray,
          javacOptions = javacOptions.toArray,
          maxErrors = 10,
          sourcePositionMappers = Array(),
          order = CompileOrder.Mixed,
          compilers = ic.compilers(
            scalaInstance,
            ClasspathOptionsUtil.boot,
            None,
            ZincUtil.scalaCompiler(scalaInstance, compilerBridge.toIO)
          ),
          setup = ic.setup(
            lookup,
            skip = false,
            zincIOFile,
            new FreshCompilerCache,
            IncOptions.of(),
            new ManagedLoggedReporter(10, logger),
            None,
            Array()
          ),
          pr = {
            val prev = store.get()
            PreviousResult.of(prev.map(_.getAnalysis), prev.map(_.getMiniSetup))
          }
        ),
        logger = logger
      )

      store.set(
        AnalysisContents.create(
          newResult.analysis(),
          newResult.setup()
        )
      )

      mill.eval.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
    }catch{case e: CompileFailed => mill.eval.Result.Failure(e.toString)}
  }

  def runTests(frameworkInstances: ClassLoader => Seq[sbt.testing.Framework],
               entireClasspath: Agg[Path],
               testClassfilePath: Agg[Path],
               args: Seq[String])
              (implicit ctx: Ctx.Log with Ctx.Home): (String, Seq[Result]) = {
    Jvm.inprocess(entireClasspath, classLoaderOverrideSbtTesting = true, cl => {
      val frameworks = frameworkInstances(cl)

      val events = mutable.Buffer.empty[Event]

      val doneMessages = frameworks.map { framework =>
        val runner = framework.runner(args.toArray, args.toArray, cl)

        val testClasses = discoverTests(cl, framework, testClassfilePath)

        val tasks = runner.tasks(
          for ((cls, fingerprint) <- testClasses.toArray)
            yield new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array(new SuiteSelector))
        )

        for (t <- tasks) {
          t.execute(
            new EventHandler {
              def handle(event: Event) = events.append(event)
            },
            Array(
              new Logger {
                def debug(msg: String) = ctx.log.outputStream.println(msg)

                def error(msg: String) = ctx.log.outputStream.println(msg)

                def ansiCodesSupported() = true

                def warn(msg: String) = ctx.log.outputStream.println(msg)

                def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream)

                def info(msg: String) = ctx.log.outputStream.println(msg)
              })
          )
        }
        ctx.log.outputStream.println(runner.done())
      }

      val results = for(e <- events) yield {
        val ex = if (e.throwable().isDefined) Some(e.throwable().get) else None
        Result(
          e.fullyQualifiedName(),
          e.selector() match{
            case s: NestedSuiteSelector => s.suiteId()
            case s: NestedTestSelector => s.suiteId() + "." + s.testName()
            case s: SuiteSelector => s.toString
            case s: TestSelector => s.testName()
            case s: TestWildcardSelector => s.testWildcard()
          },
          e.duration(),
          e.status().toString,
          ex.map(_.getClass.getName),
          ex.map(_.getMessage),
          ex.map(_.getStackTrace)
        )
      }

      (doneMessages.mkString("\n"), results)
    })

  }
  def listClassFiles(base: Path): Iterator[String] = {
    if (base.isDir) ls.rec(base).toIterator.filter(_.ext == "class").map(_.relativeTo(base).toString)
    else {
      val zip = new ZipInputStream(new FileInputStream(base.toIO))
      Iterator.continually(zip.getNextEntry).takeWhile(_ != null).map(_.getName).filter(_.endsWith(".class"))
    }
  }
  def discoverTests(cl: ClassLoader, framework: Framework, classpath: Agg[Path]) = {


    val fingerprints = framework.fingerprints()
    val testClasses = classpath.flatMap { base =>
      // Don't blow up if there are no classfiles representing
      // the tests to run Instead just don't run anything
      if (!exists(base)) Nil
      else listClassFiles(base).flatMap { path =>
        val cls = cl.loadClass(path.stripSuffix(".class").replace('/', '.'))
        fingerprints.find {
          case f: SubclassFingerprint =>
            !cls.isInterface &&
              (f.isModule == cls.getName.endsWith("$")) &&
              cl.loadClass(f.superclassName()).isAssignableFrom(cls)
          case f: AnnotatedFingerprint =>
            (f.isModule == cls.getName.endsWith("$")) &&
              cls.isAnnotationPresent(
                cl.loadClass(f.annotationName()).asInstanceOf[Class[Annotation]]
              )
        }.map { f => (cls, f) }
      }
    }
    testClasses
  }
}