aboutsummaryrefslogblamecommitdiff
path: root/libraries/eval/Eval.scala
blob: cd91d20934540e3ddc1d37a9041f4bed6a2dde46 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16















                                                                           

                        
                                  
                           
                              
                                  
                       

                               
                      
                                         



                                                          

   
                                                   
   








                                                                





                                                                                                                       
                                  




                                                                                                                      
 
                                                   



                               
                                                               

                                      
                                                                                   
                                                                                      




                               

                                                                            

   
     
                                                   
     
                                          
                                                           
   
 














                                                                                             




                                                                           
 








                                                                  

   

                                                                                            
     
                                                       



                                                                    
                                                       






















                                                                                                     


     

                                                                                               
     
                                                 
                                                                 
 


                                                       

                                                                             

                                                         
                                           






















                                                                                           
 


                         
 
                          


                        
     
 
                                               
 



                                                                                         
                                                                                                 

                 

                            
                                                                                               















                                                                        
                                                                
       











                                                                                                    

       
   
 

                                                                              
 
/*
 * Copyright 2010 Twitter, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain
 * a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.twitter.util

import java.io.{File, InputStream}
import java.math.BigInteger
import java.net.URLClassLoader
import java.security.MessageDigest
import java.util.Random
import java.util.jar.JarFile
import scala.collection.mutable
import scala.io.Source
import scala.tools.nsc.{Global, Settings}
import scala.tools.nsc.interpreter.AbstractFileClassLoader
import scala.tools.nsc.io.VirtualDirectory
import scala.tools.nsc.reporters.AbstractReporter
import scala.tools.nsc.util.{BatchSourceFile, Position}

/**
 * Evaluate a file or string and return the result.
 */
@deprecated("use a throw-away instance of Eval instead")
object Eval extends Eval {
  private val jvmId = java.lang.Math.abs(new Random().nextInt())
}

class Eval {
  import Eval.jvmId
  
  private lazy val compilerPath = try {
    jarPathOfClass("scala.tools.nsc.Interpreter")
  } catch {
    case e =>
      throw new RuntimeException("Unable lo load scala interpreter from classpath (scala-compiler jar is missing?)", e)
  }

  private lazy val libPath = try {
    jarPathOfClass("scala.ScalaObject")
  } catch {
    case e =>
      throw new RuntimeException("Unable to load scala base object from classpath (scala-library jar is missing?)", e)
  }

  private lazy val compiler = new StringCompiler(2)

  /**
   * Eval[Int]("1 + 1") // => 2
   */
  def apply[T](code: String, resetState: Boolean = true): T = {
    val id = uniqueId(code)
    val className = "Evaluator__" + id
    val cls = compiler(wrapCodeInClass(className, code), className, id, resetState)
    cls.getConstructor().newInstance().asInstanceOf[() => Any].apply().asInstanceOf[T]
  }

  /**
   * Eval[Int](new File("..."))
   */
  def apply[T](files: File*): T = {
    apply(files.map { scala.io.Source.fromFile(_).mkString }.mkString("\n"))
  }

  /**
   * Eval[Int](getClass.getResourceAsStream("..."))
   */
  def apply[T](stream: InputStream): T = {
    apply(scala.io.Source.fromInputStream(stream).mkString)
  }

  /**
   * Compile an entire source file into the virtual classloader.
   */
  def compile(code: String) {
    Eval.compiler(code)
  }

  /**
   * Like `Eval()`, but doesn't reset the virtual classloader before evaluating. So if you've
   * loaded classes with `compile`, they can be referenced/imported in code run by `inPlace`.
   */
  def inPlace[T](code: String) = {
    apply[T](code, false)
  }

  private def uniqueId(code: String): String = {
    val digest = MessageDigest.getInstance("SHA-1").digest(code.getBytes())
    val sha = new BigInteger(1, digest).toString(16)
    sha + "_" + jvmId
  }

  /*
   * Wrap source code in a new class with an apply method.
   */
  private def wrapCodeInClass(className: String, code: String) = {
    "class " + className + " extends (() => Any) {\n" +
    "  def apply() = {\n" +
    code + "\n" +
    "  }\n" +
    "}\n"
  }

