aboutsummaryrefslogblamecommitdiff
path: root/libraries/eval/Eval.scala
blob: b9d225a543d2d9e5c88eb73bfce9f13903df39c9 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16















                                                                           

                        

                                                                          
                           
                              
                                  
                       

                               
                      
                                         
                                                          
                                                          

                                                       
 
                                                         
 
   
                                                   
   




                                                                




























                                                                           
                   
 
                                       





                                                                                                                       
                                  




                                                                                                                      
 
     




                                                                                     
     
                                                       





                                                                           



                                                                                       



         







                                                                              








                                                                            

     



                               
                                                               
                                               
                                              




                               
                                   


                                                                                               



                                                                                                              
       










                                                                                                       




                                                                                    

   
     
                                                   
     
                                          
                                                                              











                                                                 
   
 











                                                                                      



                                                     


                                                 
                                        
                                    

   



                                                                
                                              









                                                                                             

                                
                                                

                           
                                                       






                                                       
                                                







                                                                                
                                                




                                                           



                                                                                                               




                                                                           
 








                                                                  

   

                                                                                            
     
                                                       



                                                                    
                                                       






















                                                                                                     

   
                      
                                                                




                                         
                                        





                                                            


                                           







                                            










                                                               

                                            














                                                                                                

                                                                         
     
                                                                            
                                   





                                                                                   
                                                  





                                                                 
                                       
                                                                    



                                                                                  


                                                                                                    
               
             





                                                                                         



                                          

   
     

                                                                                               
     




                                                                          
 


                                                       

                                                                             
                                               
 
                                           






















                                                                                           
 


                         
 
                          


                        
     
 
                                               
 



                                                                                         
                                                                                       

                 












                                                                                            
                    
                                                                                     

     














                                                                














                                                          



                                                                        
                        
                                        
 
                                                                         
                                   
                                                                   
                              







                                                             
                                                                
       
                                                                                        
                    



                                        
         

       
   
 

                                                                              
 
/*
 * Copyright 2010 Twitter, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.twitter.util

import com.twitter.io.StreamIO
import java.io.{File, InputStream, FileInputStream, FileNotFoundException}
import java.math.BigInteger
import java.net.URLClassLoader
import java.security.MessageDigest
import java.util.Random
import java.util.jar.JarFile
import scala.collection.mutable
import scala.io.Source
import scala.tools.nsc.{Global, Settings}
import scala.tools.nsc.interpreter.AbstractFileClassLoader
import scala.tools.nsc.io.{AbstractFile, VirtualDirectory}
import scala.tools.nsc.reporters.AbstractReporter
import scala.tools.nsc.util.{BatchSourceFile, Position}

case class LastMod(timestamp: Option[Long], code: String)

/**
 * Evaluate a file or string and return the result.
 */
@deprecated("use a throw-away instance of Eval instead")
object Eval extends Eval {
  private val jvmId = java.lang.Math.abs(new Random().nextInt())
}

/**
 * evaluates files, strings or input streams, and returns the result.
 * In all cases, code to be evaled is wrapped in an apply method in a
 * generated class. An instance of the class is instantiated, and the
 * result of apply is returned.
 *
 * If target is None, the results are compiled to memory (and are therefore
 * ephemeral). If target is Some(path), path must point to a directory, and
 * eval emits class files to that directory.
 *
 * eval also supports a limited set of preprocessors. Limited means
 * exactly one, that supports directives of the form #include <file>.
 *
 * The general flow of evaluation is
 * # convert arguments to a string
 * # run preprocessors on that string
 * # wrap processed code in a class
 * # compile the class
 * # create an instance of that class
 * # return the results of apply()
 */
class Eval(target: Option[File]) {
  /**
   * empty constructor for backwards compatibility
   */
  def this() {
    this(None)
  }

  import Eval.jvmId

  private lazy val compilerPath = try {
    jarPathOfClass("scala.tools.nsc.Interpreter")
  } catch {
    case e =>
      throw new RuntimeException("Unable lo load scala interpreter from classpath (scala-compiler jar is missing?)", e)
  }

  private lazy val libPath = try {
    jarPathOfClass("scala.ScalaObject")
  } catch {
    case e =>
      throw new RuntimeException("Unable to load scala base object from classpath (scala-library jar is missing?)", e)
  }

  /**
   * Preprocessors to run the code through before it is passed to the Scala compiler.
   * if you want to add new resolvers, you can do so with
   * new Eval(...) {
   *   lazy val preprocessors = {...}
   * }
   */
  protected lazy val preprocessors: Seq[Preprocessor] =
    Seq(
      new IncludePreprocessor(
        Seq(
          new ClassScopedResolver(getClass),
          new FilesystemResolver(new File(".")),
          new FilesystemResolver(new File("." + File.separator + "config"))
        ) ++ (
          Option(System.getProperty("com.twitter.util.Eval.includePath")) map { path =>
            new FilesystemResolver(new File(path))
          }
        )
      )
    )

