1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
|
package mill.scalaplugin
import java.io.FileInputStream
import java.lang.annotation.Annotation
import java.net.URLClassLoader
import java.util.zip.ZipInputStream
import ammonite.ops.{Path, ls, pwd}
import mill.util.Ctx.LogCtx
import mill.util.PrintLogger
import sbt.testing._
import scala.collection.mutable
object TestRunner {
def listClassFiles(base: Path): Iterator[String] = {
if (base.isDir) ls.rec(base).toIterator.filter(_.ext == "class").map(_.relativeTo(base).toString)
else {
val zip = new ZipInputStream(new FileInputStream(base.toIO))
Iterator.continually(zip.getNextEntry).takeWhile(_ != null).map(_.getName).filter(_.endsWith(".class"))
}
}
def runTests(cl: ClassLoader, framework: Framework, classpath: Seq[Path]) = {
val fingerprints = framework.fingerprints()
val testClasses = classpath.flatMap { base =>
listClassFiles(base).flatMap { path =>
val cls = cl.loadClass(path.stripSuffix(".class").replace('/', '.'))
fingerprints.find {
case f: SubclassFingerprint =>
cl.loadClass(f.superclassName()).isAssignableFrom(cls)
case f: AnnotatedFingerprint =>
cls.isAnnotationPresent(
cl.loadClass(f.annotationName()).asInstanceOf[Class[Annotation]]
)
}.map { f => (cls, f) }
}
}
testClasses
}
def main(args: Array[String]): Unit = {
val result = apply(
frameworkName = args(0),
entireClasspath = args(1).split(" ").map(Path(_)),
testClassfilePath = args(2).split(" ").map(Path(_)),
args = args(3) match{ case "" => Nil case x => x.split(" ").toList }
)(new LogCtx {
def log = new PrintLogger(true)
})
val outputPath = args(4)
ammonite.ops.write(Path(outputPath), upickle.default.write(result))
// Tests are over, kill the JVM whether or not anyone's threads are still running
// Always return 0, even if tests fail. The caller can pick up the detailed test
// results from the outputPath
System.exit(0)
}
def apply(frameworkName: String,
entireClasspath: Seq[Path],
testClassfilePath: Seq[Path],
args: Seq[String])
(implicit ctx: LogCtx): Option[String] = {
val outerClassLoader = getClass.getClassLoader
val cl = new URLClassLoader(
entireClasspath.map(_.toIO.toURI.toURL).toArray,
ClassLoader.getSystemClassLoader().getParent()){
override def findClass(name: String) = {
if (name.startsWith("sbt.testing.")){
outerClassLoader.loadClass(name)
}else{
super.findClass(name)
}
}
}
val framework = cl.loadClass(frameworkName)
.newInstance()
.asInstanceOf[sbt.testing.Framework]
val testClasses = runTests(cl, framework, testClassfilePath)
val runner = framework.runner(args.toArray, args.toArray, cl)
val tasks = runner.tasks(
for((cls, fingerprint) <- testClasses.toArray)
yield {
new TaskDef(cls.getName.stripSuffix("$"), fingerprint, true, Array())
}
)
val events = mutable.Buffer.empty[Status]
for(t <- tasks){
t.execute(
new EventHandler {
def handle(event: Event) = events.append(event.status())
},
Array(
new Logger {
def debug(msg: String) = ctx.log.info(msg)
def error(msg: String) = ctx.log.error(msg)
def ansiCodesSupported() = true
def warn(msg: String) = ctx.log.info(msg)
def trace(t: Throwable) = t.printStackTrace(ctx.log.outputStream)
def info(msg: String) = ctx.log.info(msg)
})
)
}
val doneMsg = runner.done()
val msg =
if (doneMsg.trim.nonEmpty) doneMsg
else{
val grouped = events.groupBy(x => x).mapValues(_.length).filter(_._2 != 0).toList.sorted
grouped.map{case (k, v) => k + ": " + v}.mkString(",")
}
ctx.log.info(msg)
if (events.count(Set(Status.Error, Status.Failure)) == 0) None
else Some(msg)
}
}
|