  /*
   * For a given FQ classname, trick the resource finder into telling us the containing jar.
   */
  private def jarPathOfClass(className: String) = try {
    val resource = className.split('.').mkString("/", "/", ".class")
    val path = getClass.getResource(resource).getPath
    val indexOfFile = path.indexOf("file:") + 5
    val indexOfSeparator = path.lastIndexOf('!')
    List(path.substring(indexOfFile, indexOfSeparator))
  }

  /*
   * Try to guess our app's classpath.
   * This is probably fragile.
   */
  lazy val impliedClassPath: List[String] = {
    val currentClassPath = this.getClass.getClassLoader.asInstanceOf[URLClassLoader].getURLs.
      map(_.toString).filter(_.startsWith("file:")).map(_.substring(5)).toList

    // if there's just one thing in the classpath, and it's a jar, assume an executable jar.
    currentClassPath ::: (if (currentClassPath.size == 1 && currentClassPath(0).endsWith(".jar")) {
      val jarFile = currentClassPath(0)
      val relativeRoot = new File(jarFile).getParentFile()
      val nestedClassPath = new JarFile(jarFile).getManifest.getMainAttributes.getValue("Class-Path")
      if (nestedClassPath eq null) {
        Nil
      } else {
        nestedClassPath.split(" ").map { f => new File(relativeRoot, f).getAbsolutePath }.toList
      }
    } else {
      Nil
    })
  }

  /**
   * Dynamic scala compiler. Lots of (slow) state is created, so it may be advantageous to keep
   * around one of these and reuse it.
   */
  private class StringCompiler(lineOffset: Int) {
    val virtualDirectory = new VirtualDirectory("(memory)", None)

    val cache = new mutable.HashMap[String, Class[_]]()

    val settings = new Settings
    settings.deprecation.value = true // enable detailed deprecation warnings
    settings.unchecked.value = true // enable detailed unchecked warnings
    settings.outputDirs.setSingleOutput(virtualDirectory)

    val pathList = compilerPath ::: libPath
    settings.bootclasspath.value = pathList.mkString(File.pathSeparator)
    settings.classpath.value = (pathList ::: impliedClassPath).mkString(File.pathSeparator)

    val reporter = new AbstractReporter {
      val settings = StringCompiler.this.settings
      val messages = new mutable.ListBuffer[List[String]]

      def display(pos: Position, message: String, severity: Severity) {
        severity.count += 1
        val severityName = severity match {
          case ERROR   => "error: "
          case WARNING => "warning: "
          case _ => ""
        }
        messages += (severityName + "line " + (pos.line - lineOffset) + ": " + message) ::
          (if (pos.isDefined) {
            pos.inUltimateSource(pos.source).lineContent.stripLineEnd ::
              (" " * (pos.column - 1) + "^") ::
              Nil
          } else {
            Nil
          })
      }

      def displayPrompt {
        // no.
      }

      override def reset {
        super.reset
        messages.clear()
      }
    }

    val global = new Global(settings, reporter)

    /*
     * Class loader for finding classes compiled by this StringCompiler.
     * After each reset, this class loader will not be able to find old compiled classes.
     */
    var classLoader = new AbstractFileClassLoader(virtualDirectory, this.getClass.getClassLoader)

    def reset() {
      virtualDirectory.clear
      reporter.reset
      classLoader = new AbstractFileClassLoader(virtualDirectory, this.getClass.getClassLoader)
    }

    /**
     * Compile scala code. It can be found using the above class loader.
     */
    def apply(code: String) {
      val compiler = new global.Run
      val sourceFiles = List(new BatchSourceFile("(inline)", code))
      compiler.compileSources(sourceFiles)

      if (reporter.hasErrors || reporter.WARNING.count > 0) {
        throw new CompilerException(reporter.messages.toList)
      }
    }

    /**
     * Compile a new class, load it, and return it. Thread-safe.
     */
    def apply(code: String, className: String, id: String, resetState: Boolean = true): Class[_] = {
      synchronized {
        cache.get(id) match {
          case Some(cls) =>
            cls
          case None =>
            if (resetState) reset()
            apply(code)
            val cls = classLoader.loadClass(className)
            cache(id) = cls
            cls
        }
      }
    }
  }

  class CompilerException(val messages: List[List[String]]) extends Exception(
    "Compiler exception " + messages.map(_.mkString("\n")).mkString("\n"))
}