summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/reflect/runtime/ToolBoxes.scala83
-rw-r--r--src/partest/scala/tools/partest/utils/CodeTest.scala10
-rw-r--r--test/files/run/code.check29
-rw-r--r--test/files/run/code.scala (renamed from test/disabled/run/code.scala)8
4 files changed, 108 insertions, 22 deletions
diff --git a/src/compiler/scala/reflect/runtime/ToolBoxes.scala b/src/compiler/scala/reflect/runtime/ToolBoxes.scala
index 8713ef01b7..6d2b97d0d1 100644
--- a/src/compiler/scala/reflect/runtime/ToolBoxes.scala
+++ b/src/compiler/scala/reflect/runtime/ToolBoxes.scala
@@ -12,7 +12,7 @@ import scala.tools.nsc.interpreter.AbstractFileClassLoader
import reflect.{mirror => rm}
import scala.tools.nsc.util.FreshNameCreator
import scala.reflect.internal.Flags
-import scala.tools.nsc.util.NoSourceFile
+import scala.tools.nsc.util.{NoSourceFile, NoFile}
import java.lang.{Class => jClass}
import scala.tools.nsc.util.trace
@@ -24,22 +24,48 @@ trait ToolBoxes extends { self: Universe =>
extends ReflectGlobal(settings, reporter) {
import definitions._
+ private val trace = scala.tools.nsc.util.trace when settings.debug.value
+
private final val wrapperMethodName = "wrapper"
+ private var wrapCount = 0
+
+ private def nextWrapperModuleName() = {
+ wrapCount += 1
+ "__wrapper$" + wrapCount
+ }
+
+ private def moduleFileName(className: String) = className + "$"
+
private def isFree(t: Tree) = t.isInstanceOf[Ident] && t.symbol.isInstanceOf[FreeVar]
- def wrapInClass(expr: Tree, fvs: List[Symbol]): ClassDef = {
- val clazz = EmptyPackageClass.newAnonymousClass(NoPosition)
- clazz setInfo ClassInfoType(List(ObjectClass.tpe), new Scope, clazz)
- val meth = clazz.newMethod(NoPosition, wrapperMethodName)
+ def wrapInObject(expr: Tree, fvs: List[Symbol]): ModuleDef = {
+ val obj = EmptyPackageClass.newModule(NoPosition, nextWrapperModuleName())
+ val minfo = ClassInfoType(List(ObjectClass.tpe), new Scope, obj.moduleClass)
+ obj.moduleClass setInfo minfo
+ obj setInfo obj.moduleClass.tpe
+ val meth = obj.moduleClass.newMethod(NoPosition, wrapperMethodName)
meth setFlag Flags.STATIC
- meth setInfo MethodType(meth.owner.newSyntheticValueParams(fvs map (_.tpe)), expr.tpe)
- clazz.info.decls enter meth
+ def makeParam(fv: Symbol) = meth.newValueParameter(NoPosition, fv.name) setInfo fv.tpe
+ meth setInfo MethodType(fvs map makeParam, expr.tpe)
+ minfo.decls enter meth
val methdef = DefDef(meth, expr)
- val clazzdef = ClassDef(clazz, NoMods, List(List()), List(List()), List(methdef), NoPosition)
- clazzdef
+ val objdef = ModuleDef(
+ obj,
+ Template(
+ List(TypeTree(ObjectClass.tpe)),
+ emptyValDef,
+ NoMods,
+ List(),
+ List(List()),
+ List(methdef),
+ NoPosition))
+ resetAllAttrs(objdef)
}
+ def wrapInPackage(clazz: Tree): PackageDef =
+ PackageDef(Ident(nme.EMPTY_PACKAGE_NAME), List(clazz))
+
def wrapInCompilationUnit(tree: Tree): CompilationUnit = {
val unit = new CompilationUnit(NoSourceFile)
unit.body = tree
@@ -47,26 +73,43 @@ trait ToolBoxes extends { self: Universe =>
}
def compileExpr(expr: Tree, fvs: List[Symbol]): String = {
- val cdef = trace("wrapped: ")(wrapInClass(expr, fvs))
- val unit = wrapInCompilationUnit(cdef)
+ val mdef = wrapInObject(expr, fvs)
+ val pdef = trace("wrapped: ")(wrapInPackage(mdef))
+ val unit = wrapInCompilationUnit(pdef)
val run = new Run
run.compileUnits(List(unit), run.namerPhase)
- cdef.name.toString
+ mdef.symbol.fullName
}
+ private def getMethod(jclazz: jClass[_], name: String) =
+ jclazz.getDeclaredMethods.find(_.getName == name).get
+
def runExpr(expr: Tree): Any = {
+ val etpe = expr.tpe
val fvs = (expr filter isFree map (_.symbol)).distinct
val className = compileExpr(expr, fvs)
- val jclazz = jClass.forName(className, true, classLoader)
+ if (settings.debug.value) println("generated: "+className)
+ val jclazz = jClass.forName(moduleFileName(className), true, classLoader)
val jmeth = jclazz.getDeclaredMethods.find(_.getName == wrapperMethodName).get
- jmeth.invoke(null, fvs map (sym => sym.asInstanceOf[FreeVar].value.asInstanceOf[AnyRef]): _*)
+ val result = jmeth.invoke(null, fvs map (sym => sym.asInstanceOf[FreeVar].value.asInstanceOf[AnyRef]): _*)
+ if (etpe.typeSymbol != FunctionClass(0)) result
+ else {
+ val applyMeth = result.getClass.getMethod("apply")
+ applyMeth.invoke(result)
+ }
}
}
- lazy val virtualDirectory = new VirtualDirectory("(memory)", None)
+ lazy val arguments = options.split(" ")
+
+ lazy val virtualDirectory =
+ (arguments zip arguments.tail) collect { case ("-d", dir) => dir } lastOption match {
+ case Some(outDir) => scala.tools.nsc.io.AbstractFile.getDirectory(outDir)
+ case None => new VirtualDirectory("(memory)", None)
+ }
lazy val compiler: ToolBoxGlobal = {
- val command = new CompilerCommand(options.split(" ").toList, reporter.error(scala.tools.nsc.util.NoPosition, _))
+ val command = new CompilerCommand(arguments.toList, reporter.error(scala.tools.nsc.util.NoPosition, _))
command.settings.outputDirs setSingleOutput virtualDirectory
new ToolBoxGlobal(command.settings, reporter)
}
@@ -80,10 +123,12 @@ trait ToolBoxes extends { self: Universe =>
lazy val classLoader = new AbstractFileClassLoader(virtualDirectory, getClass.getClassLoader)
private def importAndTypeCheck(tree: rm.Tree, expectedType: rm.Type): compiler.Tree = {
- val ctree: compiler.Tree = importer.importTree(tree.asInstanceOf[Tree])
- val pt: compiler.Type = importer.importType(expectedType.asInstanceOf[Type])
+ // need to establish a run an phase because otherwise we run into an assertion in TypeHistory
+ // that states that the period must be different from NoPeriod
val run = new compiler.Run
compiler.phase = run.refchecksPhase
+ val ctree: compiler.Tree = importer.importTree(tree.asInstanceOf[Tree])
+ val pt: compiler.Type = importer.importType(expectedType.asInstanceOf[Type])
val ttree: compiler.Tree = compiler.typer.typed(ctree, compiler.analyzer.EXPRmode, pt)
ttree
}
@@ -113,4 +158,4 @@ trait ToolBoxes extends { self: Universe =>
def runExpr(tree: rm.Tree): Any = runExpr(tree, WildcardType.asInstanceOf[rm.Type])
}
-} \ No newline at end of file
+}
diff --git a/src/partest/scala/tools/partest/utils/CodeTest.scala b/src/partest/scala/tools/partest/utils/CodeTest.scala
index 544b95a00d..c90168a313 100644
--- a/src/partest/scala/tools/partest/utils/CodeTest.scala
+++ b/src/partest/scala/tools/partest/utils/CodeTest.scala
@@ -18,14 +18,18 @@ import scala.tools.nsc.Settings
/** Runner for testing code tree liftingg
*/
object CodeTest {
+ def static[T](code: () => T, args: Array[String] = Array()) = {
+ println("static: "+code())
+ }
+
def apply[T](code: Code[T], args: Array[String] = Array()) = {
println("testing: "+code.tree)
val reporter = new ConsoleReporter(new Settings)
val toolbox = new ToolBox(reporter, args mkString " ")
val ttree = toolbox.typeCheck(code.tree, code.manifest.tpe)
println("result = " + toolbox.showAttributed(ttree))
- //val evaluated = toolbox.runExpr(ttree)
- //println("evaluated = "+evaluated)
- //evaluated
+ val evaluated = toolbox.runExpr(ttree)
+ println("evaluated = "+evaluated)
+ evaluated
}
}
diff --git a/test/files/run/code.check b/test/files/run/code.check
new file mode 100644
index 0000000000..b946554fda
--- /dev/null
+++ b/test/files/run/code.check
@@ -0,0 +1,29 @@
+testing: ((x: Int) => x.$plus(ys.length))
+result = ((x: Int) => x.+{(x: <?>)Int}(ys.length{Int}){Int}){Int => Int}
+evaluated = <function1>
+testing: (() => {
+ val e: Element = new Element("someName");
+ e
+})
+result = (() => {
+ val e: Element = new Element{Element}{(name: <?>)Element}("someName"{String("someName")}){Element};
+ e{Element}
+}{Element}){() => Element}
+evaluated = Element(someName)
+testing: (() => truc.elem = 6)
+result = (() => truc.elem{Int} = 6{Int(6)}{Unit}){() => Unit}
+evaluated = null
+testing: (() => truc.elem = truc.elem.$plus(6))
+result = (() => truc.elem{Int} = truc.elem.+{(x: <?>)Int}(6{Int(6)}){Int}{Unit}){() => Unit}
+evaluated = null
+testing: (() => new baz.BazElement("someName"))
+result = (() => new baz.BazElement{baz.BazElement}{(name: <?>)baz.BazElement}("someName"{String("someName")}){baz.BazElement}){() => baz.BazElement}
+evaluated = BazElement(someName)
+testing: ((x: Int) => x.$plus(ys.length))
+result = ((x: Int) => x.+{(x: <?>)Int}(ys.length{Int}){Int}){Int => Int}
+evaluated = <function1>
+static: 2
+testing: (() => x.$plus(1))
+result = (() => x.+{(x: <?>)Int}(1{Int(1)}){Int}){() => Int}
+evaluated = 2
+1+1 = 2
diff --git a/test/disabled/run/code.scala b/test/files/run/code.scala
index 8881c2eda8..e26f97b2a4 100644
--- a/test/disabled/run/code.scala
+++ b/test/files/run/code.scala
@@ -36,8 +36,16 @@ object Test extends App {
}
show()
+
+ def evaltest(x: Int) = {
+ CodeTest.static(() => x + 1, args)
+ CodeTest(() => x + 1, args)
+ }
+
+ println("1+1 = "+evaltest(1))
}
+
package baz {
case class BazElement(name: String) { }