blob: 872562dd4d214a1914534a6a1ccf61f2634f929a (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
/*
* filter: inliner warning; re-run with -Yinline-warnings for details
*/
import scala.tools.partest._
import scala.tools.nsc._
import scala.reflect.runtime.{universe => ru}
import scala.language.implicitConversions
// necessary to avoid bincompat with scala-partest compiled against the old compiler
abstract class CompilerTest extends DirectTest {
def check(source: String, unit: global.CompilationUnit): Unit
lazy val global: Global = newCompiler()
lazy val units: List[global.CompilationUnit] = compilationUnits(global)(sources: _ *)
import global._
import definitions.{ compilerTypeFromTag }
override def extraSettings = "-usejavacp -d " + testOutput.path
def show() = (sources, units).zipped foreach check
// Override at least one of these...
def code = ""
def sources: List[String] = List(code)
// Utility functions
class MkType(sym: Symbol) {
def apply[M](implicit t: ru.TypeTag[M]): Type =
if (sym eq NoSymbol) NoType
else appliedType(sym, compilerTypeFromTag(t))
}
implicit def mkMkType(sym: Symbol) = new MkType(sym)
def allMembers(root: Symbol): List[Symbol] = {
def loop(seen: Set[Symbol], roots: List[Symbol]): List[Symbol] = {
val latest = roots flatMap (_.info.members) filterNot (seen contains _)
if (latest.isEmpty) seen.toList.sortWith(_ isLess _)
else loop(seen ++ latest, latest)
}
loop(Set(), List(root))
}
class SymsInPackage(pkgName: String) {
def pkg = rootMirror.getPackage(TermName(pkgName))
def classes = allMembers(pkg) filter (_.isClass)
def modules = allMembers(pkg) filter (_.isModule)
def symbols = classes ++ terms filterNot (_ eq NoSymbol)
def terms = allMembers(pkg) filter (s => s.isTerm && !s.isConstructor)
def tparams = classes flatMap (_.info.typeParams)
def tpes = symbols map (_.tpe) distinct
}
}
object Test extends CompilerTest {
import global._
import definitions._
override def code = """
package ano
class ann(x: Any) extends annotation.TypeConstraint
abstract class Base {
def foo(x: String): String @ann(x.trim())
}
class Sub extends Base {
def foo(x: String): String @ann(x.trim()) = x
}
"""
object syms extends SymsInPackage("ano")
import syms._
def check(source: String, unit: global.CompilationUnit) {
exitingTyper {
terms.filter(_.name.toString == "foo").foreach(sym => {
val xParam = sym.tpe.paramss.flatten.head
val annot = sym.tpe.finalResultType.annotations.head
val xRefs = annot.args.head.filter(t => t.symbol == xParam)
println(s"testing symbol ${sym.ownerChain}, param $xParam, xRefs $xRefs")
assert(xRefs.length == 1, xRefs)
})
}
}
}
|