  private lazy val compiler = new StringCompiler(2, target)

  /**
   * run preprocessors on our string, returning a LastMod
   * where timestamp is the last modified time of any file in that contributed
   * to the text.
   * Last modified is computed here because we support includes
   */
  def sourceForString(code: String, lastModified: Option[Long]): LastMod = {
    preprocessors.foldLeft(LastMod(lastModified, code)) { (acc, p) =>
      val processed = p(acc.code, lastModified)

      // timestamp of the newest processed file.
      // if both are defined, take the max. otherwise
      // take any defined timestamp
      val newestProcessed = Seq(processed.timestamp, acc.timestamp).max
      LastMod(newestProcessed, processed.code)
    }
  }

  /**
   * Eval[Int]("1 + 1") // => 2
   */
  def apply[T](code: String, resetState: Boolean = true): T = {
    val processed = sourceForString(code, None)
    applyProcessed(processed.code, resetState)
  }

  /**
   * Eval[Int](new File("..."))
   */
  def apply[T](files: File*): T = {
    if (target.isDefined) {
      val targetDir = target.get
      val unprocessedSource = files.map { scala.io.Source.fromFile(_).mkString }.mkString("\n")
      val lastModified = files.foldLeft(None: Option[Long]) { (acc, f) => Seq(acc, Some(f.lastModified)).max }
      val processed = sourceForString(unprocessedSource, lastModified)
      val oldestTarget = targetDir.listFiles.foldLeft(Long.MaxValue) { (oldest, f) =>
        f.lastModified min oldest
      }
      processed.timestamp match {
        // if we got a last-modified-source timestamp threaded through, use it to check compiler resets
        case Some(newestSource) => {
          if (newestSource > oldestTarget) {
            compiler.reset()
          }
        }
        // if there are no timestamps anywhere, just reset the compiler
        case None => compiler.reset()
      }

      val className = "Evaluator__" + files(0).getName.split("\\.")(0)
      applyProcessed(className, processed.code, false)
    } else {
      apply(files.map { scala.io.Source.fromFile(_).mkString }.mkString("\n"), true)
    }
  }

  /**
   * Eval[Int](getClass.getResourceAsStream("..."))
   */
  def apply[T](stream: InputStream): T = {
    apply(sourceForString(Source.fromInputStream(stream).mkString, None).code)
  }

  /**
   * same as apply[T], but does not run preprocessors.
   * Will generate a classname of the form Evaluater__<unique>,
   * where unique is computed from the jvmID (a random number)
   * and a digest of code
   */
  def applyProcessed[T](code: String, resetState: Boolean): T = {
    val id = uniqueId(code)
    val className = "Evaluator__" + id
    applyProcessed(className, code, resetState)
  }

  /**
   * same as apply[T], but does not run preprocessors.
   */
  def applyProcessed[T](className: String, code: String, resetState: Boolean): T = {
    val cls = compiler(wrapCodeInClass(className, code), className, resetState)
    cls.getConstructor().newInstance().asInstanceOf[() => Any].apply().asInstanceOf[T]
  }

  /**
   * converts the given file to evaluable source.
   * delegates to toSource(code: String)
   */
  def toSource(file: File): String = {
    toSource(scala.io.Source.fromFile(file).mkString)
  }

  /**
   * converts the given file to evaluable source.
   */
  def toSource(code: String): String = {
    sourceForString(code, None).code
  }

  /**
   * Compile an entire source file into the virtual classloader.
   */
  def compile(code: String) {
    compiler(sourceForString(code, None).code)
  }

  /**
   * Like `Eval()`, but doesn't reset the virtual classloader before evaluating. So if you've
   * loaded classes with `compile`, they can be referenced/imported in code run by `inPlace`.
   */
  def inPlace[T](code: String) = {
    apply[T](code, false)
  }

  /**
   * Check if code is Eval-able.
   * @throws CompilerException if not Eval-able.
   */
  def check(code: String) {
    val id = uniqueId(sourceForString(code, None).code)
    val className = "Evaluator__" + id
    val wrappedCode = wrapCodeInClass(className, code)
    compile(wrappedCode) // may throw CompilerException
  }

  /**
   * Check if files are Eval-able.
   * @throws CompilerException if not Eval-able.
   */
  def check(files: File*) {
    val code = files.map { scala.io.Source.fromFile(_).mkString }.mkString("\n")
    check(code)
  }

