summaryrefslogtreecommitdiff
path: root/examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala')
-rw-r--r--examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala100
1 files changed, 100 insertions, 0 deletions
diff --git a/examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala b/examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala
new file mode 100644
index 0000000..d3dfd75
--- /dev/null
+++ b/examples/scala-js/compiler/src/test/scala/scala/scalajs/compiler/test/util/JSASTTest.scala
@@ -0,0 +1,100 @@
+package scala.scalajs.compiler.test.util
+
+import language.implicitConversions
+
+import scala.tools.nsc._
+import scala.reflect.internal.util.SourceFile
+
+import scala.util.control.ControlThrowable
+
+import org.junit.Assert._
+
+import scala.scalajs.compiler.{ScalaJSPlugin, JSTreeExtractors}
+import JSTreeExtractors.jse
+import scala.scalajs.ir
+import ir.{Trees => js}
+
+abstract class JSASTTest extends DirectTest {
+
+ private var lastAST: JSAST = _
+
+ class JSAST(val clDefs: List[js.Tree]) {
+ type Pat = PartialFunction[js.Tree, Unit]
+
+ class PFTraverser(pf: Pat) extends ir.Traversers.Traverser {
+ private case object Found extends ControlThrowable
+
+ private[this] var finding = false
+
+ def find: Boolean = {
+ finding = true
+ try {
+ clDefs.map(traverse)
+ false
+ } catch {
+ case Found => true
+ }
+ }
+
+ def traverse(): Unit = {
+ finding = false
+ clDefs.map(traverse)
+ }
+
+ override def traverse(tree: js.Tree): Unit = {
+ if (finding && pf.isDefinedAt(tree))
+ throw Found
+
+ if (!finding)
+ pf.lift(tree)
+
+ super.traverse(tree)
+ }
+ }
+
+ def has(trgName: String)(pf: Pat): this.type = {
+ val tr = new PFTraverser(pf)
+ assertTrue(s"AST should have $trgName", tr.find)
+ this
+ }
+
+ def hasNot(trgName: String)(pf: Pat): this.type = {
+ val tr = new PFTraverser(pf)
+ assertFalse(s"AST should not have $trgName", tr.find)
+ this
+ }
+
+ def traverse(pf: Pat): this.type = {
+ val tr = new PFTraverser(pf)
+ tr.traverse()
+ this
+ }
+
+ def show: this.type = {
+ clDefs foreach println _
+ this
+ }
+
+ }
+
+ implicit def string2ast(str: String): JSAST = stringAST(str)
+
+ override def newScalaJSPlugin(global: Global) = new ScalaJSPlugin(global) {
+ override def generatedJSAST(cld: List[js.Tree]): Unit = {
+ lastAST = new JSAST(cld)
+ }
+ }
+
+ def stringAST(code: String): JSAST = stringAST(defaultGlobal)(code)
+ def stringAST(global: Global)(code: String): JSAST = {
+ compileString(global)(code)
+ lastAST
+ }
+
+ def sourceAST(source: SourceFile): JSAST = sourceAST(defaultGlobal)(source)
+ def sourceAST(global: Global)(source: SourceFile): JSAST = {
+ compileSources(global)(source)
+ lastAST
+ }
+
+}