From 7f79ef0e30f088b41b14763f44338c240acf1a63 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Thu, 5 Apr 2012 17:04:55 -0700 Subject: Added overloading resolution to reflect.api.Symbol --- .../scala/reflect/internal/Definitions.scala | 10 +-- src/compiler/scala/reflect/internal/Symbols.scala | 73 +++++++++++++++++++++- 2 files changed, 75 insertions(+), 8 deletions(-) (limited to 'src/compiler') diff --git a/src/compiler/scala/reflect/internal/Definitions.scala b/src/compiler/scala/reflect/internal/Definitions.scala index 8ea3cd511a..cc274eee22 100644 --- a/src/compiler/scala/reflect/internal/Definitions.scala +++ b/src/compiler/scala/reflect/internal/Definitions.scala @@ -380,10 +380,10 @@ trait Definitions extends reflect.api.StandardDefinitions { def isCastSymbol(sym: Symbol) = sym == Any_asInstanceOf || sym == Object_asInstanceOf def isJavaVarArgsMethod(m: Symbol) = m.isMethod && isJavaVarArgs(m.info.params) - def isJavaVarArgs(params: List[Symbol]) = params.nonEmpty && isJavaRepeatedParamType(params.last.tpe) - def isScalaVarArgs(params: List[Symbol]) = params.nonEmpty && isScalaRepeatedParamType(params.last.tpe) - def isVarArgsList(params: List[Symbol]) = params.nonEmpty && isRepeatedParamType(params.last.tpe) - def isVarArgTypes(formals: List[Type]) = formals.nonEmpty && isRepeatedParamType(formals.last) + def isJavaVarArgs(params: Seq[Symbol]) = params.nonEmpty && isJavaRepeatedParamType(params.last.tpe) + def isScalaVarArgs(params: Seq[Symbol]) = params.nonEmpty && isScalaRepeatedParamType(params.last.tpe) + def isVarArgsList(params: Seq[Symbol]) = params.nonEmpty && isRepeatedParamType(params.last.tpe) + def isVarArgTypes(formals: Seq[Type]) = formals.nonEmpty && isRepeatedParamType(formals.last) def hasRepeatedParam(tp: Type): Boolean = tp match { case MethodType(formals, restpe) => isScalaVarArgs(formals) || hasRepeatedParam(restpe) @@ -990,7 +990,7 @@ trait Definitions extends reflect.api.StandardDefinitions { case result => result } } - + def packageExists(packageName: String): Boolean = getModuleIfDefined(packageName).isPackage diff --git a/src/compiler/scala/reflect/internal/Symbols.scala b/src/compiler/scala/reflect/internal/Symbols.scala index 4473d63f5f..4842d47e4d 100644 --- a/src/compiler/scala/reflect/internal/Symbols.scala +++ b/src/compiler/scala/reflect/internal/Symbols.scala @@ -74,6 +74,73 @@ trait Symbols extends api.Symbols { self: SymbolTable => def setInternalFlags(flag: Long): this.type = { setFlag(flag); this } def setTypeSignature(tpe: Type): this.type = { setInfo(tpe); this } def setAnnotations(annots: AnnotationInfo*): this.type = { setAnnotations(annots.toList); this } + + private def lastElemType(ts: Seq[Type]): Type = ts.last.normalize.typeArgs.head + + private def formalTypes(formals: List[Type], nargs: Int): List[Type] = { + val formals1 = formals mapConserve { + case TypeRef(_, ByNameParamClass, List(arg)) => arg + case formal => formal + } + if (isVarArgTypes(formals1)) { + val ft = lastElemType(formals) + formals1.init ::: List.fill(nargs - (formals1.length - 1))(ft) + } else formals1 + } + + def resolveOverloaded(pre: Type, targs: Seq[Type], actuals: Seq[Type]): Symbol = { + def firstParams(tpe: Type): (List[Symbol], List[Type]) = tpe match { + case PolyType(tparams, restpe) => + val (Nil, formals) = firstParams(restpe) + (tparams, formals) + case MethodType(params, _) => + (Nil, params map (_.tpe)) + case _ => + (Nil, Nil) + } + def isApplicable(alt: Symbol, targs: List[Type], actuals: Seq[Type]) = { + def isApplicableType(tparams: List[Symbol], tpe: Type): Boolean = { + val (tparams, formals) = firstParams(pre memberType alt) + val formals1 = formalTypes(formals, actuals.length) + val actuals1 = + if (isVarArgTypes(actuals)) { + if (!isVarArgTypes(formals)) return false + actuals.init :+ lastElemType(actuals) + } else actuals + if (formals1.length != actuals1.length) return false + + if (tparams.isEmpty) return (actuals1 corresponds formals1)(_ <:< _) + + if (targs.length == tparams.length) + isApplicableType(List(), tpe.instantiateTypeParams(tparams, targs)) + else if (targs.nonEmpty) + false + else { + val tvars = tparams map (TypeVar(_)) + (actuals1 corresponds formals1) { (actual, formal) => + val tp1 = actual.deconst.instantiateTypeParams(tparams, tvars) + val pt1 = actual.instantiateTypeParams(tparams, tvars) + tp1 <:< pt1 + } && + solve(tvars, tparams, List.fill(tparams.length)(COVARIANT), upper = false) + } + } + isApplicableType(List(), pre.memberType(alt)) + } + def isAsGood(alt1: Symbol, alt2: Symbol): Boolean = { + alt1 == alt2 || + alt2 == NoSymbol || { + val (tparams, formals) = firstParams(pre memberType alt1) + isApplicable(alt2, tparams map (_.tpe), formals) + } + } + assert(isOverloaded) + val applicables = alternatives filter (isApplicable(_, targs.toList, actuals)) + def winner(alts: List[Symbol]) = + ((NoSymbol: Symbol) /: alts)((best, alt) => if (isAsGood(alt, best)) alt else best) + val best = winner(applicables) + if (best == winner(applicables.reverse)) best else NoSymbol + } } /** The class for all symbols */ @@ -1188,7 +1255,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => * which immediately follows any of parser, namer, typer, or erasure. * In effect that means this will return one of: * - * - packageobjects (follows namer) + * - packageobjects (follows namer) * - superaccessors (follows typer) * - lazyvals (follows erasure) * - null @@ -1879,7 +1946,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => /** Remove private modifier from symbol `sym`s definition. If `sym` is a * is not a constructor nor a static module rename it by expanding its name to avoid name clashes - * @param base the fully qualified name of this class will be appended if name expansion is needed + * @param base the fully qualified name of this class will be appended if name expansion is needed */ final def makeNotPrivate(base: Symbol) { if (this.isPrivate) { @@ -2779,7 +2846,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => assert(validFrom != NoPeriod) override def toString() = "TypeHistory(" + phaseOf(validFrom)+":"+runId(validFrom) + "," + info + "," + prev + ")" - + def toList: List[TypeHistory] = this :: ( if (prev eq null) Nil else prev.toList ) } } -- cgit v1.2.3