  /**
   * Check if stream is Eval-able.
   * @throws CompilerException if not Eval-able.
   */
  def check(stream: InputStream) {
    check(scala.io.Source.fromInputStream(stream).mkString)
  }

  def findClass(className: String): Class[_] = {
    compiler.findClass(className).getOrElse { throw new ClassNotFoundException("no such class: " + className) }
  }

  private def uniqueId(code: String): String = {
    val digest = MessageDigest.getInstance("SHA-1").digest(code.getBytes())
    val sha = new BigInteger(1, digest).toString(16)
    sha + "_" + jvmId
  }

  /*
   * Wrap source code in a new class with an apply method.
   */
  private def wrapCodeInClass(className: String, code: String) = {
    "class " + className + " extends (() => Any) {\n" +
    "  def apply() = {\n" +
    code + "\n" +
    "  }\n" +
    "}\n"
  }

  /*
   * For a given FQ classname, trick the resource finder into telling us the containing jar.
   */
  private def jarPathOfClass(className: String) = try {
    val resource = className.split('.').mkString("/", "/", ".class")
    val path = getClass.getResource(resource).getPath
    val indexOfFile = path.indexOf("file:") + 5
    val indexOfSeparator = path.lastIndexOf('!')
    List(path.substring(indexOfFile, indexOfSeparator))
  }

  /*
   * Try to guess our app's classpath.
   * This is probably fragile.
   */
  lazy val impliedClassPath: List[String] = {
    val currentClassPath = this.getClass.getClassLoader.asInstanceOf[URLClassLoader].getURLs.
      map(_.toString).filter(_.startsWith("file:")).map(_.substring(5)).toList

    // if there's just one thing in the classpath, and it's a jar, assume an executable jar.
    currentClassPath ::: (if (currentClassPath.size == 1 && currentClassPath(0).endsWith(".jar")) {
      val jarFile = currentClassPath(0)
      val relativeRoot = new File(jarFile).getParentFile()
      val nestedClassPath = new JarFile(jarFile).getManifest.getMainAttributes.getValue("Class-Path")
      if (nestedClassPath eq null) {
        Nil
      } else {
        nestedClassPath.split(" ").map { f => new File(relativeRoot, f).getAbsolutePath }.toList
      }
    } else {
      Nil
    })
  }

  trait Preprocessor {
    def apply(code: String, lastModified: Option[Long]): LastMod
  }

  trait Resolver {
    def resolvable(path: String): Boolean
    def get(path: String): InputStream
    def lastModified(path: String): Long
  }

  class FilesystemResolver(root: File) extends Resolver {
    private[this] def file(path: String): File =
      new File(root.getAbsolutePath + File.separator + path)

    def resolvable(path: String): Boolean =
      file(path).exists

    def lastModified(path: String): Long = {
      if (resolvable(path)) {
        file(path).lastModified
      } else {
        0
      }
    }

    def get(path: String): InputStream =
      new FileInputStream(file(path))
  }

  class ClassScopedResolver(clazz: Class[_]) extends Resolver {
    private[this] def quotePath(path: String) =
      "/" + path

    def resolvable(path: String): Boolean =
      clazz.getResourceAsStream(quotePath(path)) != null

    def lastModified(path: String): Long = 0

    def get(path: String): InputStream =
      clazz.getResourceAsStream(quotePath(path))
  }

  class ResolutionFailedException(message: String) extends Exception

  /*
   * This is a preprocesor that can include files by requesting them from the given classloader
   *
   * Thusly, if you put FS directories on your classpath (e.g. config/ under your app root,) you
   * mix in configuration from the filesystem.
   *
   * @example #include file-name.scala
   *
   * This is the only directive supported by this preprocessor.
   *
   * Note that it is *not* recursive. Included files cannot have includes
   */
  class IncludePreprocessor(resolvers: Seq[Resolver]) extends Preprocessor {
    def maximumRecursionDepth = 100

    def apply(code: String, lastModified: Option[Long]): LastMod =
      apply(code, lastModified, maximumRecursionDepth)

