package package transform import TreeTransforms._ import core.DenotTransformers._ import core.Symbols._ import core.Contexts._ import core.Types._ import core.Flags._ import core.Decorators._ import core.SymDenotations._ import core.StdNames.nme import core.Names._ import core.NameOps._ import ast.Trees._ import SymUtils._ import import collection.{ mutable, immutable } import collection.mutable.{ LinkedHashMap, LinkedHashSet, TreeSet } /** * Rewires closures to implement more specific types of Functions. */ class FunctionalInterfaces extends MiniPhaseTransform { import tpd._ def phaseName: String = "functionalInterfaces" private var allowedReturnTypes: Set[Symbol] = _ // moved here to make it explicit what specializations are generated private var allowedArgumentTypes: Set[Symbol] = _ val maxArgsCount = 2 def shouldSpecialize(m: MethodType)(implicit ctx: Context) = (m.paramTypes.size <= maxArgsCount) && m.paramTypes.forall(x => allowedArgumentTypes.contains(x.typeSymbol)) && allowedReturnTypes.contains(m.resultType.typeSymbol) val functionName = "JFunction".toTermName val functionPackage = "scala.compat.java8.".toTermName override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = { allowedReturnTypes = Set(defn.UnitClass, defn.BooleanClass, defn.IntClass, defn.FloatClass, defn.LongClass, defn.DoubleClass, /* only for Function0: */ defn.ByteClass, defn.ShortClass, defn.CharClass) allowedArgumentTypes = Set(defn.IntClass, defn.LongClass, defn.DoubleClass, /* only for Function1: */ defn.FloatClass) this } override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo): Tree = { tree.tpt match { case EmptyTree => val m = tree.meth.tpe.widen.asInstanceOf[MethodType] if (shouldSpecialize(m)) { val interfaceName = (functionName ++ m.paramTypes.length.toString).specializedFor(m.resultType, m.paramTypes) // symbols loaded from classpath aren't defined in periods earlier than when they where loaded val interface = ctx.withPhase(ctx.typerPhase).getClassIfDefined(functionPackage ++ interfaceName) if (interface.exists) { val tpt = tpd.TypeTree(interface.asType.typeRef) tpd.Closure(tree.env, tree.meth, tpt) } else tree } else tree case _ => tree } } }