diff options
author | Paul Phillips <paulp@improving.org> | 2013-08-18 09:29:44 -0700 |
---|---|---|
committer | Paul Phillips <paulp@improving.org> | 2013-08-18 13:23:21 -0700 |
commit | 6d77da374e94ea8b80fc0bf9e544e11f4e9d5cc8 (patch) | |
tree | 672e9dfa0343526c77f68d45c7a4698a9dc832dc /src | |
parent | a905d0e7e49bf92f119b2fdcd2b9d15b71d64ca2 (diff) | |
download | scala-6d77da374e94ea8b80fc0bf9e544e11f4e9d5cc8.tar.gz scala-6d77da374e94ea8b80fc0bf9e544e11f4e9d5cc8.tar.bz2 scala-6d77da374e94ea8b80fc0bf9e544e11f4e9d5cc8.zip |
Refined name-based patmat methods.
This fleshes out some of the slightly unfinished corners
of the adventure, especially for unapplySeq. There's still
an unhealthy amount of duplication and a paucity of
specification, but I think it's in eminently good shape
for a milestone.
Diffstat (limited to 'src')
5 files changed, 102 insertions, 105 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala index 52055dea85..2bd14d923a 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala @@ -87,10 +87,8 @@ trait MatchCodeGen extends Interface { def _match(n: Name): SelectStart = matchStrategy DOT n // TODO: error message - private lazy val oneType = typer.typedOperator(_match(vpmName.one)).tpe - private def oneApplied(tp: Type): Type = appliedType(oneType, tp :: Nil) - override def pureType(tp: Type): Type = firstParamType(oneApplied(tp)) - override def mapResultType(prev: Type, elem: Type): Type = oneApplied(elem).finalResultType + private lazy val oneType = typer.typedOperator(_match(vpmName.one)).tpe + override def pureType(tp: Type): Type = firstParamType(appliedType(oneType, tp :: Nil)) } trait PureCodegen extends CodegenCore with PureMatchMonadInterface { diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchTranslation.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchTranslation.scala index 282a78492a..d4bbef740c 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/MatchTranslation.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchTranslation.scala @@ -20,6 +20,7 @@ trait MatchTranslation { import definitions._ import global.analyzer.{ErrorUtils, formalTypes} import treeInfo.{ WildcardStarArg, Unapplied, isStar, unbind } + import CODE._ // Always map repeated params to sequences private def setVarInfo(sym: Symbol, info: Type) = @@ -252,7 +253,7 @@ trait MatchTranslation { CaseDef( Bind(exSym, Ident(nme.WILDCARD)), // TODO: does this need fixing upping? EmptyTree, - combineCasesNoSubstOnly(CODE.REF(exSym), scrutSym, casesNoSubstOnly, pt, matchOwner, Some(scrut => Throw(CODE.REF(exSym)))) + combineCasesNoSubstOnly(REF(exSym), scrutSym, casesNoSubstOnly, pt, matchOwner, Some(scrut => Throw(REF(exSym)))) ) }) } @@ -364,8 +365,6 @@ trait MatchTranslation { } abstract class ExtractorCall { - import CODE._ - def fun: Tree def args: List[Tree] @@ -380,12 +379,16 @@ trait MatchTranslation { private def hasStar = nbSubPats > 0 && isStar(args.last) private def isNonEmptySeq = nbSubPats > 0 && isSeq - def isSingle = nbSubPats == 0 && !isSeq + /** This is special cased so that a single pattern will accept any extractor + * result, even if it's a tuple (SI-6675) + */ + def isSingle = nbSubPats == 1 && !isSeq // to which type should the previous binder be casted? def paramType : Type protected def rawSubPatTypes: List[Type] + protected def resultType: Type /** Create the TreeMaker that embodies this extractor call * @@ -412,16 +415,16 @@ trait MatchTranslation { def subPatTypes: List[Type] = ( if (rawSubPatTypes.isEmpty || !isSeq) rawSubPatTypes - else if (hasStar) nonStarSubPatTypes :+ rawLast + else if (hasStar) nonStarSubPatTypes :+ sequenceType else nonStarSubPatTypes ) + private def rawGet = typeOfMemberNamedGetOrSelf(resultType) private def emptySub = rawSubPatTypes.isEmpty - private def rawLast = if (emptySub) NothingTpe else rawSubPatTypes.last private def rawInit = rawSubPatTypes dropRight 1 - protected def sequenceType = if (emptySub) NothingTpe else rawLast - protected def elementType = if (emptySub) NothingTpe else unapplySeqElementType(rawLast) - protected def repeatedType = if (emptySub) NothingTpe else scalaRepeatedType(elementType) + protected def sequenceType = typeOfLastSelectorOrSelf(rawGet) + protected def elementType = elementTypeOfLastSelectorOrSelf(rawGet) + protected def repeatedType = scalaRepeatedType(elementType) // rawSubPatTypes.last is the Seq, thus there are `rawSubPatTypes.length - 1` non-seq elements in the tuple protected def firstIndexingBinder = rawSubPatTypes.length - 1 @@ -508,6 +511,7 @@ trait MatchTranslation { // to which type should the previous binder be casted? def paramType = constructorTp.finalResultType + def resultType = fun.tpe.finalResultType def isSeq = isVarArgTypes(rawSubPatTypes) @@ -536,7 +540,7 @@ trait MatchTranslation { } // reference the (i-1)th case accessor if it exists, otherwise the (i-1)th tuple component - override protected def tupleSel(binder: Symbol)(i: Int): Tree = { import CODE._ + override protected def tupleSel(binder: Symbol)(i: Int): Tree = { val accessors = binder.caseFieldAccessors if (accessors isDefinedAt (i-1)) REF(binder) DOT accessors(i-1) else codegen.tupleSel(binder)(i) // this won't type check for case classes, as they do not inherit ProductN @@ -550,7 +554,7 @@ trait MatchTranslation { def tpe = fun.tpe def paramType = firstParamType(tpe) - def resultType = fun.tpe.finalResultType + def resultType = tpe.finalResultType def isTyped = (tpe ne NoType) && fun.isTyped && (resultInMonad ne ErrorType) def isSeq = fun.symbol.name == nme.unapplySeq def isBool = resultType =:= BooleanTpe @@ -585,20 +589,20 @@ trait MatchTranslation { } override protected def seqTree(binder: Symbol): Tree = - if (firstIndexingBinder == 0) CODE.REF(binder) + if (firstIndexingBinder == 0) REF(binder) else super.seqTree(binder) // the trees that select the subpatterns on the extractor's result, referenced by `binder` // require (nbSubPats > 0 && (!lastIsStar || isSeq)) override protected def subPatRefs(binder: Symbol): List[Tree] = - if (!isSeq && nbSubPats == 1) List(CODE.REF(binder)) // special case for extractors + if (isSingle) REF(binder) :: Nil // special case for extractors else super.subPatRefs(binder) protected def spliceApply(binder: Symbol): Tree = { object splice extends Transformer { override def transform(t: Tree) = t match { case Apply(x, List(i @ Ident(nme.SELECTOR_DUMMY))) => - treeCopy.Apply(t, x, List(CODE.REF(binder) setPos i.pos)) + treeCopy.Apply(t, x, (REF(binder) setPos i.pos) :: Nil) case _ => super.transform(t) } @@ -606,20 +610,16 @@ trait MatchTranslation { splice transform extractorCallIncludingDummy } - // what's the extractor's result type in the monad? - // turn an extractor's result type into something `monadTypeToSubPatTypesAndRefs` understands - protected lazy val resultInMonad: Type = if (isBool) UnitTpe else matchMonadResult(resultType) // the type of "get" + // what's the extractor's result type in the monad? It is the type of its nullary member `get`. + protected lazy val resultInMonad: Type = if (isBool) UnitTpe else typeOfMemberNamedGet(resultType) protected lazy val rawSubPatTypes = ( if (isBool) Nil - else if (!isSeq && nbSubPats == 1) resultInMonad :: Nil - else getNameBasedProductSelectorTypes(resultInMonad) match { - case Nil => resultInMonad :: Nil - case x => x - } + else if (isSingle) resultInMonad :: Nil // don't go looking for selectors if we only expect one pattern + else typesOfSelectorsOrSelf(resultInMonad) ) - override def toString() = s"ExtractorCallRegular($fun:${fun.tpe} / ${fun.symbol})" + override def toString() = s"ExtractorCallRegular($fun: $tpe / ${fun.symbol})" } /** A conservative approximation of which patterns do not discern anything. diff --git a/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala b/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala index 21666ed0ec..a4944caa2b 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/PatternMatching.scala @@ -175,40 +175,8 @@ trait Interface extends ast.TreeDSL { val matchOwner = typer.context.owner def pureType(tp: Type): Type = tp - // Extracting from the monad: tp == Option[T], result == T - def matchMonadResult(tp: Type) = definitions typeOfMemberNamedGet tp - - // prev == CC[T] - // elem == U - // result == CC[U] - // where "CC" here is Option or any other single-type-parameter container - // - // TODO - what if it has multiple type parameters? - // If we have access to the zero, maybe we can infer the - // type parameter by contrasting with the zero's application. - def mapResultType(prev: Type, elem: Type): Type = { - // default to Option[U] if we can't reliably infer the types - def fallback(elem: Type): Type = elem match { - case TypeRef(_, sym, _) if sym.isTypeParameterOrSkolem => fallback(sym.info.bounds.hi) - case _ => optionType(elem) - } - - // optionType(elem) //pack(elem)) - // The type of "get" in CC[T] is what settles what was wrapped. - val prevElem = matchMonadResult(prev) - if (prevElem =:= elem) prev - else prev.typeArgs match { - case targ :: Nil if targ =:= prevElem => - // the type of "get" in the result should be elem. - // If not, the type arguments are doing something nonobvious - // so fall back on Option. - val result = appliedType(prev.typeConstructor, elem :: Nil) - val newElem = matchMonadResult(result) - if (elem =:= newElem) result else fallback(newElem) - case _ => - fallback(AnyTpe) - } - } + // Extracting from the monad: tp == { def get: T }, result == T + def matchMonadResult(tp: Type) = typeOfMemberNamedGet(tp) def reportUnreachable(pos: Position) = typer.context.unit.warning(pos, "unreachable code") def reportMissingCases(pos: Position, counterExamples: List[String]) = { diff --git a/src/compiler/scala/tools/nsc/typechecker/PatternTypers.scala b/src/compiler/scala/tools/nsc/typechecker/PatternTypers.scala index cc0ffe2ac2..7120aeaaa6 100644 --- a/src/compiler/scala/tools/nsc/typechecker/PatternTypers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/PatternTypers.scala @@ -190,13 +190,12 @@ trait PatternTypers { def resultType = tpe.finalResultType def method = unapplyMember(tpe) def paramType = firstParamType(unapplyType) - def rawGet = if (isBool) UnitTpe else resultOfMatchingMethod(resultType, "get")() - def rawTypes = if (isBool) Nil else if (rawProduct.isEmpty) rawGet :: Nil else rawProduct + def rawGet = if (isBool) UnitTpe else typeOfMemberNamedGetOrSelf(resultType) + def rawTypes = if (isBool) Nil else typesOfSelectorsOrSelf(rawGet) def rawArity = rawTypes.size def isBool = resultType =:= BooleanTpe // aka "Tuple0" or "Option[Unit]" + def isNothing = rawGet =:= NothingTpe def isCase = method.isCase - - private def rawProduct = getNameBasedProductSelectorTypes(rawGet) } object NoUnapplyMethodInfo extends UnapplyMethodInfo(NoSymbol, NoType) { @@ -213,7 +212,7 @@ trait PatternTypers { case _ => NoCaseClassInfo } val exInfo = UnapplyMethodInfo(symbol, tpe) - import exInfo.{ rawTypes, isUnapplySeq, rawGet } + import exInfo.{ rawGet, rawTypes, isUnapplySeq } override def toString = s"ExtractorShape($fun, $args)" @@ -223,15 +222,23 @@ trait PatternTypers { def caseClass = ccInfo.clazz def enclClass = symbol.enclClass + // TODO - merge these. The difference between these two methods is that expectedPatternTypes + // expands the list of types so it is the same length as the number of patterns, whereas formals + // leaves the varargs type unexpanded. def formals = ( if (isUnapplySeq) productTypes :+ varargsType else if (elementArity == 0) productTypes - else if (patternFixedArity == 1) squishIntoOne() + else if (isSingle) squishIntoOne() else wrongArity(patternFixedArity) ) + def expectedPatternTypes = elementArity match { + case 0 => productTypes + case _ if elementArity > 0 && isUnapplySeq => productTypes ::: elementTypes + case _ if productArity > 1 && patternFixedArity == 1 => squishIntoOne() + case _ => wrongArity(patternFixedArity) + } - def rawLast = if (rawTypes.isEmpty) rawGet else rawTypes.last - def elementType = unapplySeqElementType(rawLast) + def elementType = elementTypeOfLastSelectorOrSelf(rawGet) private def hasBogusExtractor = directUnapplyMember(tpe).exists && !unapplyMethod.exists private def expectedArity = "" + productArity + ( if (isUnapplySeq) "+" else "") @@ -254,20 +261,21 @@ trait PatternTypers { rawGet :: Nil } + // elementArity is the number of non-sequence patterns minus the + // the number of non-sequence product elements returned by the extractor. + // If it is zero, there is a perfect match between those parts, and + // if there is a wildcard star it will match any sequence. + // If it is positive, there are more patterns than products, + // so a sequence will have to fill in the elements. If it is negative, + // there are more products than patterns, which is a compile time error. + def elementArity = patternFixedArity - productArity def patternFixedArity = treeInfo effectivePatternArity args def productArity = productTypes.size - def elementArity = patternFixedArity - productArity + def isSingle = !isUnapplySeq && (patternFixedArity == 1) def productTypes = if (isUnapplySeq) rawTypes dropRight 1 else rawTypes def elementTypes = List.fill(elementArity)(elementType) def varargsType = scalaRepeatedType(elementType) - - def expectedPatternTypes = elementArity match { - case 0 => productTypes - case _ if elementArity > 0 && exInfo.isUnapplySeq => productTypes ::: elementTypes - case _ if productArity > 1 && patternFixedArity == 1 => squishIntoOne() - case _ => wrongArity(patternFixedArity) - } } private class VariantToSkolemMap extends TypeMap(trackVariance = true) { diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index f1480c6cbd..19458361e1 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -681,41 +681,23 @@ trait Definitions extends api.StandardDefinitions { def isExactProductType(tp: Type): Boolean = isProductNSymbol(tp.typeSymbol) /** if tpe <: ProductN[T1,...,TN], returns List(T1,...,TN) else Nil */ - def getProductArgs(tpe: Type): List[Type] = tpe.baseClasses find isProductNSymbol match { + @deprecated("No longer used", "2.11.0") def getProductArgs(tpe: Type): List[Type] = tpe.baseClasses find isProductNSymbol match { case Some(x) => tpe.baseType(x).typeArgs case _ => Nil } - def getNameBasedProductSelectors(tpe: Type): List[Symbol] = { - def loop(n: Int): List[Symbol] = tpe member TermName("_" + n) match { - case NoSymbol => Nil - case m if m.paramss.nonEmpty => Nil - case m => m :: loop(n + 1) - } - loop(1) - } - def getNameBasedProductSelectorTypes(tpe: Type): List[Type] = getProductArgs(tpe) match { - case xs if xs.nonEmpty => xs - case _ => getterMemberTypes(tpe, getNameBasedProductSelectors(tpe)) + + @deprecated("No longer used", "2.11.0") def unapplyUnwrap(tpe:Type) = tpe.finalResultType.dealiasWiden match { + case RefinedType(p :: _, _) => p.dealiasWiden + case tp => tp } def getterMemberTypes(tpe: Type, getters: List[Symbol]): List[Type] = getters map (m => dropNullaryMethod(tpe memberType m)) - def getNameBasedProductSeqElementType(tpe: Type) = getNameBasedProductSelectorTypes(tpe) match { - case _ :+ elem => unapplySeqElementType(elem) - case _ => NoType - } - def dropNullaryMethod(tp: Type) = tp match { case NullaryMethodType(restpe) => restpe case _ => tp } - - def unapplyUnwrap(tpe:Type) = tpe.finalResultType.dealiasWiden match { - case RefinedType(p :: _, _) => p.dealiasWiden - case tp => tp - } - def abstractFunctionForFunctionType(tp: Type) = { assert(isFunctionType(tp), tp) abstractFunctionType(tp.typeArgs.init, tp.typeArgs.last) @@ -738,13 +720,54 @@ trait Definitions extends api.StandardDefinitions { def scalaRepeatedType(arg: Type) = appliedType(RepeatedParamClass, arg) def seqType(arg: Type) = appliedType(SeqClass, arg) - def typeOfMemberNamedGet(tp: Type) = resultOfMatchingMethod(tp, nme.get)() - - def unapplySeqElementType(seqType: Type) = ( - resultOfMatchingMethod(seqType, nme.apply)(IntTpe) - orElse resultOfMatchingMethod(seqType, nme.head)() + // FYI the long clunky name is because it's really hard to put "get" into the + // name of a method without it sounding like the method "get"s something, whereas + // this method is about a type member which just happens to be named get. + def typeOfMemberNamedGet(tp: Type) = resultOfMatchingMethod(tp, nme.get)() + def typeOfMemberNamedHead(tp: Type) = resultOfMatchingMethod(tp, nme.head)() + def typeOfMemberNamedApply(tp: Type) = resultOfMatchingMethod(tp, nme.apply)() + def typeOfMemberNamedGetOrSelf(tp: Type) = typeOfMemberNamedGet(tp) orElse tp + def typesOfSelectors(tp: Type) = getterMemberTypes(tp, productSelectors(tp)) + def typesOfCaseAccessors(tp: Type) = getterMemberTypes(tp, tp.typeSymbol.caseFieldAccessors) + + /** If this is a case class, the case field accessors (which may be an empty list.) + * Otherwise, if there are any product selectors, that list. + * Otherwise, a list containing only the type itself. + */ + def typesOfSelectorsOrSelf(tp: Type): List[Type] = ( + if (tp.typeSymbol.isCase) + typesOfCaseAccessors(tp) + else typesOfSelectors(tp) match { + case Nil => tp :: Nil + case tps => tps + } ) + /** If the given type has one or more product selectors, the type of the last one. + * Otherwise, the type itself. + */ + def typeOfLastSelectorOrSelf(tp: Type) = typesOfSelectorsOrSelf(tp).last + + def elementTypeOfLastSelectorOrSelf(tp: Type) = { + val last = typeOfLastSelectorOrSelf(tp) + ( typeOfMemberNamedHead(last) + orElse typeOfMemberNamedApply(last) + orElse elementType(ArrayClass, last) + ) + } + + /** Returns the method symbols for members _1, _2, ..., _N + * which exist in the given type. + */ + def productSelectors(tpe: Type): List[Symbol] = { + def loop(n: Int): List[Symbol] = tpe member TermName("_" + n) match { + case NoSymbol => Nil + case m if m.paramss.nonEmpty => Nil + case m => m :: loop(n + 1) + } + loop(1) + } + /** If `tp` has a term member `name`, the first parameter list of which * matches `paramTypes`, and which either has no further parameter * lists or only an implicit one, then the result type of the matching |