summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc
diff options
context:
space:
mode:
authorIulian Dragos <jaguarul@gmail.com>2009-06-18 17:19:55 +0000
committerIulian Dragos <jaguarul@gmail.com>2009-06-18 17:19:55 +0000
commit3ee6b3653f8c25d7d6b19b9f5d4af7fa082146a8 (patch)
treee97b8c0dd8d61e82f825f528f98842f777621f7a /src/compiler/scala/tools/nsc
parentbe8e3c69114da5bc3020d5363b338b1c83aa22ef (diff)
downloadscala-3ee6b3653f8c25d7d6b19b9f5d4af7fa082146a8.tar.gz
scala-3ee6b3653f8c25d7d6b19b9f5d4af7fa082146a8.tar.bz2
scala-3ee6b3653f8c25d7d6b19b9f5d4af7fa082146a8.zip
Specialization landed in trunk.
Diffstat (limited to 'src/compiler/scala/tools/nsc')
-rw-r--r--src/compiler/scala/tools/nsc/Global.scala10
-rw-r--r--src/compiler/scala/tools/nsc/Settings.scala1
-rw-r--r--src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala29
-rw-r--r--src/compiler/scala/tools/nsc/ast/TreePrinters.scala53
-rw-r--r--src/compiler/scala/tools/nsc/ast/Trees.scala4
-rwxr-xr-xsrc/compiler/scala/tools/nsc/ast/parser/Parsers.scala10
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Definitions.scala2
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Flags.scala3
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Symbols.scala14
-rw-r--r--src/compiler/scala/tools/nsc/transform/Erasure.scala11
-rw-r--r--src/compiler/scala/tools/nsc/transform/InfoTransform.scala3
-rw-r--r--src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala1190
-rw-r--r--src/compiler/scala/tools/nsc/transform/TypingTransformers.scala5
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Duplicators.scala242
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala9
15 files changed, 1531 insertions, 55 deletions
diff --git a/src/compiler/scala/tools/nsc/Global.scala b/src/compiler/scala/tools/nsc/Global.scala
index 36080c64fa..9f6fe689cb 100644
--- a/src/compiler/scala/tools/nsc/Global.scala
+++ b/src/compiler/scala/tools/nsc/Global.scala
@@ -390,7 +390,13 @@ class Global(var settings: Settings, var reporter: Reporter) extends SymbolTable
val runsRightAfter = None
} with ExplicitOuter
- // phaseName = "erasure"
+ // phaseName = "specialize"
+ object specializeTypes extends {
+ val global: Global.this.type = Global.this
+ val runsAfter = List[String]("")
+ val runsRightAfter = Some("tailcalls")
+ } with SpecializeTypes
+
object erasure extends {
val global: Global.this.type = Global.this
val runsAfter = List[String]("explicitouter")
@@ -557,6 +563,8 @@ class Global(var settings: Settings, var reporter: Reporter) extends SymbolTable
phasesSet += uncurry // uncurry, translate function values to anonymous classes
phasesSet += tailCalls // replace tail calls by jumps
+ if (settings.specialize.value)
+ phasesSet += specializeTypes
phasesSet += explicitOuter // replace C.this by explicit outer pointers, eliminate pattern matching
phasesSet += erasure // erase generic types to Java 1.4 types, add interfaces for traits
phasesSet += lazyVals //
diff --git a/src/compiler/scala/tools/nsc/Settings.scala b/src/compiler/scala/tools/nsc/Settings.scala
index 3032663b54..20f16c785e 100644
--- a/src/compiler/scala/tools/nsc/Settings.scala
+++ b/src/compiler/scala/tools/nsc/Settings.scala
@@ -792,6 +792,7 @@ trait ScalacSettings {
List("no-cache", "mono-cache", "poly-cache", "invoke-dynamic"), "poly-cache") .
withHelpSyntax("-Ystruct-dispatch:<method>")
val Xwarndeadcode = BooleanSetting ("-Ywarn-dead-code", "Emit warnings for dead code")
+ val specialize = BooleanSetting ("-Yspecialize", "Specialize generic code on types.")
/**
* -P "Plugin" settings
diff --git a/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala b/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
index 3995bc9dbd..35a22a4f2c 100644
--- a/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
+++ b/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
@@ -48,7 +48,7 @@ abstract class TreeBrowsers {
*/
class SwingBrowser {
- def browse(t: Tree): Unit = {
+ def browse(t: Tree): Tree = {
val tm = new ASTTreeModel(t)
val frame = new BrowserFrame()
@@ -59,6 +59,7 @@ abstract class TreeBrowsers {
// wait for the frame to be closed
lock.acquire
+ t
}
def browse(units: Iterator[CompilationUnit]): Unit =
@@ -138,7 +139,7 @@ abstract class TreeBrowsers {
var splitPane: JSplitPane = _
var treeModel: TreeModel = _
- val textArea: JTextArea = new JTextArea(20, 150)
+ val textArea: JTextArea = new JTextArea(30, 120)
val infoPanel = new TextInfoPanel()
/** Create a frame that displays the AST.
@@ -201,7 +202,7 @@ abstract class TreeBrowsers {
/**
* Present detailed information about the selected tree node.
*/
- class TextInfoPanel extends JTextArea(30, 40) {
+ class TextInfoPanel extends JTextArea(20, 50) {
setFont(new Font("monospaced", Font.PLAIN, 12))
@@ -216,9 +217,13 @@ abstract class TreeBrowsers {
case _ =>
str.append("tree.pos: ").append(t.pos)
str.append("\nSymbol: ").append(TreeInfo.symbolText(t))
- str.append("\nSymbol info: \n")
- TreeInfo.symbolTypeDoc(t).format(getWidth() / getColumnWidth(), buf)
- str.append(buf.toString())
+ str.append("\nSymbol owner: ").append(
+ if ((t.symbol ne null) && t.symbol != NoSymbol)
+ t.symbol.owner.toString
+ else
+ "NoSymbol has no owner")
+ if ((t.symbol ne null) && t.symbol.isType)
+ str.append("\ntermSymbol: " + t.symbol.tpe.termSymbol + "\ntypeSymbol: " + t.symbol.tpe.typeSymbol)
str.append("\nSymbol tpe: ")
if (t.symbol ne null) {
str.append(t.symbol.tpe).append("\n")
@@ -226,7 +231,10 @@ abstract class TreeBrowsers {
TypePrinter.toDocument(t.symbol.tpe).format(getWidth() / getColumnWidth(), buf)
str.append(buf.toString())
}
- str.append("\nSymbol Attributes: \n").append(TreeInfo.symbolAttributes(t))
+ str.append("\n\nSymbol info: \n")
+ TreeInfo.symbolTypeDoc(t).format(getWidth() / getColumnWidth(), buf)
+ str.append(buf.toString())
+ str.append("\n\nSymbol Attributes: \n").append(TreeInfo.symbolAttributes(t))
str.append("\ntree.tpe: ")
if (t.tpe ne null) {
str.append(t.tpe.toString()).append("\n")
@@ -239,7 +247,6 @@ abstract class TreeBrowsers {
}
}
-
/** Computes different information about a tree node. It
* is used as central place to do all pattern matching against
* Tree.
@@ -567,7 +574,7 @@ abstract class TreeBrowsers {
if ((s ne null) && (s != NoSymbol)) {
var str = flagsToString(s.flags)
if (s.isStaticMember) str = str + " isStatic ";
- str
+ str + " annotations: " + s.annotations.mkString("", " ", "")
}
else ""
}
@@ -620,7 +627,7 @@ abstract class TreeBrowsers {
Document.group(
Document.nest(4, "TypeRef(" :/:
toDocument(pre) :: ", " :/:
- sym.name.toString() :: ", " :/:
+ sym.name.toString() + sym.idString :: ", " :/:
"[ " :: toDocument(args) ::"]" :: ")")
)
@@ -641,7 +648,7 @@ abstract class TreeBrowsers {
Document.group(
Document.nest(4,"ClassInfoType(" :/:
toDocument(parents) :: ", " :/:
- clazz.name.toString() :: ")")
+ clazz.name.toString() + clazz.idString :: ")")
)
case MethodType(params, result) =>
diff --git a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
index dec92d7832..df19bdcf4b 100644
--- a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
+++ b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
@@ -57,7 +57,10 @@ abstract class TreePrinters {
def printTypeParams(ts: List[TypeDef]) {
if (!ts.isEmpty) {
- print("["); printSeq(ts){printParam}{print(", ")}; print("]")
+ print("["); printSeq(ts){ t =>
+ printAnnotations(t)
+ printParam(t)
+ }{print(", ")}; print("]")
}
}
@@ -87,7 +90,6 @@ abstract class TreePrinters {
printColumn(List(tree), "{", ";", "}")
}
}
-
def symName(tree: Tree, name: Name): String =
if (tree.symbol != null && tree.symbol != NoSymbol) {
((if (tree.symbol.isMixinConstructor) "/*"+tree.symbol.owner.name+"*/" else "") +
@@ -177,6 +179,7 @@ abstract class TreePrinters {
case TypeDef(mods, name, tparams, rhs) =>
if (mods hasFlag (PARAM | DEFERRED)) {
+ printAnnotations(tree)
printModifiers(tree, mods); print("type "); printParam(tree)
} else {
printAnnotations(tree)
@@ -194,19 +197,19 @@ abstract class TreePrinters {
if (isNotRemap(s)) s._1.toString else s._1.toString + "=>" + s._2.toString
print("import "); print(expr)
- print(".")
+ print(".")
selectors match {
case List(s) =>
// If there is just one selector and it is not remapping a name, no braces are needed
- if (isNotRemap(s)) {
- print(selectorToString(s))
- } else {
- print("{"); print(selectorToString(s)); print("}")
- }
+ if (isNotRemap(s)) {
+ print(selectorToString(s))
+ } else {
+ print("{"); print(selectorToString(s)); print("}")
+ }
// If there is more than one selector braces are always needed
- case many =>
+ case many =>
print(many.map(selectorToString).mkString("{", ", ", "}"))
- }
+ }
case DocDef(comment, definition) =>
print(comment); println; print(definition)
@@ -410,21 +413,21 @@ abstract class TreePrinters {
def create(writer: PrintWriter): TreePrinter = new TreePrinter(writer)
def create(stream: OutputStream): TreePrinter = create(new PrintWriter(stream))
def create(): TreePrinter = {
- /** A writer that writes to the current Console and
- * is sensitive to replacement of the Console's
- * output stream.
- */
- object ConsoleWriter extends Writer {
- override def write(str: String) { Console.print(str) }
-
- def write(cbuf: Array[Char], off: Int, len: Int) {
- val str = new String(cbuf, off, len)
- write(str)
- }
-
- def close = { /* do nothing */ }
- def flush = { /* do nothing */ }
- }
create(new PrintWriter(ConsoleWriter))
}
+ /** A writer that writes to the current Console and
+ * is sensitive to replacement of the Console's
+ * output stream.
+ */
+ object ConsoleWriter extends Writer {
+ override def write(str: String) { Console.print(str) }
+
+ def write(cbuf: Array[Char], off: Int, len: Int) {
+ val str = new String(cbuf, off, len)
+ write(str)
+ }
+
+ def close = { /* do nothing */ }
+ def flush = { /* do nothing */ }
+ }
}
diff --git a/src/compiler/scala/tools/nsc/ast/Trees.scala b/src/compiler/scala/tools/nsc/ast/Trees.scala
index d175bfa398..1b8777a642 100644
--- a/src/compiler/scala/tools/nsc/ast/Trees.scala
+++ b/src/compiler/scala/tools/nsc/ast/Trees.scala
@@ -477,6 +477,10 @@ trait Trees {
def DefDef(sym: Symbol, rhs: Tree): DefDef =
DefDef(sym, Modifiers(sym.flags), rhs)
+ def DefDef(sym: Symbol, rhs: List[List[Symbol]] => Tree): DefDef = {
+ DefDef(sym, rhs(sym.info.paramss))
+ }
+
/** Abstract type, type parameter, or type alias */
case class TypeDef(mods: Modifiers, name: Name, tparams: List[TypeDef], rhs: Tree)
extends MemberDef {
diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
index c78886b823..fc85f7a196 100755
--- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
+++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
@@ -1725,15 +1725,15 @@ self =>
/** TypeParamClauseOpt ::= [TypeParamClause]
* TypeParamClause ::= `[' VariantTypeParam {`,' VariantTypeParam} `]']
- * VariantTypeParam ::= [`+' | `-'] TypeParam
+ * VariantTypeParam ::= {Annotation} [`+' | `-'] TypeParam
* FunTypeParamClauseOpt ::= [FunTypeParamClause]
* FunTypeParamClause ::= `[' TypeParam {`,' TypeParam} `]']
* TypeParam ::= Id TypeParamClauseOpt TypeBounds [<% Type]
*/
def typeParamClauseOpt(owner: Name, implicitViewBuf: ListBuffer[Tree]): List[TypeDef] = {
- def typeParam(): TypeDef = {
+ def typeParam(ms: Modifiers): TypeDef = {
+ var mods = ms | Flags.PARAM
val start = in.offset
- var mods = Modifiers(Flags.PARAM)
if (owner.isTypeName && isIdent) {
if (in.name == PLUS) {
in.nextToken()
@@ -1763,10 +1763,10 @@ self =>
newLineOptWhenFollowedBy(LBRACKET)
if (in.token == LBRACKET) {
in.nextToken()
- params += typeParam()
+ params += typeParam(NoMods.withAnnotations(annotations(true, false)))
while (in.token == COMMA) {
in.nextToken()
- params += typeParam()
+ params += typeParam(NoMods.withAnnotations(annotations(true, false)))
}
accept(RBRACKET)
}
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
index af158f27d4..d9041472a3 100644
--- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
@@ -500,7 +500,7 @@ trait Definitions {
}
val refClass = new HashMap[Symbol, Symbol]
- private val abbrvTag = new HashMap[Symbol, Char]
+ val abbrvTag = new HashMap[Symbol, Char]
private def newValueClass(name: Name, tag: Char): Symbol = {
val boxedName = sn.Boxed(name)
diff --git a/src/compiler/scala/tools/nsc/symtab/Flags.scala b/src/compiler/scala/tools/nsc/symtab/Flags.scala
index 2c7462296a..e6edef09f9 100644
--- a/src/compiler/scala/tools/nsc/symtab/Flags.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Flags.scala
@@ -38,7 +38,6 @@ object Flags {
final val BYNAMEPARAM = 0x00010000 // parameter is by name
final val CONTRAVARIANT = 0x00020000 // symbol is a contravariant type variable
final val LABEL = 0x00020000 // method symbol is a label. Set by TailCall
- final val DEFAULTINIT = 0x00020000 // field is initialized to the default value (used by checkinit)
final val INCONSTRUCTOR = 0x00020000 // class symbol is defined in this/superclass
// constructor.
final val ABSOVERRIDE = 0x00040000 // combination of abstract & override
@@ -81,6 +80,8 @@ object Flags {
// after each phase.
final val LOCKED = 0x8000000000L // temporary flag to catch cyclic dependencies
+ final val SPECIALIZED = 0x10000000000L// symbol is a generated specialized member
+ final val DEFAULTINIT = 0x20000000000L// symbol is a generated specialized member
final val InitialFlags = 0x0001FFFFFFFFFFFFL // flags that are enabled from phase 1.
final val LateFlags = 0x00FE000000000000L // flags that override flags in 0x1FC.
diff --git a/src/compiler/scala/tools/nsc/symtab/Symbols.scala b/src/compiler/scala/tools/nsc/symtab/Symbols.scala
index 227768acb0..2518e404f2 100644
--- a/src/compiler/scala/tools/nsc/symtab/Symbols.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Symbols.scala
@@ -124,6 +124,10 @@ trait Symbols {
/** Does this symbol have an annotation of the given class? */
def hasAnnotation(cls: Symbol) = annotations exists { _.atp.typeSymbol == cls }
+ /** Remove all annotations matching the given class. */
+ def removeAnnotation(cls: Symbol): Unit =
+ setAnnotations(annotations.remove(_.atp.typeSymbol == cls))
+
/** set when symbol has a modifier of the form private[X], NoSymbol otherwise.
* Here's some explanation how privateWithin gets combined with access flags:
*
@@ -988,9 +992,11 @@ trait Symbols {
cloneSymbol(owner)
/** A clone of this symbol, but with given owner */
- final def cloneSymbol(owner: Symbol): Symbol =
- cloneSymbolImpl(owner).setInfo(info.cloneInfo(this))
+ final def cloneSymbol(owner: Symbol): Symbol = {
+ val newSym = cloneSymbolImpl(owner)
+ newSym.setInfo(info.cloneInfo(newSym))
.setFlag(this.rawflags).setAnnotations(this.annotations)
+ }
/** Internal method to clone a symbol's implementation without flags or type
*/
@@ -1537,14 +1543,14 @@ trait Symbols {
}
override def alias: Symbol =
- if (hasFlag(SUPERACCESSOR | PARAMACCESSOR | MIXEDIN)) initialize.referenced
+ if (hasFlag(SUPERACCESSOR | PARAMACCESSOR | MIXEDIN | SPECIALIZED)) initialize.referenced
else NoSymbol
def setAlias(alias: Symbol): TermSymbol = {
assert(alias != NoSymbol, this)
assert(!(alias hasFlag OVERLOADED), alias)
- assert(hasFlag(SUPERACCESSOR | PARAMACCESSOR | MIXEDIN), this)
+ assert(hasFlag(SUPERACCESSOR | PARAMACCESSOR | MIXEDIN | SPECIALIZED), this)
referenced = alias
this
}
diff --git a/src/compiler/scala/tools/nsc/transform/Erasure.scala b/src/compiler/scala/tools/nsc/transform/Erasure.scala
index 7330b9c9e8..84b5e14298 100644
--- a/src/compiler/scala/tools/nsc/transform/Erasure.scala
+++ b/src/compiler/scala/tools/nsc/transform/Erasure.scala
@@ -102,8 +102,8 @@ abstract class Erasure extends AddInterfaces with typechecker.Analyzer {
case RefinedType(parents, decls) =>
if (parents.isEmpty) erasedTypeRef(ObjectClass)
else apply(parents.head)
- case AnnotatedType(_, atp, _) =>
- apply(atp)
+ case AnnotatedType(_, atp, _) =>
+ apply(atp)
case ClassInfoType(parents, decls, clazz) =>
ClassInfoType(
if ((clazz == ObjectClass) || (isValueType(clazz))) List()
@@ -781,7 +781,10 @@ abstract class Erasure extends AddInterfaces with typechecker.Analyzer {
val opc = new overridingPairs.Cursor(root) {
override def exclude(sym: Symbol): Boolean =
- !sym.isTerm || sym.hasFlag(PRIVATE) || super.exclude(sym)
+ (!sym.isTerm || sym.hasFlag(PRIVATE) || super.exclude(sym)
+ // specialized members have no type history before 'specialize', causing duble def errors for curried defs
+ || !sym.hasTypeAt(currentRun.refchecksPhase.id))
+
override def matches(sym1: Symbol, sym2: Symbol): Boolean =
atPhase(phase.next)(sym1.tpe =:= sym2.tpe)
}
@@ -794,7 +797,7 @@ abstract class Erasure extends AddInterfaces with typechecker.Analyzer {
opc.overriding.infosString +
opc.overridden.locationString + " " +
opc.overridden.infosString)
- doubleDefError(opc.overriding, opc.overridden)
+ doubleDefError(opc.overriding, opc.overridden)
}
opc.next
}
diff --git a/src/compiler/scala/tools/nsc/transform/InfoTransform.scala b/src/compiler/scala/tools/nsc/transform/InfoTransform.scala
index 5eeec493c0..ba4ce36c33 100644
--- a/src/compiler/scala/tools/nsc/transform/InfoTransform.scala
+++ b/src/compiler/scala/tools/nsc/transform/InfoTransform.scala
@@ -22,8 +22,11 @@ trait InfoTransform extends Transform {
new Phase(prev)
protected def changesBaseClasses = true
+ protected def keepsTypeParams = false
class Phase(prev: scala.tools.nsc.Phase) extends super.Phase(prev) {
+ override val keepsTypeParams = InfoTransform.this.keepsTypeParams
+
if (infoTransformers.nextFrom(id).pid != id) {
// this phase is not yet in the infoTransformers
val infoTransformer = new InfoTransformer {
diff --git a/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala b/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala
new file mode 100644
index 0000000000..a450b36091
--- /dev/null
+++ b/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala
@@ -0,0 +1,1190 @@
+package scala.tools.nsc.transform
+
+import scala.tools.nsc.symtab.Flags
+import scala.tools.nsc.util.FreshNameCreator
+import scala.tools.nsc.util.Position
+
+import scala.collection.{mutable, immutable}
+
+/** Specialize code on types.
+ */
+abstract class SpecializeTypes extends InfoTransform with TypingTransformers {
+ import global._
+ import Flags._
+ /** the name of the phase: */
+ val phaseName: String = "specialize"
+
+ /** This phase changes base classes. */
+ override def changesBaseClasses = true
+ override def keepsTypeParams = true
+
+ /** Concrete types for specialization */
+ final lazy val concreteTypes = List(definitions.IntClass.tpe, definitions.DoubleClass.tpe)
+
+ type TypeEnv = immutable.Map[Symbol, Type]
+ def emptyEnv: TypeEnv = immutable.ListMap.empty[Symbol, Type]
+
+ object TypeEnv {
+ /** Return a new type environment binding specialized type parameters of sym to
+ * the given args. Expects the lists to have the same length.
+ */
+ def fromSpecialization(sym: Symbol, args: List[Type]): TypeEnv = {
+ assert(sym.info.typeParams.length == args.length, sym + " args: " + args)
+ var env = emptyEnv
+ for ((tvar, tpe) <- sym.info.typeParams.zip(args) if tvar.hasAnnotation(SpecializedClass))
+ env = env + ((tvar, tpe))
+ env
+ }
+
+ /** Is this typeenv included in `other'? All type variables in this environement
+ * are defined in `other' and bound to the same type.
+ */
+ def includes(t1: TypeEnv, t2: TypeEnv) = {
+ t1 forall { kv =>
+ t2.get(kv._1) match {
+ case Some(v2) => v2 == kv._2
+ case _ => false
+ }
+ }
+ }
+
+ /** Reduce the given environment to contain mappins only for type variables in tps. */
+ def reduce(env: TypeEnv, tps: immutable.Set[Symbol]): TypeEnv = {
+ env filter { kv => tps.contains(kv._1)}
+ }
+
+ /** Is the given environment a valid specialization for sym?
+ * It is valid if each binding is from a @specialized type parameter in sym (or its owner)
+ * to a type for which `sym' is specialized.
+ */
+ def isValid(env: TypeEnv, sym: Symbol): Boolean = {
+ env forall { binding =>
+ val (tvar, tpe) = binding
+// log("isValid: " + env + " sym: " + sym + " sym.tparams: " + sym.typeParams)
+// log("Flag " + tvar + ": " + tvar.hasAnnotation(SpecializedClass))
+// log("tparams contains: " + sym.typeParams.contains(tvar))
+// log("concreteTypes: " + concreteTypes.contains(tpe))
+ ((tvar.hasAnnotation(SpecializedClass)
+ && sym.typeParams.contains(tvar)
+ && concreteTypes.contains(tpe))
+ || (if (sym.owner != definitions.RootClass) isValid(env, sym.owner) else false))
+ // FIXME: it expects all type parameters to appear in the same owner!
+ }
+ }
+ }
+
+ /** For a given class and concrete type arguments, give its specialized class */
+ val specializedClass: mutable.Map[(Symbol, TypeEnv), Symbol] = new mutable.HashMap
+
+ /** Map a method symbol to a list of its specialized overloads in the same class. */
+ private val overloads: mutable.Map[Symbol, List[Overload]] = new mutable.HashMap[Symbol, List[Overload]] {
+ override def default(key: Symbol): List[Overload] = Nil
+ }
+
+ case class Overload(sym: Symbol, env: TypeEnv) {
+ override def toString: String =
+ "specalized overload " + sym + " in " + env
+ }
+
+ /** The annotation used to mark specialized type parameters. */
+ lazy val SpecializedClass = definitions.getClass("scala.specialized")
+
+ protected def newTransformer(unit: CompilationUnit): Transformer =
+ new SpecializationTransformer(unit)
+
+ abstract class SpecializedInfo {
+ def target: Symbol
+
+ /** Are type bounds of @specialized type parameters of 'target' now in 'env'? */
+ def typeBoundsIn(env: TypeEnv) = false
+
+ /** A degenerated method has @specialized type parameters that appear only in
+ * type bounds of other @specialized type parameters (and not in its result type).
+ */
+ def degenerate = false
+ }
+
+ /** Symbol is a special overloaded method of 'original', in the environment env. */
+ case class SpecialOverload(original: Symbol, env: TypeEnv) extends SpecializedInfo {
+ def target = original
+ }
+
+ /** Symbol is a method that should be forwarded to 't' */
+ case class Forward(t: Symbol) extends SpecializedInfo {
+ def target = t
+ }
+
+ /** Symbol is a specialized accessor for the `target' field. */
+ case class SpecializedAccessor(target: Symbol) extends SpecializedInfo
+
+ /** Symbol is a specialized method whose body should be the target's method body. */
+ case class Implementation(target: Symbol) extends SpecializedInfo
+
+ /** Symbol is a normalized member of 'target'. */
+ case class NormalizedMember(target: Symbol) extends SpecializedInfo {
+
+ /** Type bounds of a @specialized type var are now in the environment. */
+ override def typeBoundsIn(env: TypeEnv): Boolean = {
+ target.info.typeParams exists { tvar =>
+ (tvar.hasAnnotation(SpecializedClass)
+ && (specializedTypeVars(tvar.info.bounds) exists env.isDefinedAt))
+ }
+ }
+
+ override lazy val degenerate = {
+ log("degenerate: " + target +
+ " stv tparams: " + specializedTypeVars(target.info.typeParams map (_.info)) +
+ " stv info: " + specializedTypeVars(target.info.resultType))
+ !(specializedTypeVars(target.info.typeParams map (_.info))
+ -- specializedTypeVars(target.info.resultType)).isEmpty
+ }
+ }
+
+ /** Map a symbol to additional information on specialization. */
+ private val info: mutable.Map[Symbol, SpecializedInfo] = new mutable.HashMap[Symbol, SpecializedInfo]
+
+ /** Has `clazz' any type parameters that need be specialized? */
+ def hasSpecializedParams(clazz: Symbol): Boolean =
+ !specializedParams(clazz).isEmpty
+
+ /** Return specialized type paramters. */
+ def specializedParams(sym: Symbol): List[Symbol] =
+ splitParams(sym.info.typeParams)._1
+
+ def splitParams(tps: List[Symbol]) =
+ tps.partition(_.hasAnnotation(SpecializedClass))
+
+ def unspecializedArgs(sym: Symbol, args: List[Type]): List[Type] =
+ for ((tvar, tpe) <- sym.info.typeParams.zip(args) if !tvar.hasAnnotation(SpecializedClass))
+ yield tpe
+
+ val specializedType = new TypeMap {
+ override def apply(tp: Type): Type = tp match {
+ case TypeRef(pre, sym, args) if !args.isEmpty =>
+ val pre1 = this(pre)
+ val args1 = args map this
+ val unspecArgs = unspecializedArgs(sym, args)
+ specializedClass.get((sym, TypeEnv.fromSpecialization(sym, args1))) match {
+ case Some(sym1) =>
+ assert(sym1.info.typeParams.length == unspecArgs.length, sym1)
+ typeRef(pre1, sym1, unspecArgs)
+ case None =>
+ typeRef(pre1, sym, args1)
+ }
+ case _ => mapOver(tp)
+ }
+ }
+
+ /** Return the specialized overload of sym in the given env, if any. */
+ def overload(sym: Symbol, env: TypeEnv) =
+ overloads(sym).find(ov => TypeEnv.includes(ov.env, env))
+
+ /** Return the specialized name of 'sym' in the given environment. It
+ * guarantees the same result regardless of the map order by sorting
+ * type variables alphabetically.
+ */
+ private def specializedName(sym: Symbol, env: TypeEnv): Name = {
+ val tvars = if (sym.isClass) env.keySet
+ else specializedTypeVars(sym.info).intersect(env.keySet)
+ log("specName(" + sym + ") env " + env + " tvars: " + tvars + " stv: " + specializedTypeVars(sym.info) + " info: " + sym.info)
+ val (methparams, others) = tvars.toList.partition(_.owner.isMethod)
+ val tvars1 = methparams.sort(_.name.toString < _.name.toString)
+ val tvars2 = others.sort(_.name.toString < _.name.toString)
+ specializedName(sym.name, tvars1 map env, tvars2 map env)
+ }
+
+ /** Specialize name for the two list of types. The first one denotes
+ * specialization on method type parameters, the second on outer environment.
+ */
+ private def specializedName(name: Name, types1: List[Type], types2: List[Type]): Name = {
+ def split: (String, String, String) = {
+ if (name.endsWith("$sp")) {
+ val name1 = name.subName(0, name.length - 3)
+ val idxC = name1.lastPos('c')
+ val idxM = name1.lastPos('m', idxC)
+ (name1.subName(0, idxM - 1).toString,
+ name1.subName(idxC + 1, name1.length).toString,
+ name1.subName(idxM + 1, idxC).toString)
+ } else
+ (name.toString, "", "")
+ }
+
+ if (nme.INITIALIZER == name || (types1.isEmpty && types2.isEmpty))
+ name
+ else if (nme.isSetterName(name))
+ nme.getterToSetter(specializedName(nme.setterToGetter(name), types1, types2))
+ else if (nme.isLocalName(name))
+ nme.getterToLocal(specializedName(nme.localToGetter(name), types1, types2))
+ else {
+ val (base, cs, ms) = split
+ newTermName(base + "$"
+ + "m" + ms + types1.map(t => definitions.abbrvTag(t.typeSymbol)).mkString("", "", "")
+ + "c" + cs + types2.map(t => definitions.abbrvTag(t.typeSymbol)).mkString("", "", "$sp"))
+ }
+ }
+
+ /** Generate all arrangements with repetitions from the list of values,
+ * with 'pos' positions. For example, count(2, List(1, 2)) yields
+ * List(List(1, 1), List(1, 2), List(2, 1), List(2, 2))
+ */
+ private def count[T](pos: Int, values: List[T]): List[List[T]] = {
+ if (pos == 0) Nil
+ else if (pos == 1) values map (v => List(v))
+ else for (v <- values; vs <- count(pos - 1, values)) yield v :: vs
+ }
+
+ /** Does the given tpe need to be specialized in the environment 'env'? */
+ private def needsSpecialization(env: TypeEnv, sym: Symbol): Boolean = {
+ def needsIt(tpe: Type): Boolean = tpe match {
+ case TypeRef(pre, sym, args) =>
+ (env.keys.contains(sym)
+ || (args exists needsIt))
+ case PolyType(tparams, resTpe) => needsIt(resTpe)
+ case MethodType(argTpes, resTpe) =>
+ (argTpes exists (sym => needsIt(sym.tpe))) || needsIt(resTpe)
+ case ClassInfoType(parents, stats, sym) =>
+ stats.toList exists (s => needsIt(s.info))
+ case _ => false
+ }
+
+ (needsIt(sym.tpe)
+ || (isNormalizedMember(sym) && info(sym).typeBoundsIn(env)))
+
+ }
+
+ def isNormalizedMember(m: Symbol): Boolean =
+ (m.hasFlag(SPECIALIZED) && (info.get(m) match {
+ case Some(NormalizedMember(_)) => true
+ case _ => false
+ }))
+
+
+ private def specializedTypeVars(tpe: List[Type]): immutable.Set[Symbol] =
+ tpe.foldLeft(immutable.ListSet.empty[Symbol]: immutable.Set[Symbol]) {
+ (s, tp) => s ++ specializedTypeVars(tp)
+ }
+
+ /** Return the set of @specialized type variables mentioned by the given type. */
+ private def specializedTypeVars(tpe: Type): immutable.Set[Symbol] = tpe match {
+ case TypeRef(pre, sym, args) =>
+ if (sym.isTypeParameter && sym.hasAnnotation(SpecializedClass))
+ specializedTypeVars(args) + sym
+ else
+ specializedTypeVars(args)
+ case PolyType(tparams, resTpe) =>
+ specializedTypeVars(tparams map (_.info)) ++ specializedTypeVars(resTpe)
+ case MethodType(argSyms, resTpe) =>
+ specializedTypeVars(argSyms map (_.tpe)) ++ specializedTypeVars(resTpe)
+ case ExistentialType(_, res) => specializedTypeVars(res)
+ case AnnotatedType(_, tp, _) => specializedTypeVars(tp)
+ case TypeBounds(hi, lo) => specializedTypeVars(hi) ++ specializedTypeVars(lo)
+ case _ => immutable.ListSet.empty[Symbol]
+ }
+
+ /** Specialize 'clazz', in the environment `outerEnv`. The outer
+ * environment contains bindings for specialized types of enclosing
+ * classes.
+ *
+ * A class C is specialized w.r.t to its own specialized type params
+ * `stps`, by specializing its members, and creating a new class for
+ * each combination of `stps`.
+ */
+ def specializeClass(clazz: Symbol, outerEnv: TypeEnv): List[Symbol] = {
+ def specializedClass(env: TypeEnv, normMembers: List[Symbol]): Symbol = {
+ val cls = clazz.owner.newClass(clazz.pos, specializedName(clazz, env))
+ .setFlag(SPECIALIZED | clazz.flags)
+ .resetFlag(CASE)
+ cls.sourceFile = clazz.sourceFile
+ currentRun.symSource(cls) = clazz.sourceFile // needed later on by mixin
+
+ typeEnv(cls) = env
+ this.specializedClass((clazz, env)) = cls
+
+ val decls1 = newScope
+
+ val specializedInfoType: Type = {
+ val (_, unspecParams) = splitParams(clazz.info.typeParams)
+ val tparams1 = cloneSymbols(unspecParams, cls)
+ var parents = List(subst(env, clazz.tpe).subst(unspecParams, tparams1 map (_.tpe)))
+ if (parents.head.typeSymbol.isTrait)
+ parents = parents.head.parents.head :: parents
+ val infoType = ClassInfoType(parents, decls1, cls)
+ if (tparams1.isEmpty) infoType else PolyType(tparams1, infoType)
+ }
+
+ atPhase(phase.next)(cls.setInfo(specializedInfoType))
+
+ val fullEnv = outerEnv ++ env
+
+ /** Enter 'sym' in the scope of the current specialized class. It's type is
+ * mapped through the active environment, binding type variables to concrete
+ * types. The existing typeEnv for `sym' is composed with the current active
+ * environment
+ */
+ def enterMember(sym: Symbol): Symbol = {
+ typeEnv(sym) = fullEnv ++ typeEnv(sym) // append the full environment
+ sym.setInfo(sym.info.substThis(clazz, ThisType(cls)))
+ decls1.enter(subst(fullEnv)(sym))
+ }
+
+ /** Create and enter in scope an overriden symbol m1 for `m' that forwards
+ * to `om'. `om' is a fresh, special overload of m1 that is an implementation
+ * of `m'. For example, for a
+ *
+ * class Foo[@specialized A] {
+ * def m(x: A) = <body>
+ * }
+ * , for class Foo$I extends Foo[Int], this method enters two new symbols in
+ * the scope of Foo$I:
+ *
+ * def m(x: Int) = m$I(x)
+ * def m$I(x: Int) = <body>/adapted to env {A -> Int}
+ */
+ def forwardToOverload(m: Symbol): Symbol = {
+ val specMember = enterMember(m.cloneSymbol(cls)).setFlag(OVERRIDE | SPECIALIZED).resetFlag(DEFERRED | CASEACCESSOR)
+ val om = specializedOverload(cls, m, env).setFlag(OVERRIDE)
+
+ var original = info.get(m) match {
+ case Some(NormalizedMember(tg)) => tg
+ case _ => m
+ }
+
+ info(specMember) = Forward(om)
+ info(om) = Implementation(original)
+ typeEnv(om) = env ++ typeEnv(m) // add the environment for any method tparams
+
+ enterMember(om)
+ }
+
+ log("specializedClass: " + cls)
+ for (m <- normMembers if needsSpecialization(outerEnv ++ env, m) && satisfiable(fullEnv)) {
+ log(" * looking at: " + m)
+ if (!m.isDeferred) concreteSpecMethods += m
+
+ // specialized members have to be overridable. Fields should not
+ // have the field CASEACCESSOR (messes up patmatch)
+ if (m.hasFlag(PRIVATE))
+ m.resetFlag(PRIVATE | CASEACCESSOR).setFlag(PROTECTED)
+
+ if (m.isConstructor) {
+ val specCtor = enterMember(m.cloneSymbol(cls).setFlag(SPECIALIZED))
+ info(specCtor) = Forward(m)
+
+ } else if (isNormalizedMember(m)) { // methods added by normalization
+ val NormalizedMember(original) = info(m)
+ if (!conflicting(env ++ typeEnv(m))) {
+ if (info(m).degenerate) {
+ log("degenerate normalized member " + m + " info(m): " + info(m))
+ val specMember = enterMember(m.cloneSymbol(cls)).setFlag(SPECIALIZED).resetFlag(DEFERRED)
+ info(specMember) = Implementation(original)
+ typeEnv(specMember) = env ++ typeEnv(m)
+ } else {
+ val om = forwardToOverload(m)
+ log("normalizedMember " + m + " om: " + om + " typeEnv(om): " + typeEnv(om))
+ }
+ } else
+ log("conflicting env for " + m + " env: " + env)
+
+ } else if (m.isDeferred) { // abstract methods
+ val specMember = enterMember(m.cloneSymbol(cls)).setFlag(SPECIALIZED).resetFlag(DEFERRED)
+ log("deferred " + specMember.fullNameString + " is forwarded")
+
+ info(specMember) = new Forward(specMember) {
+ override def target = m.owner.info.member(specializedName(m, env))
+ }
+
+ } else if (m.isMethod && !m.hasFlag(ACCESSOR)) { // other concrete methods
+ forwardToOverload(m)
+
+ } else if (m.isValue && !m.isMethod) { // concrete value definition
+ def mkAccessor(field: Symbol, name: Name) = {
+ val sym = cls.newMethod(field.pos, name)
+ .setFlag(SPECIALIZED | m.getter(clazz).flags)
+ .resetFlag(LOCAL | PARAMACCESSOR | CASEACCESSOR) // we rely on the super class to initialize param accessors
+ info(sym) = SpecializedAccessor(field)
+ sym
+ }
+
+ def overrideIn(clazz: Symbol, sym: Symbol) = {
+ val sym1 = sym.cloneSymbol(clazz)
+ .setFlag(OVERRIDE | SPECIALIZED)
+ .resetFlag(DEFERRED | CASEACCESSOR | ACCESSOR)
+ sym1.setInfo(sym1.info.asSeenFrom(clazz.tpe, sym1.owner))
+ }
+
+ val specVal = specializedOverload(cls, m, env)
+
+ concreteSpecMethods += m
+ specVal.asInstanceOf[TermSymbol].setAlias(m)
+
+ enterMember(specVal)
+ // create accessors
+ log("m: " + m + " isLocal: " + nme.isLocalName(m.name) + " specVal: " + specVal.name + " isLocal: " + nme.isLocalName(specVal.name))
+ if (nme.isLocalName(m.name)) {
+ val specGetter = mkAccessor(specVal, nme.localToGetter(specVal.name)).setInfo(MethodType(List(), specVal.info))
+ val origGetter = overrideIn(cls, m.getter(clazz))
+ info(origGetter) = Forward(specGetter)
+ enterMember(specGetter)
+ enterMember(origGetter)
+ log("created accessors: " + specGetter + " orig: " + origGetter)
+
+ clazz.caseFieldAccessors.find(_.name.startsWith(m.name)) foreach { cfa =>
+ val cfaGetter = overrideIn(cls, cfa)
+ info(cfaGetter) = SpecializedAccessor(specVal)
+ enterMember(cfaGetter)
+ log("found case field accessor for " + m + " added override " + cfaGetter);
+ }
+
+ if (specVal.isVariable) {
+ val specSetter = mkAccessor(specVal, nme.getterToSetter(specGetter.name))
+ .resetFlag(STABLE)
+ specSetter.setInfo(MethodType(specSetter.newSyntheticValueParams(List(specVal.info)),
+ definitions.UnitClass.tpe))
+ val origSetter = overrideIn(cls, m.setter(clazz))
+ info(origSetter) = Forward(specSetter)
+ enterMember(specSetter)
+ enterMember(origSetter)
+ }
+ } else { // if there are no accessors, specialized methods will need to access this field in specialized subclasses
+ m.resetFlag(PRIVATE)
+ specVal.resetFlag(PRIVATE)
+ }
+ }
+ }
+ cls
+ }
+
+ val decls1 = (clazz.info.decls.toList flatMap { m: Symbol =>
+ normalizeMember(m.owner, m, outerEnv) flatMap { normalizedMember =>
+ val ms = specializeMember(m.owner, normalizedMember, outerEnv, clazz.info.typeParams)
+ if (normalizedMember.isMethod) {
+ val newTpe = subst(outerEnv, normalizedMember.info)
+ if (newTpe != normalizedMember.info) // only do it when necessary, otherwise the method type might be at a later phase already
+ normalizedMember.updateInfo(newTpe) :: ms
+ else
+ normalizedMember :: ms
+ } else
+ normalizedMember :: ms
+ }
+ })
+
+ var hasSubclasses = false
+ for (env <- specializations(clazz.info.typeParams) if satisfiable(env)) {
+ val spc = specializedClass(env, decls1)
+ log("entered " + spc + " in " + clazz.owner)
+ hasSubclasses = true
+ atPhase(phase.next)(clazz.owner.info.decls enter spc) //!! assumes fully specialized classes
+ }
+ if (hasSubclasses) clazz.resetFlag(FINAL)
+ decls1
+ }
+
+ /** Expand member `sym' to a set of normalized members. Normalized members
+ * are monomorphic or polymorphic only in non-specialized types.
+ *
+ * Given method m[@specialized T, U](x: T, y: U) it returns
+ * m[T, U](x: T, y: U),
+ * m$I[ U](x: Int, y: U),
+ * m$D[ U](x: Double, y: U)
+ */
+ private def normalizeMember(owner: Symbol, sym: Symbol, outerEnv: TypeEnv): List[Symbol] = {
+ if (sym.isMethod && !sym.info.typeParams.isEmpty) {
+ val (stps, tps) = splitParams(sym.info.typeParams)
+ val res = sym :: (for (env <- specializations(stps)) yield {
+ val keys = env.keys.toList;
+ val vals = env.values.toList
+ val specMember = sym.cloneSymbol(owner).setFlag(SPECIALIZED).resetFlag(DEFERRED)
+ specMember.name = specializedName(sym, env)
+ typeEnv(specMember) = outerEnv ++ env
+ val tps1 = cloneSymbols(tps)
+ for (tp <- tps1) tp.setInfo(tp.info.subst(keys, vals))
+ val methodType = sym.info.resultType.subst(keys ::: tps, vals ::: (tps1 map (_.tpe)))
+
+ specMember.setInfo(polyType(tps1, methodType))
+ log("expanded member: " + sym + " -> " + specMember + ": " + specMember.info + " env: " + env)
+ info(specMember) = NormalizedMember(sym)
+ overloads(sym) = Overload(specMember, env) :: overloads(sym)
+ specMember
+ })
+ //stps foreach (_.removeAttribute(SpecializedClass))
+ res
+ } else List(sym)
+ }
+
+ /** Specialize member `m' w.r.t. to the outer environment and the type parameters of
+ * the innermost enclosing class.
+ *
+ * Turns 'private' into 'protected' for members that need specialization.
+ *
+ * Return a list of symbols that are specializations of 'sym', owned by 'owner'.
+ */
+ private def specializeMember(owner: Symbol, sym: Symbol, outerEnv: TypeEnv, tps: List[Symbol]): List[Symbol] = {
+ def specializeOn(tparams: List[Symbol]): List[Symbol] =
+ for (spec <- specializations(tparams)) yield {
+ if (sym.hasFlag(PRIVATE)) sym.resetFlag(PRIVATE).setFlag(PROTECTED)
+ val specMember = subst(outerEnv)(specializedOverload(owner, sym, spec))
+ typeEnv(specMember) = outerEnv ++ spec
+ overloads(sym) = Overload(specMember, spec) :: overloads(sym)
+ specMember
+ }
+
+ if (sym.isMethod) {
+// log("specializeMember " + sym + " with own stps: " + specializedTypes(sym.info.typeParams))
+ val tps1 = if (sym.isConstructor) tps filter (tp => sym.info.paramTypes.contains(tp)) else tps
+ val tps2 = tps1 intersect specializedTypeVars(sym.info).toList
+ if (!sym.isDeferred) concreteSpecMethods += sym
+
+ specializeOn(tps2) map {m => info(m) = SpecialOverload(sym, typeEnv(m)); m}
+ } else
+ List()
+ }
+
+ /** Return the specialized overload of `m', in the given environment. */
+ private def specializedOverload(owner: Symbol, sym: Symbol, env: TypeEnv): Symbol = {
+ val specMember = sym.cloneSymbol(owner) // this method properly duplicates the symbol's info
+ specMember.name = specializedName(sym, env)
+
+ specMember.setInfo(subst(env, specMember.info))
+ .setFlag(SPECIALIZED)
+ .resetFlag(DEFERRED | CASEACCESSOR | ACCESSOR)
+ }
+
+ /** For each method m that overrides inherited method m', add a special
+ * overload method `om' that overrides the corresponding overload in the
+ * superclass. For the following example:
+ *
+ * class IntFun extends Function1[Int, Int] {
+ * def apply(x: Int): Int = ..
+ * }
+ *
+ * this method will return List('apply$spec$II')
+ */
+ private def specialOverrides(clazz: Symbol): List[Symbol] = {
+ log("specialOverrides(" + clazz + ")")
+ val opc = new overridingPairs.Cursor(clazz)
+ val oms = new mutable.ListBuffer[Symbol]
+ while (opc.hasNext) {
+// log("\toverriding pairs: " + opc.overridden.fullNameString + ": " + opc.overridden.info
+// + "> " + opc.overriding.fullNameString + ": " + opc.overriding.info)
+ if (!specializedTypeVars(opc.overridden.info).isEmpty) {
+// log("\t\tspecializedTVars: " + specializedTypeVars(opc.overridden.info))
+ val env = unify(opc.overridden.info, opc.overriding.info, emptyEnv)
+ log("\t\tenv: " + env)
+ if (!env.isEmpty
+ && TypeEnv.isValid(env, opc.overridden)
+ && opc.overridden.owner.info.decl(specializedName(opc.overridden, env)) != NoSymbol) {
+ log("Added specialized overload for " + opc.overriding.fullNameString + " in env: " + env)
+ val om = specializedOverload(clazz, opc.overridden, env)
+ if (!opc.overriding.isDeferred) {
+ concreteSpecMethods += opc.overriding
+ info(om) = Implementation(opc.overriding)
+ info(opc.overriding) = Forward(om)
+ }
+ overloads(opc.overriding) = Overload(om, env) :: overloads(opc.overriding)
+ oms += om
+ atPhase(phase.next)(
+ assert(opc.overridden.owner.info.decl(om.name) != NoSymbol,
+ "Could not find " + om.name + " in " + opc.overridden.owner.info.decls))
+ }
+ }
+ opc.next
+ }
+ oms.toList
+ }
+
+ /** Return the most general type environment that specializes tp1 to tp2.
+ * It only allows binding of type parameters annotated with @specialized.
+ * Fails if such an environment cannot be found.
+ */
+ private def unify(tp1: Type, tp2: Type, env: TypeEnv): TypeEnv = (tp1, tp2) match {
+ case (TypeRef(_, sym1, _), _) if sym1.hasAnnotation(SpecializedClass) =>
+ if (definitions.isValueType(tp2.typeSymbol))
+ env + ((sym1, tp2))
+ else
+ env
+ case (TypeRef(_, sym1, args1), TypeRef(_, sym2, args2)) =>
+ unify(args1, args2, env)
+ case (TypeRef(_, sym1, _), _) if sym1.isTypeParameterOrSkolem =>
+ env
+ case (MethodType(params1, res1), MethodType(params2, res2)) =>
+ unify(res1 :: (params1 map (_.tpe)), res2 :: (params2 map (_.tpe)), env)
+ case (PolyType(tparams1, res1), PolyType(tparams2, res2)) =>
+ unify(res1, res2, env)
+ case (PolyType(_, res), other) =>
+ unify(res, other, env)
+ case (ThisType(_), ThisType(_)) => env
+ case (_, SingleType(_, _)) => unify(tp1, tp2.underlying, env)
+ case (SingleType(_, _), _) => unify(tp1.underlying, tp2, env)
+ case (ThisType(_), _) => unify(tp1.widen, tp2, env)
+ case (_, ThisType(_)) => unify(tp1, tp2.widen, env)
+ case (RefinedType(_, _), RefinedType(_, _)) => env
+ case (AnnotatedType(_, tp1, _), tp2) => unify(tp2, tp1, env)
+ case (ExistentialType(_, res1), _) => unify(tp2, res1, env)
+ }
+
+ private def unify(tp1: List[Type], tp2: List[Type], env: TypeEnv): TypeEnv =
+ tp1.zip(tp2).foldLeft(env) { (env, args) =>
+ unify(args._1, args._2, env)
+ }
+
+ private def specializedTypes(tps: List[Symbol]) = tps.filter(_.hasAnnotation(SpecializedClass))
+
+ /** Map class symbols to the type environments where they were created. */
+ val typeEnv: mutable.Map[Symbol, TypeEnv] = new mutable.HashMap[Symbol, TypeEnv] {
+ override def default(key: Symbol) = emptyEnv
+ }
+
+ /** Apply type bindings in the given environement `env' to all declarations. */
+ private def subst(env: TypeEnv, decls: List[Symbol]): List[Symbol] =
+ decls map subst(env)
+
+ private def subst(env: TypeEnv, tpe: Type): Type = {
+ // disabled because of bugs in std. collections
+ //val (keys, values) = env.iterator.toList.unzip
+ val keys = env.keysIterator.toList
+ val values = env.valuesIterator.toList
+ tpe.subst(keys, values)
+ }
+
+ private def subst(env: TypeEnv)(decl: Symbol): Symbol = {
+ val tpe = subst(env, decl.info)
+ decl.setInfo(if (decl.isConstructor) tpe match {
+ case MethodType(args, resTpe) => MethodType(args, decl.owner.tpe)
+ } else tpe)
+ }
+
+ /** Return a list of all type environements for all specializations
+ * of @specialized types in `tps'.
+ */
+ private def specializations(tps: List[Symbol]): List[TypeEnv] = {
+ val stps = tps filter (_.hasAnnotation(SpecializedClass))
+ val env = immutable.HashMap.empty[Symbol, Type]
+ count(stps.length, concreteTypes) map { tps =>
+ immutable.HashMap.empty[Symbol, Type] ++ (stps zip tps)
+ }
+ }
+
+ /** Type transformation.
+ */
+ override def transformInfo(sym: Symbol, tpe: Type): Type = {
+ val res = tpe match {
+ case PolyType(targs, ClassInfoType(base, decls, clazz)) =>
+ val parents = base map specializedType
+ PolyType(targs, ClassInfoType(parents, newScope(specializeClass(clazz, typeEnv(clazz))), clazz))
+
+ case ClassInfoType(base, decls, clazz) =>
+// val parents = base map specializedType
+// log("set parents of " + clazz + " to: " + parents)
+ val res = ClassInfoType(base map specializedType, newScope(specializeClass(clazz, typeEnv(clazz))), clazz)
+ res
+
+ case _ =>
+ tpe
+ }
+ res
+
+ }
+
+ def conflicting(env: TypeEnv): Boolean = {
+ val silent = (pos: Position, str: String) => ()
+ conflicting(env, silent)
+ }
+
+ /** Is any type variable in `env' conflicting with any if its type bounds, when
+ * type bindings in `env' are taken into account?
+ *
+ * A conflicting type environment could still be satisfiable.
+ */
+ def conflicting(env: TypeEnv, warn: (Position, String) => Unit): Boolean =
+ env exists { case (tvar, tpe) =>
+ if (!(subst(env, tvar.info.bounds.lo) <:< tpe) && (tpe <:< subst(env, tvar.info.bounds.hi))) {
+ warn(tvar.pos, "Bounds prevent specialization for " + tvar)
+ true
+ } else false
+ }
+
+ /** The type environemnt is sound w.r.t. to all type bounds or only soft
+ * conflicts appear. An environment is sound if all bindings are within
+ * the bounds of the given type variable. A soft conflict is a binding
+ * that does not fall within the bounds, but whose bounds contain
+ * type variables that are @specialized, (that could become satisfiable).
+ */
+ def satisfiable(env: TypeEnv, warn: (Position, String) => Unit): Boolean = {
+ def matches(tpe1: Type, tpe2: Type): Boolean = {
+ val t1 = subst(env, tpe1)
+ val t2 = subst(env, tpe2)
+ ((t1 <:< t2)
+ || !specializedTypeVars(t1).isEmpty
+ || !specializedTypeVars(t2).isEmpty)
+ }
+
+ env forall { case (tvar, tpe) =>
+ ((matches(tvar.info.bounds.lo, tpe)
+ && matches(tpe, tvar.info.bounds.hi))
+ || { warn(tvar.pos, "Bounds prevent specialization of " + tvar);
+ log("specvars: "
+ + tvar.info.bounds.lo + ": " + specializedTypeVars(tvar.info.bounds.lo)
+ + " " + subst(env, tvar.info.bounds.hi) + ": " + specializedTypeVars(subst(env, tvar.info.bounds.hi)))
+ false })
+ }
+ }
+
+ def satisfiable(env: TypeEnv): Boolean = {
+ val silent = (pos: Position, str: String) => ()
+ satisfiable(env, silent)
+ }
+
+ import java.io.PrintWriter
+
+ /*************************** Term transformation ************************************/
+
+ class Duplicator extends {
+ val global: SpecializeTypes.this.global.type = SpecializeTypes.this.global
+ } with typechecker.Duplicators
+
+ import global.typer.typed
+
+ def specializeCalls(unit: CompilationUnit) = new TypingTransformer(unit) {
+ /** Map a specializable method to it's rhs, when not deferred. */
+ val body: mutable.Map[Symbol, Tree] = new mutable.HashMap
+
+ /** Map a specializable method to its value parameter symbols. */
+ val parameters: mutable.Map[Symbol, List[List[Symbol]]] = new mutable.HashMap
+
+ /** The current fresh name creator. */
+ implicit val fresh: FreshNameCreator = unit.fresh
+
+ /** Collect method bodies that are concrete specialized methods.
+ */
+ class CollectMethodBodies extends Traverser {
+ override def traverse(tree: Tree) = tree match {
+ case DefDef(mods, name, tparams, vparamss, tpt, rhs) if concreteSpecMethods(tree.symbol) || tree.symbol.isConstructor =>
+ log("adding body of " + tree.symbol)
+ body(tree.symbol) = rhs
+// body(tree.symbol) = tree // whole method
+ parameters(tree.symbol) = vparamss map (_ map (_.symbol))
+ super.traverse(tree)
+ case ValDef(mods, name, tpt, rhs) if concreteSpecMethods(tree.symbol) =>
+ body(tree.symbol) = rhs
+ super.traverse(tree)
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ import posAssigner._
+
+ override def transform(tree: Tree): Tree = {
+ val symbol = tree.symbol
+
+ /** The specialized symbol of 'tree.symbol' for tree.tpe, if there is one */
+ def specSym(qual: Tree): Option[Symbol] = {
+ val env = unify(symbol.tpe, tree.tpe, emptyEnv)
+ log("checking for rerouting: " + tree + " with sym.tpe: " + symbol.tpe + " tree.tpe: " + tree.tpe + " env: " + env)
+ if (!env.isEmpty) { // a method?
+ val specMember = overload(symbol, env)
+ if (specMember.isDefined) Some(specMember.get.sym)
+ else { // a field?
+ val specMember = qual.tpe.member(specializedName(symbol, env))
+ if (specMember ne NoSymbol) Some(specMember)
+ else None
+ }
+ } else None
+ }
+
+ def maybeTypeApply(fun: Tree, targs: List[Tree]) =
+ if (targs.isEmpty)fun else TypeApply(fun, targs)
+
+ curTree = tree
+ tree match {
+ case Apply(Select(New(tpt), nme.CONSTRUCTOR), args) =>
+ if (findSpec(tpt.tpe).typeSymbol ne tpt.tpe.typeSymbol) {
+ log("** instantiated specialized type: " + findSpec(tpt.tpe))
+ atPos(tree.pos)(
+ localTyper.typed(
+ Apply(
+ Select(New(TypeTree(findSpec(tpt.tpe))), nme.CONSTRUCTOR),
+ transformTrees(args))))
+ } else tree
+
+ case TypeApply(Select(qual, name), targs) if (!specializedTypeVars(symbol.info).isEmpty && name != nme.CONSTRUCTOR) =>
+ log("checking typeapp for rerouting: " + tree + " with sym.tpe: " + symbol.tpe + " tree.tpe: " + tree.tpe)
+ val qual1 = transform(qual)
+ specSym(qual1) match {
+ case Some(specMember) =>
+ assert(symbol.info.typeParams.length == targs.length)
+ val env = typeEnv(specMember)
+ val residualTargs =
+ for ((tvar, targ) <-symbol.info.typeParams.zip(targs) if !env.isDefinedAt(tvar))
+ yield targ
+ assert(residualTargs.length == specMember.info.typeParams.length)
+ val tree1 = maybeTypeApply(Select(qual1, specMember.name), residualTargs)
+ log("rewrote " + tree + " to " + tree1)
+ localTyper.typedOperator(atPos(tree.pos)(tree1)) // being polymorphic, it must be a method
+
+ case None => super.transform(tree)
+ }
+
+ case Select(qual, name) if (!specializedTypeVars(symbol.info).isEmpty && name != nme.CONSTRUCTOR) =>
+ val qual1 = transform(qual)
+ val env = unify(symbol.tpe, tree.tpe, emptyEnv)
+ log("checking for rerouting: " + tree + " with sym.tpe: " + symbol.tpe + " tree.tpe: " + tree.tpe + " env: " + env)
+ if (!env.isEmpty) {
+ val specMember = overload(symbol, env)
+ if (specMember.isDefined) {
+ log("** routing " + tree + " to " + specMember.get.sym.fullNameString + " tree: " + Select(qual1, specMember.get.sym.name))
+ localTyper.typedOperator(atPos(tree.pos)(Select(qual1, specMember.get.sym.name)))
+ } else {
+ val specMember = qual1.tpe.member(specializedName(symbol, env))
+ if (specMember ne NoSymbol) {
+ log("** using spec member " + specMember)
+ localTyper.typed(atPos(tree.pos)(Select(qual1, specMember.name)))
+ } else
+ super.transform(tree)
+ }
+ } else
+ super.transform(tree)
+
+ case PackageDef(name, stats) =>
+ tree.symbol.info // make sure specializations have been peformed
+ log("PackageDef owner: " + symbol)
+ atOwner(tree, symbol) {
+ val specMembers = implSpecClasses(stats) map localTyper.typed
+ treeCopy.PackageDef(tree, name, transformStats(stats ::: specMembers, symbol.moduleClass))
+ }
+
+ case Template(parents, self, body) =>
+ val specMembers = makeSpecializedMembers(tree.symbol.enclClass) ::: (implSpecClasses(body) map localTyper.typed)
+ if (!symbol.isPackageClass)
+ (new CollectMethodBodies)(tree)
+ treeCopy.Template(tree, parents, self, atOwner(currentOwner)(transformTrees(body ::: specMembers)))
+
+ case ddef @ DefDef(mods, name, tparams, vparamss, tpt, rhs) if info.isDefinedAt(symbol) =>
+ if (symbol.isConstructor) {
+ val t = atOwner(symbol) {
+ val superRef: Tree = Select(Super(nme.EMPTY.toTypeName, nme.EMPTY.toTypeName), nme.CONSTRUCTOR)
+ forwardCall(tree.pos, superRef, vparamss)
+ }
+ val tree1 = atPos(symbol.pos)(treeCopy.DefDef(tree, mods, name, tparams, vparamss, tpt, Block(List(t), Literal(()))))
+ log(tree1)
+ localTyper.typed(tree1)
+ } else info(symbol) match {
+
+ case Implementation(target) =>
+ assert(body.isDefinedAt(target), "sym: " + symbol.fullNameString + " target: " + target.fullNameString)
+ // we have an rhs, specialize it
+ val tree1 = duplicateBody(ddef, target)
+ log("implementation: " + tree1)
+ val DefDef(mods, name, tparams, vparamss, tpt, rhs) = tree1
+ treeCopy.DefDef(tree1, mods, name, tparams, vparamss, tpt, transform(rhs))
+
+ case NormalizedMember(target) =>
+ log("normalized member " + symbol + " of " + target)
+
+ if (conflicting(typeEnv(symbol))) {
+ val targs = makeTypeArguments(symbol, target)
+ log("targs: " + targs)
+ val call =
+ forwardCall(tree.pos,
+ TypeApply(
+ Select(This(symbol.owner), target),
+ targs map TypeTree),
+ vparamss)
+ log("call: " + call)
+ localTyper.typed(
+ treeCopy.DefDef(tree, mods, name, tparams, vparamss, tpt,
+ maybeCastTo(symbol.info.finalResultType,
+ target.info.subst(target.info.typeParams, targs).finalResultType,
+ call)))
+
+/* copy.DefDef(tree, mods, name, tparams, vparamss, tpt,
+ typed(Apply(gen.mkAttributedRef(definitions.Predef_error),
+ List(Literal("boom! you stepped on a bug. This method should never be called.")))))*/
+ } else {
+ // we have an rhs, specialize it
+ val tree1 = duplicateBody(ddef, target)
+ log("implementation: " + tree1)
+ val DefDef(mods, name, tparams, vparamss, tpt, rhs) = tree1
+ treeCopy.DefDef(tree1, mods, name, tparams, vparamss, tpt, transform(rhs))
+ }
+
+ case SpecialOverload(original, env) =>
+ log("completing specialized " + symbol.fullNameString + " calling " + original)
+ val t = DefDef(symbol, { vparamss =>
+ val fun = Apply(Select(This(symbol.owner), original),
+ makeArguments(original, vparamss.head))
+
+ maybeCastTo(symbol.owner.info.memberType(symbol).finalResultType,
+ symbol.owner.info.memberType(original).finalResultType,
+ fun)
+ })
+ log("created " + t)
+ localTyper.typed(t)
+
+ case fwd @ Forward(_) =>
+ val rhs1 = forwardCall(tree.pos, gen.mkAttributedRef(symbol.owner.thisType, fwd.target), vparamss)
+ log("completed forwarder to specialized overload: " + fwd.target + ": " + rhs1)
+ localTyper.typed(treeCopy.DefDef(tree, mods, name, tparams, vparamss, tpt, rhs1))
+
+ case SpecializedAccessor(target) =>
+ val rhs1 = if (symbol.isGetter)
+ gen.mkAttributedRef(target)
+ else
+ Assign(gen.mkAttributedRef(target), Ident(vparamss.head.head.symbol))
+ localTyper.typed(treeCopy.DefDef(tree, mods, name, tparams, vparamss, tpt, rhs1))
+ }
+
+ case ValDef(mods, name, tpt, rhs) if symbol.hasFlag(SPECIALIZED) =>
+ assert(body.isDefinedAt(symbol.alias))
+ val tree1 = treeCopy.ValDef(tree, mods, name, tpt, body(symbol.alias).duplicate)
+ log("now typing: " + tree1 + " in " + tree.symbol.owner.fullNameString)
+ val d = new Duplicator
+ d.retyped(localTyper.context1.asInstanceOf[d.Context],
+ tree1,
+ symbol.alias.enclClass,
+ symbol.enclClass,
+ typeEnv(symbol.alias) ++ typeEnv(tree.symbol))
+
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ private def reskolemize(tparams: List[TypeDef]): (List[Symbol], List[Symbol]) = {
+ val tparams1 = tparams map (_.symbol)
+ localTyper.namer.skolemize(tparams)
+ (tparams1, tparams map (_.symbol))
+ }
+
+ private def duplicateBody(tree: DefDef, target: Symbol): Tree = {
+ val symbol = tree.symbol
+ log("specializing body of" + symbol.fullNameString + ": " + symbol.info)
+ val DefDef(mods, name, tparams, vparamss, tpt, _) = tree
+ val (_, origtparams) = splitParams(target.typeParams)
+ log("substituting " + origtparams + " for " + symbol.typeParams)
+
+ // skolemize type parameters
+ val (oldtparams, newtparams) = reskolemize(tparams)
+
+ // create fresh symbols for value parameters to hold the skolem types
+ val vparamss1 = List(for (vdef <- vparamss.head; param = vdef.symbol) yield {
+ ValDef(param.cloneSymbol(symbol).setInfo(param.info.substSym(oldtparams, newtparams)))
+ })
+
+ // replace value and type paremeters of the old method with the new ones
+ val symSubstituter = new ImplementationAdapter(
+ List.flatten(parameters(target)) ::: origtparams,
+ List.flatten(vparamss1).map(_.symbol) ::: newtparams)
+ val adapter = new AdaptSpecializedValues
+ val tmp = symSubstituter(adapter(body(target).duplicate))
+ tpt.tpe = tpt.tpe.substSym(oldtparams, newtparams)
+
+ val meth = treeCopy.DefDef(tree, mods, name, tparams, vparamss1, tpt, tmp)
+
+ log("now typing: " + meth + " in " + symbol.owner.fullNameString)
+ val d = new Duplicator
+ d.retyped(localTyper.context1.asInstanceOf[d.Context],
+ meth,
+ target.enclClass,
+ symbol.enclClass,
+ typeEnv(target) ++ typeEnv(symbol))
+ }
+
+ /** A tree symbol substituter that substitutes on type skolems.
+ * If a type parameter is a skolem, it looks for the original
+ * symbol in the 'from' and maps it to the corresponding new
+ * symbol. The new symbol should probably be a type skolem as
+ * well (not enforced).
+ *
+ * All private members are made protected in order to be accessible from
+ * specialized classes.
+ */
+ class ImplementationAdapter(from: List[Symbol], to: List[Symbol]) extends TreeSymSubstituter(from, to) {
+ override val symSubst = new SubstSymMap(from, to) {
+ override def matches(sym1: Symbol, sym2: Symbol) =
+ if (sym2.isTypeSkolem) sym2.deSkolemize eq sym1
+ else sym1 eq sym2
+ }
+
+ /** All private members that are referenced are made protected,
+ * in order to be accessible from specialized subclasses.
+ */
+ override def traverse(tree: Tree): Unit = tree match {
+ case Select(qual, name) =>
+ if (tree.symbol.hasFlag(PRIVATE)) {
+ log("changing private flag of " + tree.symbol)
+ tree.symbol.resetFlag(PRIVATE).setFlag(PROTECTED)
+ }
+ super.traverse(tree)
+
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ /** Does the given tree need a cast to a type parameter's upper bound?
+ * A cast is needed for values of type A, where A is a specialized type
+ * variable with a non-trivial upper bound. When A is specialized, its
+ * specialization may not satisfy the upper bound. We generate casts to
+ * be able to type check code. Such methods will never be called, as they
+ * are not visible to the user. The compiler will insert such calls only when
+ * the bounds are satisfied.
+ */
+ private class AdaptSpecializedValues extends Transformer {
+ private def needsCast(tree: Tree): Boolean = {
+ val sym = tree.tpe.typeSymbol
+ (sym.isTypeParameterOrSkolem
+ && sym.hasAnnotation(SpecializedClass)
+ && sym.info.bounds.hi != definitions.AnyClass.tpe
+ /*&& !(tree.tpe <:< sym.info.bounds.hi)*/)
+ }
+
+ override def transform(tree: Tree): Tree = {
+ val tree1 = super.transform(tree)
+ if (needsCast(tree1)) {
+ log("inserting cast for " + tree1 + " tpe: " + tree1.tpe)
+ val tree2 = gen.mkAsInstanceOf(tree1, tree1.tpe.typeSymbol.info.bounds.hi, false)
+ log(" casted to: " + tree2)
+ tree2
+ } else
+ tree1
+ }
+ def apply(t: Tree): Tree = transform(t)
+ }
+
+ def warn(clazz: Symbol)(pos: Position, err: String) =
+ if (!clazz.hasFlag(SPECIALIZED))
+ unit.warning(pos, err)
+
+ /** Create trees for specialized members of 'cls', based on the
+ * symbols that are already there.
+ */
+ private def makeSpecializedMembers(cls: Symbol): List[Tree] = {
+ // add special overrides first
+ if (!cls.hasFlag(SPECIALIZED))
+ for (m <- specialOverrides(cls)) cls.info.decls.enter(m)
+ val mbrs = new mutable.ListBuffer[Tree]
+
+ for (m <- cls.info.decls.toList
+ if m.hasFlag(SPECIALIZED)
+ && (m.sourceFile ne null)
+ && satisfiable(typeEnv(m), warn(cls))) {
+ log("creating tree for " + m.fullNameString)
+ if (m.isMethod) {
+ if (m.isClassConstructor) {
+ val origParamss = parameters(info(m).target)
+ assert(origParamss.length == 1) // we are after uncurry
+
+ val vparams =
+ for ((tp, sym) <- m.info.paramTypes zip origParamss(0))
+ yield m.newValue(sym.pos, specializedName(sym, typeEnv(cls)))
+ .setInfo(tp)
+ .setFlag(sym.flags)
+ // param accessors for private members (the others are inherited from the generic class)
+ for (param <- vparams if cls.info.nonPrivateMember(param.name) == NoSymbol;
+ val acc = param.cloneSymbol(cls).setFlag(PARAMACCESSOR | PRIVATE)) {
+ log("param accessor for " + acc.fullNameString)
+ cls.info.decls.enter(acc)
+ mbrs += ValDef(acc, EmptyTree).setType(NoType).setPos(m.pos)
+ }
+ // ctor
+ mbrs += DefDef(m, Modifiers(m.flags), List(vparams) map (_ map ValDef), EmptyTree)
+ } else
+ mbrs += DefDef(m, { paramss => EmptyTree })
+ } else {
+ assert(m.isValue)
+ mbrs += ValDef(m, EmptyTree).setType(NoType).setPos(m.pos)
+ }
+ }
+ mbrs.toList
+ }
+ }
+
+ private def forwardCall(pos: util.Position, receiver: Tree, paramss: List[List[ValDef]]): Tree = {
+ val argss = paramss map (_ map (x => Ident(x.symbol)))
+ atPos(pos) { (receiver /: argss) (Apply) }
+ }
+
+ /** Create specialized class definitions */
+ def implSpecClasses(trees: List[Tree]): List[Tree] = {
+ val buf = new mutable.ListBuffer[Tree]
+ for (val tree <- trees)
+ tree match {
+ case ClassDef(_, _, _, impl) =>
+ tree.symbol.info // force specialization
+ for (val ((sym1, env), specCls) <- specializedClass if sym1 == tree.symbol)
+ buf +=
+ ClassDef(specCls, Template(specCls.info.parents map TypeTree, emptyValDef, List())
+ .setSymbol(specCls.newLocalDummy(sym1.pos)))
+ case _ =>
+ }
+ log(buf)
+ buf.toList
+ }
+
+ /** Concrete methods that use a specialized type, or override such methods. */
+ private val concreteSpecMethods: mutable.Set[Symbol] = new mutable.HashSet
+
+ /** Instantiate polymorphic function `target' with type parameters from `from'.
+ * For each type parameter `tp' in `target', its argument is:
+ * - a corresponding type parameter of `from', if tp is not bound in
+ * typeEnv(from)
+ * - the upper bound of tp, if the binding conflicts with tp's bounds
+ * - typeEnv(from)(tp), if the binding is not conflicting in its bounds
+ */
+ private def makeTypeArguments(from: Symbol, target: Symbol): List[Type] = {
+ val owner = from.owner
+ val env = typeEnv(from)
+ for (tp <- owner.info.memberType(target).typeParams)
+ yield
+ if (!env.isDefinedAt(tp))
+ typeRef(NoPrefix, from.info.typeParams.find(_ == tp.name).get, Nil)
+ else if ((env(tp) <:< tp.info.bounds.hi) && (tp.info.bounds.lo <:< env(tp)))
+ env(tp)
+ else tp.info.bounds.hi
+ }
+
+ /** Cast `tree' to 'pt', unless tpe is a subtype of pt, or pt is Unit. */
+ def maybeCastTo(pt: Type, tpe: Type, tree: Tree): Tree =
+ if ((pt == definitions.UnitClass.tpe) || (tpe <:< pt)) {
+ log("no need to cast from " + tpe + " to " + pt)
+ tree
+ } else
+ gen.mkAsInstanceOf(tree, pt, false)
+
+
+ private def makeArguments(fun: Symbol, vparams: List[Symbol]): List[Tree] = {
+ def needsCast(tp1: Type, tp2: Type): Boolean =
+ !(tp1 <:< tp2)
+
+ //! TODO: make sure the param types are seen from the right prefix
+ for ((tp, arg) <- fun.info.paramTypes zip vparams) yield {
+ if (needsCast(arg.tpe, tp)) {
+ //log("tp: " + tp + " " + tp.typeSymbol.owner)
+ gen.mkAsInstanceOf(Ident(arg), tp, false)
+ } else Ident(arg)
+ }
+ }
+
+ private def findSpec(tp: Type): Type = tp match {
+ case TypeRef(pre, sym, args) =>
+ if (args.isEmpty) tp
+ else {
+ specializedType(tp)
+ /*log("looking for " + specializedName(sym.name, args) + " in " + pre)
+ val sym1 = pre.member(specializedName(sym.name, args))
+ assert(sym1 != NoSymbol, "pre: " + pre.typeSymbol + " ph: " + phase + " with: " + pre.members)
+ TypeRef(pre, sym1, Nil)*/
+ }
+ case _ => tp
+ }
+
+ class SpecializationTransformer(unit: CompilationUnit) extends Transformer {
+ override def transform(tree: Tree) =
+ atPhase(phase.next) {
+ val res = specializeCalls(unit).transform(tree)
+ res
+ }
+ }
+
+}
diff --git a/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala b/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala
index d1902a2983..90281047f4 100644
--- a/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala
+++ b/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala
@@ -28,7 +28,8 @@ trait TypingTransformers {
def atOwner[A](tree: Tree, owner: Symbol)(trans: => A): A = {
val savedLocalTyper = localTyper
- localTyper = localTyper.atOwner(tree, owner)
+// println("ttransformer atOwner: " + owner + " isPackage? " + owner.isPackage)
+ localTyper = localTyper.atOwner(tree, if (owner.isModule) owner.moduleClass else owner)
typers += Pair(owner, localTyper)
val result = super.atOwner(owner)(trans)
localTyper = savedLocalTyper
@@ -42,6 +43,8 @@ trait TypingTransformers {
case Template(_, _, _) =>
// enter template into context chain
atOwner(currentOwner) { super.transform(tree) }
+ case PackageDef(_, _) =>
+ atOwner(tree.symbol) { super.transform(tree) }
case _ =>
super.transform(tree)
}
diff --git a/src/compiler/scala/tools/nsc/typechecker/Duplicators.scala b/src/compiler/scala/tools/nsc/typechecker/Duplicators.scala
new file mode 100644
index 0000000000..b20b7c97bd
--- /dev/null
+++ b/src/compiler/scala/tools/nsc/typechecker/Duplicators.scala
@@ -0,0 +1,242 @@
+package scala.tools.nsc.typechecker
+
+import scala.tools.nsc.symtab.Flags
+
+import scala.collection.{mutable, immutable}
+
+/** Duplicate trees and re-type check them, taking care to replace
+ * and create fresh symbols for new local definitions.
+ */
+abstract class Duplicators extends Analyzer {
+ import global._
+
+ def retyped(context: Context, tree: Tree): Tree = {
+ resetClassOwners
+ (new BodyDuplicator(context)).typed(tree)
+ }
+
+ /** Retype the given tree in the given context. Use this method when retyping
+ * a method in a different class. The typer will replace references to the this of
+ * the old class with the new class, and map symbols through the given 'env'. The
+ * environment is a map from type skolems to concrete types (see SpecializedTypes).
+ */
+ def retyped(context: Context, tree: Tree, oldThis: Symbol, newThis: Symbol, env: collection.Map[Symbol, Type]): Tree = {
+ if (oldThis ne newThis) {
+ oldClassOwner = oldThis
+ newClassOwner = newThis
+ } else resetClassOwners
+
+ envSubstitution = new SubstSkolemsTypeMap(env.keysIterator.toList, env.valuesIterator.toList)
+ log("retyped with env: " + env)
+ (new BodyDuplicator(context)).typed(tree)
+ }
+
+ def retypedMethod(context: Context, tree: Tree, oldThis: Symbol, newThis: Symbol): Tree =
+ (new BodyDuplicator(context)).retypedMethod(tree.asInstanceOf[DefDef], oldThis, newThis)
+
+ /** Return the special typer for duplicate method bodies. */
+ override def newTyper(context: Context): Typer =
+ new BodyDuplicator(context)
+
+ private def resetClassOwners {
+ oldClassOwner = null
+ newClassOwner = null
+ }
+
+ private var oldClassOwner: Symbol = _
+ private var newClassOwner: Symbol = _
+ private var envSubstitution: SubstTypeMap = _
+
+ private class SubstSkolemsTypeMap(from: List[Symbol], to: List[Type]) extends SubstTypeMap(from, to) {
+ protected override def matches(sym1: Symbol, sym2: Symbol) =
+ if (sym2.isTypeSkolem) sym2.deSkolemize eq sym1
+ else sym1 eq sym2
+ }
+
+ private val invalidSyms: mutable.Map[Symbol, Tree] = mutable.HashMap.empty[Symbol, Tree]
+
+ /** A typer that creates new symbols for all definitions in the given tree
+ * and updates references to them while re-typechecking. All types in the
+ * tree, except for TypeTrees, are erased prior to type checking. TypeTrees
+ * are fixed by substituting invalid symbols for the new ones.
+ */
+ class BodyDuplicator(context: Context) extends Typer(context: Context) {
+
+ class FixInvalidSyms extends TypeMap {
+
+ def apply(tpe: Type): Type = {
+ mapOver(tpe)
+ }
+
+ override def mapOver(tpe: Type): Type = tpe match {
+ case TypeRef(NoPrefix, sym, args) if sym.isTypeParameterOrSkolem =>
+ val sym1 = context.scope.lookup(sym.name)
+// assert(sym1 ne NoSymbol, tpe)
+ if ((sym1 ne NoSymbol) && (sym1 ne sym)) {
+ log("fixing " + sym + " -> " + sym1)
+ typeRef(NoPrefix, sym1, mapOverArgs(args, sym1.typeParams))
+ } else super.mapOver(tpe)
+
+ case TypeRef(pre, sym, args) =>
+ val newsym = updateSym(sym)
+ if (newsym ne sym) {
+ log("fixing " + sym + " -> " + newsym)
+ typeRef(mapOver(pre), newsym, mapOverArgs(args, newsym.typeParams))
+ } else
+ super.mapOver(tpe)
+ case _ =>
+ super.mapOver(tpe)
+ }
+ }
+
+ /** Fix the given type by replacing invalid symbols with the new ones. */
+ def fixType(tpe: Type): Type = {
+ val tpe1 = envSubstitution(tpe)
+ log("tpe1: " + tpe1)
+ (new FixInvalidSyms)(tpe1)
+ }
+
+ /** Return the new symbol corresponding to `sym'. */
+ private def updateSym(sym: Symbol): Symbol =
+ if (invalidSyms.isDefinedAt(sym))
+ invalidSyms(sym).symbol
+ else
+ sym
+
+ private def invalidate(tree: Tree) {
+ if (tree.isDef && tree.symbol != NoSymbol) {
+ log("invalid " + tree.symbol)
+ invalidSyms(tree.symbol) = tree
+
+ tree match {
+ case ldef @ LabelDef(name, params, rhs) =>
+ log("LabelDef " + name + " sym.info: " + ldef.symbol.info)
+ invalidSyms(ldef.symbol) = ldef
+ // breakIf(true, this, ldef, context)
+ val newsym = ldef.symbol.cloneSymbol(context.owner)
+ newsym.setInfo(fixType(ldef.symbol.info))
+ ldef.symbol = newsym
+ log("newsym: " + newsym + " info: " + newsym.info)
+
+ case DefDef(_, _, _, _, _, rhs) =>
+ // invalidate parameters
+ invalidate(tree.asInstanceOf[DefDef].tparams)
+ invalidate(List.flatten(tree.asInstanceOf[DefDef].vparamss))
+ tree.symbol = NoSymbol
+
+ case _ =>
+ tree.symbol = NoSymbol
+ }
+ }
+ }
+
+ private def invalidate(stats: List[Tree]) {
+ stats foreach invalidate
+ }
+
+
+ def retypedMethod(ddef: DefDef, oldThis: Symbol, newThis: Symbol): Tree = {
+ oldClassOwner = oldThis
+ newClassOwner = newThis
+ invalidate(ddef.tparams)
+ for (vdef <- List.flatten(ddef.vparamss)) {
+ invalidate(vdef)
+ vdef.tpe = null
+ }
+ ddef.symbol = NoSymbol
+ enterSym(context, ddef)
+ log("remapping this of " + oldClassOwner + " to " + newClassOwner)
+ typed(ddef)
+ }
+
+ /** Special typer method allowing for re-type checking trees. It expects a typed tree.
+ * Returns a typed tree that has fresh symbols for all definitions in the original tree.
+ *
+ * Each definition tree is visited and its symbol added to the invalidSyms map (except LabelDefs),
+ * then cleared (forcing the namer to create fresh symbols).
+ * All invalid symbols found in trees are cleared (except for LabelDefs), forcing the
+ * typechecker to look for fresh ones in the context.
+ *
+ * Type trees are typed by substituting old symbols for new ones (@see fixType).
+ *
+ * LabelDefs are not typable from trees alone, unless they have the type ()Unit. Therefore,
+ * their symbols are recreated ad-hoc and their types are fixed inline, instead of letting the
+ * namer/typer handle them, or Idents that refer to them.
+ */
+ override def typed(tree: Tree, mode: Int, pt: Type): Tree = {
+ log("typing " + tree)
+ if (tree.hasSymbol && tree.symbol != NoSymbol
+ && !tree.symbol.isLabel // labels cannot be retyped by the type checker as LabelDef has no ValDef/return type trees
+ && invalidSyms.isDefinedAt(tree.symbol)) {
+ tree.symbol = NoSymbol
+ }
+
+ tree match {
+ case ttree @ TypeTree() =>
+ log("fixing tpe: " + tree.tpe + " with sym: " + tree.tpe.typeSymbol)
+ ttree.tpe = fixType(ttree.tpe)
+ ttree
+ case Block(stats, res) =>
+ log("invalidating block")
+ invalidate(stats)
+ invalidate(res)
+ tree.tpe = null
+ super.typed(tree, mode, pt)
+
+ case ClassDef(_, _, _, tmpl @ Template(parents, _, stats)) =>
+// log("invalidating classdef " + tree.tpe)
+ tmpl.symbol = tree.symbol.newLocalDummy(tree.pos)
+ invalidate(stats)
+ tree.tpe = null
+ super.typed(tree, mode, pt)
+
+ case ddef @ DefDef(_, _, _, _, tpt, rhs) =>
+ ddef.tpt.tpe = fixType(ddef.tpt.tpe)
+ ddef.tpe = null
+ super.typed(ddef, mode, pt)
+
+ case vdef @ ValDef(_, _, tpt, rhs) =>
+// log("vdef fixing tpe: " + tree.tpe + " with sym: " + tree.tpe.typeSymbol + " and " + invalidSyms)
+ vdef.tpt.tpe = fixType(vdef.tpt.tpe)
+ vdef.tpe = null
+ super.typed(vdef, mode, pt)
+
+ case ldef @ LabelDef(name, params, rhs) =>
+ ldef.tpe = null
+ val params1 = params map { p => Ident(updateSym(p.symbol)) }
+ super.typed(treeCopy.LabelDef(tree, name, params1, rhs), mode, pt)
+
+ case Bind(name, _) =>
+ invalidate(tree)
+ tree.tpe = null
+ super.typed(tree, mode, pt)
+
+ case Ident(_) if tree.symbol.isLabel =>
+ log("Ident to labeldef " + tree + " switched to ")
+ tree.symbol = updateSym(tree.symbol)
+ tree.tpe = null
+ super.typed(tree, mode, pt)
+
+ case Select(th @ This(_), sel) if (oldClassOwner ne null) && (th.symbol == oldClassOwner) =>
+ log("selection on this, no type ascription required")
+ super.typed(atPos(tree.pos)(Select(This(newClassOwner), sel)), mode, pt)
+
+ case This(_) if (oldClassOwner ne null) && (tree.symbol == oldClassOwner) =>
+// val tree1 = Typed(This(newClassOwner), TypeTree(fixType(tree.tpe.widen)))
+ val tree1 = This(newClassOwner)
+ log("mapped " + tree + " to " + tree1)
+ super.typed(atPos(tree.pos)(tree1), mode, pt)
+
+ case Super(qual, mix) if (oldClassOwner ne null) && (tree.symbol == oldClassOwner) =>
+ val tree1 = Super(qual, mix)
+ log("changed " + tree + " to " + tree1)
+ super.typed(atPos(tree.pos)(tree1))
+
+ case _ =>
+ tree.tpe = null
+ super.typed(tree, mode, pt)
+ }
+ }
+ }
+}
+
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 159062dca0..e324e5793f 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3046,8 +3046,13 @@ trait Typers { self: Analyzer =>
res.tpe = res.tpe.notNull
}
*/
- if (fun2.symbol == Array_apply) typed { atPos(tree.pos) { gen.mkCheckInit(res) } }
- else res
+ if (fun2.symbol == Array_apply) {
+ val checked = gen.mkCheckInit(res)
+ // this check is needed to avoid infinite recursion in Duplicators
+ // (calling typed1 more than once for the same tree
+ if (checked ne res) typed { atPos(tree.pos)(checked) }
+ else res
+ } else res
/* Would like to do the following instead, but curiously this fails; todo: investigate
if (fun2.symbol.name == nme.apply && fun2.symbol.owner == ArrayClass)
typed { atPos(tree.pos) { gen.mkCheckInit(res) } }