blob: d3dfd7528d8617c1b7033ee782d8a5885db5dca2 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
package scala.scalajs.compiler.test.util
import language.implicitConversions
import scala.tools.nsc._
import scala.reflect.internal.util.SourceFile
import scala.util.control.ControlThrowable
import org.junit.Assert._
import scala.scalajs.compiler.{ScalaJSPlugin, JSTreeExtractors}
import JSTreeExtractors.jse
import scala.scalajs.ir
import ir.{Trees => js}
abstract class JSASTTest extends DirectTest {
private var lastAST: JSAST = _
class JSAST(val clDefs: List[js.Tree]) {
type Pat = PartialFunction[js.Tree, Unit]
class PFTraverser(pf: Pat) extends ir.Traversers.Traverser {
private case object Found extends ControlThrowable
private[this] var finding = false
def find: Boolean = {
finding = true
try {
clDefs.map(traverse)
false
} catch {
case Found => true
}
}
def traverse(): Unit = {
finding = false
clDefs.map(traverse)
}
override def traverse(tree: js.Tree): Unit = {
if (finding && pf.isDefinedAt(tree))
throw Found
if (!finding)
pf.lift(tree)
super.traverse(tree)
}
}
def has(trgName: String)(pf: Pat): this.type = {
val tr = new PFTraverser(pf)
assertTrue(s"AST should have $trgName", tr.find)
this
}
def hasNot(trgName: String)(pf: Pat): this.type = {
val tr = new PFTraverser(pf)
assertFalse(s"AST should not have $trgName", tr.find)
this
}
def traverse(pf: Pat): this.type = {
val tr = new PFTraverser(pf)
tr.traverse()
this
}
def show: this.type = {
clDefs foreach println _
this
}
}
implicit def string2ast(str: String): JSAST = stringAST(str)
override def newScalaJSPlugin(global: Global) = new ScalaJSPlugin(global) {
override def generatedJSAST(cld: List[js.Tree]): Unit = {
lastAST = new JSAST(cld)
}
}
def stringAST(code: String): JSAST = stringAST(defaultGlobal)(code)
def stringAST(global: Global)(code: String): JSAST = {
compileString(global)(code)
lastAST
}
def sourceAST(source: SourceFile): JSAST = sourceAST(defaultGlobal)(source)
def sourceAST(global: Global)(source: SourceFile): JSAST = {
compileSources(global)(source)
lastAST
}
}
|