From ee5721e864de6fff2d54b9fb5452123bcca82483 Mon Sep 17 00:00:00 2001 From: Adriaan Moors Date: Mon, 23 Jul 2012 14:14:24 +0200 Subject: SI-6111 accept single-subpattern unapply pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An extractor pattern `X(p)` should type check for any `X.unapply`/`X.unapplySeq` that returns an `Option[_]` -- previously we were confused about the case where it was an `Option[(T1, ... , Tn)]`. In this case, the expected type for the pattern `p` is simply `(T1, ... , Tn)`. While I was at it, tried to clean up unapplyTypeList and friends (by replacing them by extractorFormalTypes). From the spec: 8.1.8 ExtractorPatterns An extractor pattern x(p1, ..., pn) where n ≥ 0 is of the same syntactic form as a constructor pattern. However, instead of a case class, the stable identifier x denotes an object which has a member method named unapply or unapplySeq that matches the pattern. An unapply method in an object x matches the pattern x(p1, ..., pn) if it takes exactly one argument and one of the following applies: n = 0 and unapply’s result type is Boolean. n = 1 and unapply’s result type is Option[T], for some type T. the (only) argument pattern p1 is typed in turn with expected type T n > 1 and unapply’s result type is Option[(T1, ..., Tn)], for some types T1, ..., Tn. the argument patterns p1, ..., pn are typed in turn with expected types T1, ..., Tn --- .../scala/tools/nsc/matching/Patterns.scala | 2 +- .../scala/tools/nsc/transform/UnCurry.scala | 2 +- .../scala/tools/nsc/typechecker/Infer.scala | 64 ++++++++++++++++++++++ .../tools/nsc/typechecker/PatternMatching.scala | 5 -- .../scala/tools/nsc/typechecker/Typers.scala | 12 ++-- .../scala/tools/nsc/typechecker/Unapplies.scala | 59 +++----------------- .../scala/reflect/internal/Definitions.scala | 10 ++++ 7 files changed, 92 insertions(+), 62 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index bbe22ca314..28dfd3fc77 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -402,7 +402,7 @@ trait Patterns extends ast.TreeDSL { case _ => toPats(args) } - def resTypes = analyzer.unapplyTypeList(unfn.symbol, unfn.tpe) + def resTypes = analyzer.unapplyTypeList(unfn.symbol, unfn.tpe, args.length) def resTypesString = resTypes match { case Nil => "Boolean" case xs => xs.mkString(", ") diff --git a/src/compiler/scala/tools/nsc/transform/UnCurry.scala b/src/compiler/scala/tools/nsc/transform/UnCurry.scala index 2983c65e78..5c0207e5c7 100644 --- a/src/compiler/scala/tools/nsc/transform/UnCurry.scala +++ b/src/compiler/scala/tools/nsc/transform/UnCurry.scala @@ -618,7 +618,7 @@ abstract class UnCurry extends InfoTransform val fn1 = withInPattern(false)(transform(fn)) val args1 = transformTrees(fn.symbol.name match { case nme.unapply => args - case nme.unapplySeq => transformArgs(tree.pos, fn.symbol, args, analyzer.unapplyTypeListFromReturnTypeSeq(fn.tpe)) + case nme.unapplySeq => transformArgs(tree.pos, fn.symbol, args, analyzer.unapplyTypeList(fn.symbol, fn.tpe, args.length)) case _ => sys.error("internal error: UnApply node has wrong symbol") }) treeCopy.UnApply(tree, fn1, args1) diff --git a/src/compiler/scala/tools/nsc/typechecker/Infer.scala b/src/compiler/scala/tools/nsc/typechecker/Infer.scala index 960c210649..4340137335 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Infer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Infer.scala @@ -49,6 +49,70 @@ trait Infer { } else formals1 } + /** Returns `(formals, formalsExpanded)` where `formalsExpanded` are the expected types + * for the `nbSubPats` sub-patterns of an extractor pattern, of which the corresponding + * unapply[Seq] call is assumed to have result type `resTp`. + * + * `formals` are the formal types before expanding a potential repeated parameter (must come last in `formals`, if at all) + * + * @throws TypeError when the unapply[Seq] definition is ill-typed + * @returns (null, null) when the expected number of sub-patterns cannot be satisfied by the given extractor + * + * From the spec: + * 8.1.8 ExtractorPatterns + * + * An extractor pattern x(p1, ..., pn) where n ≥ 0 is of the same syntactic form as a constructor pattern. + * However, instead of a case class, the stable identifier x denotes an object which has a member method named unapply or unapplySeq that matches the pattern. + * An unapply method in an object x matches the pattern x(p1, ..., pn) if it takes exactly one argument and one of the following applies: + * + * n = 0 and unapply’s result type is Boolean. + * + * n = 1 and unapply’s result type is Option[T], for some type T. + * the (only) argument pattern p1 is typed in turn with expected type T + * + * n > 1 and unapply’s result type is Option[(T1, ..., Tn)], for some types T1, ..., Tn. + * the argument patterns p1, ..., pn are typed in turn with expected types T1, ..., Tn + */ + def extractorFormalTypes(resTp: Type, nbSubPats: Int, unappSym: Symbol): (List[Type], List[Type]) = { + val isUnapplySeq = unappSym.name == nme.unapplySeq + val booleanExtractor = resTp.typeSymbolDirect == BooleanClass + + @inline def seqToRepeatedChecked(tp: Type) = { + val toRepeated = seqToRepeated(tp) + if (tp eq toRepeated) throw new TypeError("(the last tuple-component of) the result type of an unapplySeq must be a Seq[_]") + else toRepeated + } + + val formals = + if (nbSubPats == 0 && booleanExtractor && !isUnapplySeq) Nil + else resTp.baseType(OptionClass).typeArgs match { + case optionTArg :: Nil => + if (nbSubPats == 1) + if (isUnapplySeq) List(seqToRepeatedChecked(optionTArg)) + else List(optionTArg) + // in principle, the spec doesn't allow just any subtype of Product, it *must* be TupleN[...] -- see run/virtpatmat_extends_product.scala + // should check `isTupleType(optionTArg)` -- this breaks plenty of stuff, though... + else getProductArgs(optionTArg) match { + case Nil if isUnapplySeq => List(seqToRepeatedChecked(optionTArg)) + case tps if isUnapplySeq => tps.init :+ seqToRepeatedChecked(tps.last) + case tps => tps + } + case _ => + if (isUnapplySeq) + throw new TypeError(s"result type $resTp of unapplySeq defined in ${unappSym.owner+unappSym.owner.locationString} not in {Option[_], Some[_]}") + else + throw new TypeError(s"result type $resTp of unapply defined in ${unappSym.owner+unappSym.owner.locationString} not in {Boolean, Option[_], Some[_]}") + } + + // for unapplySeq, replace last vararg by as many instances as required by nbSubPats + val formalsExpanded = + if (isUnapplySeq && formals.nonEmpty) formalTypes(formals, nbSubPats) + else formals + + if (formalsExpanded.lengthCompare(nbSubPats) != 0) (null, null) + else (formals, formalsExpanded) + } + def actualTypes(actuals: List[Type], nformals: Int): List[Type] = if (nformals == 1 && !hasLength(actuals, 1)) List(if (actuals.isEmpty) UnitClass.tpe else tupleType(actuals)) diff --git a/src/compiler/scala/tools/nsc/typechecker/PatternMatching.scala b/src/compiler/scala/tools/nsc/typechecker/PatternMatching.scala index 43edad3576..8b7c70c048 100644 --- a/src/compiler/scala/tools/nsc/typechecker/PatternMatching.scala +++ b/src/compiler/scala/tools/nsc/typechecker/PatternMatching.scala @@ -218,11 +218,6 @@ trait PatternMatching extends Transform with TypingTransformers with ast.TreeDSL if(phase.id >= currentRun.uncurryPhase.id) debugwarn("running translateMatch at "+ phase +" on "+ selector +" match "+ cases) patmatDebug("translating "+ cases.mkString("{", "\n", "}")) - def repeatedToSeq(tp: Type): Type = (tp baseType RepeatedParamClass) match { - case TypeRef(_, RepeatedParamClass, arg :: Nil) => seqType(arg) - case _ => tp - } - val start = Statistics.startTimer(patmatNanos) val selectorTp = repeatedToSeq(elimAnonymousClass(selector.tpe.widen.withoutAnnotations)) diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index dbe65c16d8..8361ea9586 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -2749,6 +2749,7 @@ trait Typers extends Modes with Adaptations with Tags { def typedArgs(args: List[Tree], mode: Int) = args mapConserve (arg => typedArg(arg, mode, 0, WildcardType)) + // [adriaan] as far as I can tell, formals0 is only supplied to detect whether the last formal was originally a vararg def typedArgs(args0: List[Tree], mode: Int, formals0: List[Type], adapted0: List[Type]): List[Tree] = { val sticky = onlyStickyModes(mode) def loop(args: List[Tree], formals: List[Type], adapted: List[Type]): List[Tree] = { @@ -3157,12 +3158,13 @@ trait Typers extends Modes with Adaptations with Tags { if (fun1.tpe.isErroneous) duplErrTree else { - val formals0 = unapplyTypeList(fun1.symbol, fun1.tpe) - val formals1 = formalTypes(formals0, args.length) + val resTp = fun1.tpe.finalResultType.normalize + val nbSubPats = args.length - if (!sameLength(formals1, args)) duplErrorTree(WrongNumberArgsPatternError(tree, fun)) + val (formals, formalsExpanded) = extractorFormalTypes(resTp, nbSubPats, fun1.symbol) + if (formals == null) duplErrorTree(WrongNumberArgsPatternError(tree, fun)) else { - val args1 = typedArgs(args, mode, formals0, formals1) + val args1 = typedArgs(args, mode, formals, formalsExpanded) // This used to be the following (failing) assert: // assert(isFullyDefined(pt), tree+" ==> "+UnApply(fun1, args1)+", pt = "+pt) // I modified as follows. See SI-1048. @@ -4880,7 +4882,7 @@ trait Typers extends Modes with Adaptations with Tags { case UnApply(fun, args) => val fun1 = typed(fun) - val tpes = formalTypes(unapplyTypeList(fun.symbol, fun1.tpe), args.length) + val tpes = formalTypes(unapplyTypeList(fun.symbol, fun1.tpe, args.length), args.length) val args1 = map2(args, tpes)(typedPattern) treeCopy.UnApply(tree, fun1, args1) setType pt diff --git a/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala b/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala index ad936ac39d..d508e10813 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala @@ -31,59 +31,18 @@ trait Unapplies extends ast.TreeDSL // moduleClass symbol of the companion module. class ClassForCaseCompanionAttachment(val caseClass: ClassDef) - /** returns type list for return type of the extraction */ - def unapplyTypeList(ufn: Symbol, ufntpe: Type) = { + /** returns type list for return type of the extraction + * @see extractorFormalTypes + */ + def unapplyTypeList(ufn: Symbol, ufntpe: Type, nbSubPats: Int) = { assert(ufn.isMethod, ufn) //Console.println("utl "+ufntpe+" "+ufntpe.typeSymbol) ufn.name match { - case nme.unapply => unapplyTypeListFromReturnType(ufntpe) - case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(ufntpe) - case _ => throw new TypeError(ufn+" is not an unapply or unapplySeq") - } - } - /** (the inverse of unapplyReturnTypeSeq) - * for type Boolean, returns Nil - * for type Option[T] or Some[T]: - * - returns T0...Tn if n>0 and T <: Product[T0...Tn]] - * - returns T otherwise - */ - def unapplyTypeListFromReturnType(tp1: Type): List[Type] = { - val tp = unapplyUnwrap(tp1) - tp.typeSymbol match { // unapplySeqResultToMethodSig - case BooleanClass => Nil - case OptionClass | SomeClass => - val prod = tp.typeArgs.head -// the spec doesn't allow just any subtype of Product, it *must* be TupleN[...] -- see run/virtpatmat_extends_product.scala -// this breaks plenty of stuff, though... -// val targs = -// if (isTupleType(prod)) getProductArgs(prod) -// else List(prod) - val targs = getProductArgs(prod) - - if (targs.isEmpty || targs.tail.isEmpty) List(prod) // special n == 0 || n == 1 - else targs // n > 1 - case _ => - throw new TypeError("result type "+tp+" of unapply not in {Boolean, Option[_], Some[_]}") - } - } - - /** let type be the result type of the (possibly polymorphic) unapply method - * for type Option[T] or Some[T] - * -returns T0...Tn-1,Tn* if n>0 and T <: Product[T0...Tn-1,Seq[Tn]]], - * -returns R* if T = Seq[R] - */ - def unapplyTypeListFromReturnTypeSeq(tp1: Type): List[Type] = { - val tp = unapplyUnwrap(tp1) - tp.typeSymbol match { - case OptionClass | SomeClass => - val ts = unapplyTypeListFromReturnType(tp1) - val last1 = (ts.last baseType SeqClass) match { - case TypeRef(pre, SeqClass, args) => typeRef(pre, RepeatedParamClass, args) - case _ => throw new TypeError("last not seq") - } - ts.init :+ last1 - case _ => - throw new TypeError("result type "+tp+" of unapply not in {Option[_], Some[_]}") + case nme.unapply | nme.unapplySeq => + val (formals, _) = extractorFormalTypes(unapplyUnwrap(ufntpe), nbSubPats, ufn) + if (formals == null) throw new TypeError(s"$ufn of type $ufntpe cannot extract $nbSubPats sub-patterns") + else formals + case _ => throw new TypeError(ufn+" is not an unapply or unapplySeq") } } diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index d9b63529eb..90aa0b732c 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -397,6 +397,16 @@ trait Definitions extends api.StandardDefinitions { case _ => false } + def repeatedToSeq(tp: Type): Type = (tp baseType RepeatedParamClass) match { + case TypeRef(_, RepeatedParamClass, arg :: Nil) => seqType(arg) + case _ => tp + } + + def seqToRepeated(tp: Type): Type = (tp baseType SeqClass) match { + case TypeRef(_, SeqClass, arg :: Nil) => scalaRepeatedType(arg) + case _ => tp + } + def isPrimitiveArray(tp: Type) = tp match { case TypeRef(_, ArrayClass, arg :: Nil) => isPrimitiveValueClass(arg.typeSymbol) case _ => false -- cgit v1.2.3 From 98a5f06e3841ac988a819a4928ffd827efec221e Mon Sep 17 00:00:00 2001 From: Adriaan Moors Date: Tue, 24 Jul 2012 09:48:26 +0200 Subject: docs related to fix for SI-6111 --- src/compiler/scala/tools/nsc/typechecker/Infer.scala | 3 +-- src/compiler/scala/tools/nsc/typechecker/Typers.scala | 9 ++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/typechecker/Infer.scala b/src/compiler/scala/tools/nsc/typechecker/Infer.scala index 4340137335..291b7f1827 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Infer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Infer.scala @@ -90,8 +90,7 @@ trait Infer { if (nbSubPats == 1) if (isUnapplySeq) List(seqToRepeatedChecked(optionTArg)) else List(optionTArg) - // in principle, the spec doesn't allow just any subtype of Product, it *must* be TupleN[...] -- see run/virtpatmat_extends_product.scala - // should check `isTupleType(optionTArg)` -- this breaks plenty of stuff, though... + // TODO: update spec to reflect we allow any ProductN, not just TupleN else getProductArgs(optionTArg) match { case Nil if isUnapplySeq => List(seqToRepeatedChecked(optionTArg)) case tps if isUnapplySeq => tps.init :+ seqToRepeatedChecked(tps.last) diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 8361ea9586..80e7d0d474 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -2749,7 +2749,14 @@ trait Typers extends Modes with Adaptations with Tags { def typedArgs(args: List[Tree], mode: Int) = args mapConserve (arg => typedArg(arg, mode, 0, WildcardType)) - // [adriaan] as far as I can tell, formals0 is only supplied to detect whether the last formal was originally a vararg + /** Type trees in `args0` against corresponding expected type in `adapted0`. + * + * The mode in which each argument is typed is derived from `mode` and + * whether the arg was originally by-name or var-arg (need `formals0` for that) + * the default is by-val, of course. + * + * (docs reverse-engineered -- AM) + */ def typedArgs(args0: List[Tree], mode: Int, formals0: List[Type], adapted0: List[Type]): List[Tree] = { val sticky = onlyStickyModes(mode) def loop(args: List[Tree], formals: List[Type], adapted: List[Type]): List[Tree] = { -- cgit v1.2.3