diff options
Diffstat (limited to 'src/partest/scala/tools/partest/DirectTest.scala')
-rw-r--r-- | src/partest/scala/tools/partest/DirectTest.scala | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/src/partest/scala/tools/partest/DirectTest.scala b/src/partest/scala/tools/partest/DirectTest.scala index 46e9621b31..7f9ca3a321 100644 --- a/src/partest/scala/tools/partest/DirectTest.scala +++ b/src/partest/scala/tools/partest/DirectTest.scala @@ -6,7 +6,7 @@ package scala.tools.partest import scala.tools.nsc._ -import util.{BatchSourceFile, CommandLineParser} +import util.{ SourceFile, BatchSourceFile, CommandLineParser } import reporters.{Reporter, ConsoleReporter} /** A class for testing code which is embedded as a string. @@ -45,18 +45,32 @@ abstract class DirectTest extends App { def reporter(settings: Settings): Reporter = new ConsoleReporter(settings) - def newSources(sourceCodes: String*) = sourceCodes.toList.zipWithIndex map { - case (src, idx) => new BatchSourceFile("newSource" + (idx + 1), src) - } + private def newSourcesWithExtension(ext: String)(codes: String*): List[BatchSourceFile] = + codes.toList.zipWithIndex map { + case (src, idx) => new BatchSourceFile(s"newSource${idx + 1}.$ext", src) + } + + def newJavaSources(codes: String*) = newSourcesWithExtension("java")(codes: _*) + def newSources(codes: String*) = newSourcesWithExtension("scala")(codes: _*) + def compileString(global: Global)(sourceCode: String): Boolean = { withRun(global)(_ compileSources newSources(sourceCode)) !global.reporter.hasErrors } - def compilationUnits(global: Global)(sourceCodes: String*): List[global.CompilationUnit] = { - val units = withRun(global) { run => - run compileSources newSources(sourceCodes: _*) + + def javaCompilationUnits(global: Global)(sourceCodes: String*) = { + sourceFilesToCompiledUnits(global)(newJavaSources(sourceCodes: _*)) + } + + def sourceFilesToCompiledUnits(global: Global)(files: List[SourceFile]) = { + withRun(global) { run => + run compileSources files run.units.toList } + } + + def compilationUnits(global: Global)(sourceCodes: String*): List[global.CompilationUnit] = { + val units = sourceFilesToCompiledUnits(global)(newSources(sourceCodes: _*)) if (global.reporter.hasErrors) { global.reporter.flush() sys.error("Compilation failure.") |