From 4334f4c4a47f7e2dc0c382ada7d1a683bdfbf215 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Thu, 15 Aug 2013 15:02:18 -0700 Subject: Some general purpose methods. Motivated by pattern matcher work, also useful elsewhere. --- src/compiler/scala/tools/nsc/ast/TreeDSL.scala | 1 + .../scala/reflect/internal/Definitions.scala | 67 +++++++++++++++++++++- src/reflect/scala/reflect/internal/TreeInfo.scala | 9 +-- src/reflect/scala/reflect/internal/Types.scala | 5 +- 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala index 66ed0c8fae..d7a32c3be0 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -83,6 +83,7 @@ trait TreeDSL { def INT_>= (other: Tree) = fn(target, getMember(IntClass, nme.GE), other) def INT_== (other: Tree) = fn(target, getMember(IntClass, nme.EQ), other) + def INT_- (other: Tree) = fn(target, getMember(IntClass, nme.MINUS), other) // generic operations on ByteClass, IntClass, LongClass def GEN_| (other: Tree, kind: ClassSymbol) = fn(target, getMember(kind, nme.OR), other) diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index 6b7aa2dddf..f1480c6cbd 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -253,6 +253,13 @@ trait Definitions extends api.StandardDefinitions { || tp =:= AnyRefTpe ) + def hasMultipleNonImplicitParamLists(member: Symbol): Boolean = hasMultipleNonImplicitParamLists(member.info) + def hasMultipleNonImplicitParamLists(info: Type): Boolean = info match { + case PolyType(_, restpe) => hasMultipleNonImplicitParamLists(restpe) + case MethodType(_, MethodType(p :: _, _)) if !p.isImplicit => true + case _ => false + } + private def fixupAsAnyTrait(tpe: Type): Type = tpe match { case ClassInfoType(parents, decls, clazz) => if (parents.head.typeSymbol == AnyClass) tpe @@ -384,6 +391,7 @@ trait Definitions extends api.StandardDefinitions { def arrayCloneMethod = getMemberMethod(ScalaRunTimeModule, nme.array_clone) def ensureAccessibleMethod = getMemberMethod(ScalaRunTimeModule, nme.ensureAccessible) def arrayClassMethod = getMemberMethod(ScalaRunTimeModule, nme.arrayClass) + def traversableDropMethod = getMemberMethod(ScalaRunTimeModule, nme.drop) // classes with special meanings lazy val StringAddClass = requiredClass[scala.runtime.StringAdd] @@ -423,6 +431,15 @@ trait Definitions extends api.StandardDefinitions { def isVarArgsList(params: Seq[Symbol]) = params.nonEmpty && isRepeatedParamType(params.last.tpe) def isVarArgTypes(formals: Seq[Type]) = formals.nonEmpty && isRepeatedParamType(formals.last) + def firstParamType(tpe: Type): Type = tpe.paramTypes match { + case p :: _ => p + case _ => NoType + } + def isImplicitParamss(paramss: List[List[Symbol]]) = paramss match { + case (p :: _) :: _ => p.isImplicit + case _ => false + } + def hasRepeatedParam(tp: Type): Boolean = tp match { case MethodType(formals, restpe) => isScalaVarArgs(formals) || hasRepeatedParam(restpe) case PolyType(_, restpe) => hasRepeatedParam(restpe) @@ -430,7 +447,12 @@ trait Definitions extends api.StandardDefinitions { } // wrapping and unwrapping - def dropByName(tp: Type): Type = elementExtract(ByNameParamClass, tp) orElse tp + def dropByName(tp: Type): Type = elementExtract(ByNameParamClass, tp) orElse tp + def dropRepeated(tp: Type): Type = ( + if (isJavaRepeatedParamType(tp)) elementExtract(JavaRepeatedParamClass, tp) orElse tp + else if (isScalaRepeatedParamType(tp)) elementExtract(RepeatedParamClass, tp) orElse tp + else tp + ) def repeatedToSingle(tp: Type): Type = elementExtract(RepeatedParamClass, tp) orElse tp def repeatedToSeq(tp: Type): Type = elementTransform(RepeatedParamClass, tp)(seqType) orElse tp def seqToRepeated(tp: Type): Type = elementTransform(SeqClass, tp)(scalaRepeatedType) orElse tp @@ -663,6 +685,26 @@ trait Definitions extends api.StandardDefinitions { case Some(x) => tpe.baseType(x).typeArgs case _ => Nil } + def getNameBasedProductSelectors(tpe: Type): List[Symbol] = { + def loop(n: Int): List[Symbol] = tpe member TermName("_" + n) match { + case NoSymbol => Nil + case m if m.paramss.nonEmpty => Nil + case m => m :: loop(n + 1) + } + loop(1) + } + def getNameBasedProductSelectorTypes(tpe: Type): List[Type] = getProductArgs(tpe) match { + case xs if xs.nonEmpty => xs + case _ => getterMemberTypes(tpe, getNameBasedProductSelectors(tpe)) + } + + def getterMemberTypes(tpe: Type, getters: List[Symbol]): List[Type] = + getters map (m => dropNullaryMethod(tpe memberType m)) + + def getNameBasedProductSeqElementType(tpe: Type) = getNameBasedProductSelectorTypes(tpe) match { + case _ :+ elem => unapplySeqElementType(elem) + case _ => NoType + } def dropNullaryMethod(tp: Type) = tp match { case NullaryMethodType(restpe) => restpe @@ -696,6 +738,29 @@ trait Definitions extends api.StandardDefinitions { def scalaRepeatedType(arg: Type) = appliedType(RepeatedParamClass, arg) def seqType(arg: Type) = appliedType(SeqClass, arg) + def typeOfMemberNamedGet(tp: Type) = resultOfMatchingMethod(tp, nme.get)() + + def unapplySeqElementType(seqType: Type) = ( + resultOfMatchingMethod(seqType, nme.apply)(IntTpe) + orElse resultOfMatchingMethod(seqType, nme.head)() + ) + + /** If `tp` has a term member `name`, the first parameter list of which + * matches `paramTypes`, and which either has no further parameter + * lists or only an implicit one, then the result type of the matching + * method. Otherwise, NoType. + */ + def resultOfMatchingMethod(tp: Type, name: TermName)(paramTypes: Type*): Type = { + def matchesParams(member: Symbol) = member.paramss match { + case Nil => paramTypes.isEmpty + case ps :: rest => (rest.isEmpty || isImplicitParamss(rest)) && (ps corresponds paramTypes)(_.tpe =:= _) + } + tp member name filter matchesParams match { + case NoSymbol => NoType + case member => (tp memberType member).finalResultType + } + } + def ClassType(arg: Type) = if (phase.erasedTypes) ClassClass.tpe else appliedType(ClassClass, arg) /** Can we tell by inspecting the symbol that it will never diff --git a/src/reflect/scala/reflect/internal/TreeInfo.scala b/src/reflect/scala/reflect/internal/TreeInfo.scala index 5c92512193..d01f1ce681 100644 --- a/src/reflect/scala/reflect/internal/TreeInfo.scala +++ b/src/reflect/scala/reflect/internal/TreeInfo.scala @@ -488,7 +488,7 @@ abstract class TreeInfo { } object WildcardStarArg { - def unapply(tree: Typed): Option[Tree] = tree match { + def unapply(tree: Tree): Option[Tree] = tree match { case Typed(expr, Ident(tpnme.WILDCARD_STAR)) => Some(expr) case _ => None } @@ -628,11 +628,12 @@ abstract class TreeInfo { * case Extractor(a @ (b, c)) => 2 * }}} */ - def effectivePatternArity(args: List[Tree]): Int = (args.map(unbind) match { + def effectivePatternArity(args: List[Tree]): Int = flattenedPatternArgs(args).length + + def flattenedPatternArgs(args: List[Tree]): List[Tree] = args map unbind match { case Apply(fun, xs) :: Nil if isTupleSymbol(fun.symbol) => xs case xs => xs - }).length - + } // used in the symbols for labeldefs and valdefs emitted by the pattern matcher // tailcalls, cps,... use this flag combination to detect translated matches diff --git a/src/reflect/scala/reflect/internal/Types.scala b/src/reflect/scala/reflect/internal/Types.scala index 0639a8e3f0..94222565c4 100644 --- a/src/reflect/scala/reflect/internal/Types.scala +++ b/src/reflect/scala/reflect/internal/Types.scala @@ -4015,9 +4015,12 @@ trait Types def isErrorOrWildcard(tp: Type) = (tp eq ErrorType) || (tp eq WildcardType) + /** This appears to be equivalent to tp.isInstanceof[SingletonType], + * except it excludes ConstantTypes. + */ def isSingleType(tp: Type) = tp match { case ThisType(_) | SuperType(_, _) | SingleType(_, _) => true - case _ => false + case _ => false } def isConstantType(tp: Type) = tp match { -- cgit v1.2.3