    def apply(code: String, lastModified: Option[Long], maxDepth: Int): LastMod = {
      var lastMod = lastModified
      val lines = code.lines map { line: String =>
        val tokens = line.trim.split(' ')
        if (tokens.length == 2 && tokens(0).equals("#include")) {
          val path = tokens(1)
          resolvers find { resolver: Resolver =>
            resolver.resolvable(path)
          } match {
            case Some(r: Resolver) => {
              lastMod = Seq(lastMod, Some(r.lastModified(path))).max
              // recursively process includes
              if (maxDepth == 0) {
                throw new IllegalStateException("Exceeded maximum recusion depth")
              } else {
                val subLastMod = apply(StreamIO.buffer(r.get(path)).toString, lastMod, maxDepth - 1)
                lastMod = Seq(lastMod, subLastMod.timestamp).max
                subLastMod.code
              }
            }
            case _ =>
              throw new IllegalStateException("No resolver could find '%s'".format(path))
          }
        } else {
          line
        }
      }
      val processed = lines.mkString("\n")
      LastMod(lastMod, processed)
    }
  }

  /**
   * Dynamic scala compiler. Lots of (slow) state is created, so it may be advantageous to keep
   * around one of these and reuse it.
   */
  private class StringCompiler(lineOffset: Int, targetDir: Option[File]) {
    val target = targetDir match {
      case Some(dir) => AbstractFile.getDirectory(dir)
      case None => new VirtualDirectory("(memory)", None)
    }

    val cache = new mutable.HashMap[String, Class[_]]()

    val settings = new Settings
    settings.deprecation.value = true // enable detailed deprecation warnings
    settings.unchecked.value = true // enable detailed unchecked warnings
    settings.outputDirs.setSingleOutput(target)

    val pathList = compilerPath ::: libPath
    settings.bootclasspath.value = pathList.mkString(File.pathSeparator)
    settings.classpath.value = (pathList ::: impliedClassPath).mkString(File.pathSeparator)

    val reporter = new AbstractReporter {
      val settings = StringCompiler.this.settings
      val messages = new mutable.ListBuffer[List[String]]

      def display(pos: Position, message: String, severity: Severity) {
        severity.count += 1
        val severityName = severity match {
          case ERROR   => "error: "
          case WARNING => "warning: "
          case _ => ""
        }
        messages += (severityName + "line " + (pos.line - lineOffset) + ": " + message) ::
          (if (pos.isDefined) {
            pos.inUltimateSource(pos.source).lineContent.stripLineEnd ::
              (" " * (pos.column - 1) + "^") ::
              Nil
          } else {
            Nil
          })
      }

      def displayPrompt {
        // no.
      }

      override def reset {
        super.reset
        messages.clear()
      }
    }

    val global = new Global(settings, reporter)

    /*
     * Class loader for finding classes compiled by this StringCompiler.
     * After each reset, this class loader will not be able to find old compiled classes.
     */
    var classLoader = new AbstractFileClassLoader(target, this.getClass.getClassLoader)

    def reset() {
      targetDir match {
        case None => {
          target.asInstanceOf[VirtualDirectory].clear
        }
        case Some(t) => {
          target.foreach { abstractFile =>
            if (abstractFile.file == null || abstractFile.file.getName.endsWith(".class")) {
              abstractFile.delete
            }
          }
        }
      }
      cache.clear()
      reporter.reset
      classLoader = new AbstractFileClassLoader(target, this.getClass.getClassLoader)
    }

    object Debug {
      val enabled =
        System.getProperty("eval.debug") != null

      def printWithLineNumbers(code: String) {
        printf("Code follows (%d bytes)\n", code.length)

        var numLines = 0
        code.lines foreach { line: String =>
          numLines += 1
          println(numLines.toString.padTo(5, ' ') + "| " + line)
        }
      }
    }

    def findClass(className: String): Option[Class[_]] = {
      synchronized {
        cache.get(className).orElse {
          try {
            val cls = classLoader.loadClass(className)
            cache(className) = cls
            Some(cls)
          } catch {
            case e: ClassNotFoundException => None
          }
        }
      }
    }


    /**
     * Compile scala code. It can be found using the above class loader.
     */
    def apply(code: String) {
      if (Debug.enabled)
        Debug.printWithLineNumbers(code)

      // if you're looking for the performance hit, it's 1/2 this line...
      val compiler = new global.Run
      val sourceFiles = List(new BatchSourceFile("(inline)", code))
      // ...and 1/2 this line:
      compiler.compileSources(sourceFiles)

      if (reporter.hasErrors || reporter.WARNING.count > 0) {
        throw new CompilerException(reporter.messages.toList)
      }
    }

    /**
     * Compile a new class, load it, and return it. Thread-safe.
     */
    def apply(code: String, className: String, resetState: Boolean = true): Class[_] = {
      synchronized {
        if (resetState) reset()
        findClass(className).getOrElse {
          apply(code)
          findClass(className).get
        }
      }
    }
  }

  class CompilerException(val messages: List[List[String]]) extends Exception(
    "Compiler exception " + messages.map(_.mkString("\n")).mkString("\n"))
}