diff options
author | Martin Odersky <odersky@gmail.com> | 2015-10-31 12:58:01 +0100 |
---|---|---|
committer | Martin Odersky <odersky@gmail.com> | 2016-02-16 15:23:43 +0100 |
commit | 29104c9755a9d6393959a416650422b84f0957f2 (patch) | |
tree | 19f844e71f9904c19714f4cab8c93fffc97db437 /src/dotty/tools | |
parent | 5e8023335e641c9c05c6517a82764571e7ef6386 (diff) | |
download | dotty-29104c9755a9d6393959a416650422b84f0957f2.tar.gz dotty-29104c9755a9d6393959a416650422b84f0957f2.tar.bz2 dotty-29104c9755a9d6393959a416650422b84f0957f2.zip |
Auto-uncurry n-ary functions.
Implements SIP #897.
Diffstat (limited to 'src/dotty/tools')
-rw-r--r-- | src/dotty/tools/dotc/ast/Desugar.scala | 19 | ||||
-rw-r--r-- | src/dotty/tools/dotc/typer/Typer.scala | 56 |
2 files changed, 56 insertions, 19 deletions
diff --git a/src/dotty/tools/dotc/ast/Desugar.scala b/src/dotty/tools/dotc/ast/Desugar.scala index 87694843a..c1083d26d 100644 --- a/src/dotty/tools/dotc/ast/Desugar.scala +++ b/src/dotty/tools/dotc/ast/Desugar.scala @@ -588,6 +588,25 @@ object desugar { Function(params, Match(selector, cases)) } + /** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows: + * + * x$1 => { + * val p1 = x$1._1 + * ... + * val pn = x$1._n + * body + * } + */ + def makeUnaryCaseLambda(params: List[ValDef], body: Tree)(implicit ctx: Context): Tree = { + val param = makeSyntheticParameter() + def selector(n: Int) = Select(refOfDef(param), nme.selectorName(n)) + val vdefs = + params.zipWithIndex.map{ + case(param, idx) => cpy.ValDef(param)(rhs = selector(idx)) + } + Function(param :: Nil, Block(vdefs, body)) + } + /** Add annotation with class `cls` to tree: * tree @cls */ diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala index 6a2ff30fa..784702dba 100644 --- a/src/dotty/tools/dotc/typer/Typer.scala +++ b/src/dotty/tools/dotc/typer/Typer.scala @@ -611,26 +611,44 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit if (protoFormals.length == params.length) protoFormals(i) else errorType(i"wrong number of parameters, expected: ${protoFormals.length}", tree.pos) - val inferredParams: List[untpd.ValDef] = - for ((param, i) <- params.zipWithIndex) yield - if (!param.tpt.isEmpty) param - else cpy.ValDef(param)( - tpt = untpd.TypeTree( - inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false))) - - // Define result type of closure as the expected type, thereby pushing - // down any implicit searches. We do this even if the expected type is not fully - // defined, which is a bit of a hack. But it's needed to make the following work - // (see typers.scala and printers/PlainPrinter.scala for examples). - // - // def double(x: Char): String = s"$x$x" - // "abc" flatMap double - // - val resultTpt = protoResult match { - case WildcardType(_) => untpd.TypeTree() - case _ => untpd.TypeTree(protoResult) + /** Is `formal` a product type which is elementwise compatible with `params`? */ + def ptIsCorrectProduct(formal: Type) = { + val pclass = defn.ProductNClass(params.length) + isFullyDefined(formal, ForceDegree.noBottom) && + formal.derivesFrom(pclass) && + formal.baseArgTypes(pclass).corresponds(params) { + (argType, param) => + param.tpt.isEmpty || isCompatible(argType, typedAheadType(param.tpt).tpe) + } } - typed(desugar.makeClosure(inferredParams, fnBody, resultTpt), pt) + + val desugared = + if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) { + desugar.makeUnaryCaseLambda(params, fnBody) + } + else { + val inferredParams: List[untpd.ValDef] = + for ((param, i) <- params.zipWithIndex) yield + if (!param.tpt.isEmpty) param + else cpy.ValDef(param)( + tpt = untpd.TypeTree( + inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false))) + + // Define result type of closure as the expected type, thereby pushing + // down any implicit searches. We do this even if the expected type is not fully + // defined, which is a bit of a hack. But it's needed to make the following work + // (see typers.scala and printers/PlainPrinter.scala for examples). + // + // def double(x: Char): String = s"$x$x" + // "abc" flatMap double + // + val resultTpt = protoResult match { + case WildcardType(_) => untpd.TypeTree() + case _ => untpd.TypeTree(protoResult) + } + desugar.makeClosure(inferredParams, fnBody, resultTpt) + } + typed(desugared, pt) } } |