summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/typechecker/Typers.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/scala/tools/nsc/typechecker/Typers.scala')
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala79
1 files changed, 50 insertions, 29 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 553583e6b7..239fc47867 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -2213,37 +2213,58 @@ trait Typers extends Modes with Adaptations with Tags {
*/
def checkMethodStructuralCompatible(ddef: DefDef): Unit = {
val meth = ddef.symbol
- def fail(pos: Position, msg: String) = unit.error(pos, msg)
- val tp: Type = meth.tpe match {
- case mt @ MethodType(_, _) => mt
- case NullaryMethodType(restpe) => restpe // TODO_NMT: drop NullaryMethodType from resultType?
- case PolyType(_, restpe) => restpe
- case _ => NoType
- }
- def nthParamPos(n: Int) = ddef.vparamss match {
- case xs :: _ if xs.length > n => xs(n).pos
- case _ => meth.pos
- }
- def failStruct(pos: Position, what: String, where: String = "Parameter") =
- fail(pos, s"$where type in structural refinement may not refer to $what")
-
- foreachWithIndex(tp.paramTypes) { (paramType, idx) =>
- val sym = paramType.typeSymbol
- def paramPos = nthParamPos(idx)
-
- if (sym.isAbstractType) {
- if (!sym.hasTransOwner(meth.owner))
- failStruct(paramPos, "an abstract type defined outside that refinement")
- else if (!sym.hasTransOwner(meth))
- failStruct(paramPos, "a type member of that refinement")
+ def parentString = meth.owner.parentSymbols filterNot (_ == ObjectClass) match {
+ case Nil => ""
+ case xs => xs.map(_.nameString).mkString(" (of ", " with ", ")")
+ }
+ def fail(pos: Position, msg: String): Boolean = {
+ unit.error(pos, msg)
+ false
+ }
+ /** Have to examine all parameters in all lists.
+ */
+ def paramssTypes(tp: Type): List[List[Type]] = tp match {
+ case mt @ MethodType(_, restpe) => mt.paramTypes :: paramssTypes(restpe)
+ case PolyType(_, restpe) => paramssTypes(restpe)
+ case _ => Nil
+ }
+ def resultType = meth.tpe.finalResultType
+ def nthParamPos(n1: Int, n2: Int) =
+ try ddef.vparamss(n1)(n2).pos catch { case _: IndexOutOfBoundsException => meth.pos }
+
+ def failStruct(pos: Position, what: String, where: String = "Parameter type") =
+ fail(pos, s"$where in structural refinement may not refer to $what")
+
+ foreachWithIndex(paramssTypes(meth.tpe)) { (paramList, listIdx) =>
+ foreachWithIndex(paramList) { (paramType, paramIdx) =>
+ val sym = paramType.typeSymbol
+ def paramPos = nthParamPos(listIdx, paramIdx)
+
+ /** Not enough to look for abstract types; have to recursively check the bounds
+ * of each abstract type for more abstract types. Almost certainly there are other
+ * exploitable type soundness bugs which can be seen by bounding a type parameter
+ * by an abstract type which itself is bounded by an abstract type.
+ */
+ def checkAbstract(tp0: Type, what: String): Boolean = {
+ def check(sym: Symbol): Boolean = !sym.isAbstractType || {
+ log(s"""checking $tp0 in refinement$parentString at ${meth.owner.owner.fullLocationString}""")
+ ( (!sym.hasTransOwner(meth.owner) && failStruct(paramPos, "an abstract type defined outside that refinement", what))
+ || (!sym.hasTransOwner(meth) && failStruct(paramPos, "a type member of that refinement", what))
+ || checkAbstract(sym.info.bounds.hi, "Type bound")
+ )
+ }
+ tp0.dealiasWidenChain forall (t => check(t.typeSymbol))
+ }
+ checkAbstract(paramType, "Parameter type")
+
+ if (sym.isDerivedValueClass)
+ failStruct(paramPos, "a user-defined value class")
+ if (paramType.isInstanceOf[ThisType] && sym == meth.owner)
+ failStruct(paramPos, "the type of that refinement (self type)")
}
- if (sym.isDerivedValueClass)
- failStruct(paramPos, "a user-defined value class")
- if (paramType.isInstanceOf[ThisType] && sym == meth.owner)
- failStruct(paramPos, "the type of that refinement (self type)")
}
- if (tp.resultType.typeSymbol.isDerivedValueClass)
- failStruct(ddef.tpt.pos, "a user-defined value class", where = "Result")
+ if (resultType.typeSymbol.isDerivedValueClass)
+ failStruct(ddef.tpt.pos, "a user-defined value class", where = "Result type")
}
def typedUseCase(useCase: UseCase) {