From 3c8196300d65738d6779ba8703e2a86ee3390ec7 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sat, 28 Dec 2013 21:46:05 +0100 Subject: New version of eta-expansion. This version expands a method ref p.m to the untyped tree p.m(_, ..., _) (after lifting impure expressions from p). Afterwards the usual application mechanisms kick in. This fixes problems also present in Scala 2.x, where an eta-expanded function was not as flexible as an explicitly expanded one (for instance, eta expansion did not honor default parameters). --- src/dotty/tools/dotc/core/Definitions.scala | 4 +- src/dotty/tools/dotc/typer/Applications.scala | 4 ++ src/dotty/tools/dotc/typer/EtaExpansion.scala | 36 ++++++----- src/dotty/tools/dotc/typer/Typer.scala | 87 +++++++++++++++++++++------ 4 files changed, 93 insertions(+), 38 deletions(-) (limited to 'src/dotty/tools') diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index e3915b861..7a30be9d3 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -322,10 +322,12 @@ class Definitions(implicit ctx: Context) { (tp derivesFrom ProductClass) && tp.baseClasses.exists(ProductClasses contains _) def isFunctionType(tp: Type) = { - val arity = tp.dealias.typeArgs.length - 1 + val arity = functionArity(tp) 0 <= arity && arity <= MaxFunctionArity && (tp isRef FunctionClass(arity)) } + def functionArity(tp: Type) = tp.dealias.typeArgs.length - 1 + // ----- Higher kinds machinery ------------------------------------------ private var _hkTraits: Set[Symbol] = Set() diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala index 27e142e55..aa2227d63 100644 --- a/src/dotty/tools/dotc/typer/Applications.scala +++ b/src/dotty/tools/dotc/typer/Applications.scala @@ -26,8 +26,12 @@ import reflect.ClassTag import language.implicitConversions object Applications { + import tpd._ private val isNamedArg = (arg: Any) => arg.isInstanceOf[Trees.NamedArg[_]] def hasNamedArg(args: List[Any]) = args exists isNamedArg + + def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(implicit ctx: Context): Tree = + if (defs != null && defs.nonEmpty) tpd.Block(defs.toList, tree) else tree } import Applications._ diff --git a/src/dotty/tools/dotc/typer/EtaExpansion.scala b/src/dotty/tools/dotc/typer/EtaExpansion.scala index 16a02ba5c..46a1d3583 100644 --- a/src/dotty/tools/dotc/typer/EtaExpansion.scala +++ b/src/dotty/tools/dotc/typer/EtaExpansion.scala @@ -94,33 +94,31 @@ object EtaExpansion { } /** Eta-expanding a tree means converting a method reference to a function value. - * @param tree The tree to expand - * @param wtp The widened type of the tree, which is always a MethodType - * Let `wtp` be the method type - * - * (x1: T1, ..., xn: Tn): R - * + * @param tree The tree to expand + * @param paramNames The names of the parameters to use in the expansion + * Let `paramNames` be x1, ..., xn * and assume the lifted application of `tree` (@see liftApp) is * * { val xs = es; expr } * * Then the eta-expansion is * - * { val xs = es; - * { def $anonfun(x1: T1, ..., xn: Tn): T = expr; Closure($anonfun) }} + * { val xs = es; (x1, ..., xn) => expr(xx1, ..., xn) } + * + * This is an untyped tree, with `es` and `expr` as typed splices. */ - def etaExpand(tree: Tree, tpe: MethodType)(implicit ctx: Context): Tree = { - def expand(lifted: Tree): Tree = { - val meth = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic, tpe, coord = tree.pos) - Closure(meth, argss => (lifted /: argss)(Apply(_, _))) - } - val defs = new mutable.ListBuffer[Tree] - val lifted = liftApp(defs, tree) - wrapDefs(defs, expand(lifted)) + def etaExpand(tree: Tree, paramNames: List[TermName])(implicit ctx: Context): untpd.Tree = { + import untpd._ + val defs = new mutable.ListBuffer[tpd.Tree] + val lifted: Tree = TypedSplice(liftApp(defs, tree)) + val params = paramNames map (name => + ValDef(Modifiers(Synthetic | Param), name, TypeTree(), EmptyTree).withPos(tree.pos)) + val ids = paramNames map (name => + Ident(name).withPos(tree.pos)) + val body = Apply(lifted, ids) + val fn = untpd.Function(params, body) + if (defs.nonEmpty) untpd.Block(defs.toList map untpd.TypedSplice, fn) else fn } - - def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(implicit ctx: Context): Tree = - if (defs != null && defs.nonEmpty) tpd.Block(defs.toList, tree) else tree } /**

not needed diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala index 0b3b764b5..85724b1ee 100644 --- a/src/dotty/tools/dotc/typer/Typer.scala +++ b/src/dotty/tools/dotc/typer/Typer.scala @@ -502,34 +502,82 @@ class Typer extends Namer with Applications with Implicits { typed(cpy.AppliedTypeTree(tree, untpd.TypeTree(defn.FunctionClass(args.length).typeRef), args :+ body), pt) else { - val params = args.asInstanceOf[List[ValDef]] + val params = args.asInstanceOf[List[untpd.ValDef]] val protoFormals: List[Type] = pt match { case _ if pt isRef defn.FunctionClass(params.length) => pt.typeArgs take params.length case SAMType(meth) => - // println(s"SAMType $pt") val MethodType(_, paramTypes) = meth.info paramTypes case _ => - // println(s"Neither fucntion nor SAM type $pt") params map alwaysWildcardType } + + def refersTo(arg: untpd.Tree, param: untpd.ValDef): Boolean = arg match { + case Ident(name) => name == param.name + case _ => false + } + + /** The funcion body to be returned in the closure. Can become a TypedSplice + * of a typed expression if this is necessary to infer a parameter type. + */ + var fnBody = tree.body + + + /** If function is of the form + * (x1, ..., xN) => f(x1, ..., XN) + * the type of `f`, otherwise NoType. (updates `fnBody` as a side effect). + */ + def calleeType: Type = fnBody match { + case Apply(expr, args) if (args corresponds params)(refersTo) => + expr match { + case untpd.TypedSplice(expr1) => + expr1.tpe + case _ => + val protoArgs = args map (_ withType WildcardType) + val callProto = FunProto(protoArgs, WildcardType, this) + val expr1 = typedExpr(expr, callProto) + fnBody = cpy.Apply(fnBody, untpd.TypedSplice(expr1), args) + expr1.tpe + } + case _ => + NoType + } + + /** Two attempts: First, if expected type is fully defined pick this one. + * Second, if function is of the form + * (x1, ..., xN) => f(x1, ..., XN) + * and f has a method type MT, pick the corresponding parameter type in MT, + * if this one is fully defined. + * If both attempts fail, issue a "missing parameter type" error. + */ + def inferredParamType(param: untpd.ValDef, formal: Type): Type = { + if (isFullyDefined(formal, ForceDegree.noBottom)) return formal + calleeType.widen match { + case mtpe: MethodType => + val pos = params indexWhere (_.name == param.name) + if (pos < mtpe.paramTypes.length) { + val ptype = mtpe.paramTypes(pos) + if (isFullyDefined(ptype, ForceDegree.none)) return ptype + } + case _ => + } + val ofFun = + if (nme.syntheticParamNames(args.length + 1) contains param.name) + s" of expanded function ${tree.show}" + else + "" + errorType(s"missing parameter type for parameter ${param.name}$ofFun, expected = ${pt.show}", param.pos) + } + val inferredParams: List[untpd.ValDef] = for ((param, formal) <- params zip protoFormals) yield if (!param.tpt.isEmpty) param else { - val paramType = - if (isFullyDefined(formal, ForceDegree.noBottom)) formal - else { - val ofFun = - if (nme.syntheticParamNames(args.length + 1) contains param.name) - s" of expanded function ${tree.show}" - else "" - errorType(s"missing parameter type for parameter ${param.name}$ofFun, expected = ${pt.show}", param.pos) - } - cpy.ValDef(param, param.mods, param.name, untpd.TypeTree(paramType), param.rhs) + val paramTpt = untpd.TypeTree(inferredParamType(param, formal)) + cpy.ValDef(param, param.mods, param.name, paramTpt, param.rhs) } - typed(desugar.makeClosure(inferredParams, body), pt) + typed(desugar.makeClosure(inferredParams, fnBody), pt) } } @@ -1063,7 +1111,7 @@ class Typer extends Namer with Applications with Implicits { } } - def adaptNoArgs(wtp: Type) = wtp match { + def adaptNoArgs(wtp: Type): Tree = wtp match { case wtp: ExprType => adaptInterpolated(tree.withType(wtp.resultType), pt) case wtp: ImplicitMethodType => @@ -1084,9 +1132,12 @@ class Typer extends Namer with Applications with Implicits { } adapt(tpd.Apply(tree, args), pt) case wtp: MethodType if !pt.isInstanceOf[SingletonType] => - if ((defn.isFunctionType(pt) || (pt eq AnyFunctionProto)) && - !tree.symbol.isConstructor) - etaExpand(tree, wtp) + val arity = + if (defn.isFunctionType(pt)) defn.functionArity(pt) + else if (pt eq AnyFunctionProto) wtp.paramTypes.length + else -1 + if (arity >= 0 && !tree.symbol.isConstructor) + typed(etaExpand(tree, wtp.paramNames take arity), pt) else if (wtp.paramTypes.isEmpty) adaptInterpolated(tpd.Apply(tree, Nil), pt) else -- cgit v1.2.3