summaryrefslogtreecommitdiff
path: root/contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala')
-rw-r--r--contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala574
1 files changed, 574 insertions, 0 deletions
diff --git a/contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala b/contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala
new file mode 100644
index 00000000..6db85c2c
--- /dev/null
+++ b/contrib/bsp/src/mill/contrib/bsp/MillBuildServer.scala
@@ -0,0 +1,574 @@
+package mill.contrib.bsp
+
+import java.util.concurrent.CompletableFuture
+
+import ch.epfl.scala.bsp4j._
+import com.google.gson.JsonObject
+import mill.api.Result.{Skipped, Success}
+import mill.api.{BuildProblemReporter, DummyTestReporter, Result, Strict}
+import mill.contrib.bsp.ModuleUtils._
+import mill.define.Segment.Label
+import mill.define.{Discover, ExternalModule}
+import mill.eval.Evaluator
+import mill.main.{EvaluatorScopt, MainModule}
+import mill.modules.Jvm
+import mill.scalalib.Lib.discoverTests
+import mill.scalalib._
+import mill.scalalib.api.CompilationResult
+import mill.util.{Ctx, DummyLogger}
+import mill.{scalalib, _}
+import os.Path
+
+import scala.collection.JavaConverters._
+
+
+class MillBuildServer(evaluator: Evaluator,
+ _bspVersion: String,
+ serverVersion: String,
+ languages: List[String]) extends ExternalModule with BuildServer with ScalaBuildServer {
+
+ implicit def millScoptEvaluatorReads[T]: EvaluatorScopt[T] = new mill.main.EvaluatorScopt[T]()
+
+ lazy val millDiscover: Discover[MillBuildServer.this.type] = Discover[this.type]
+ val bspVersion: String = _bspVersion
+ val supportedLanguages: List[String] = languages
+ val millServerVersion: String = serverVersion
+ val millEvaluator: Evaluator = evaluator
+ val ctx: Ctx.Log with Ctx.Home = new Ctx.Log with Ctx.Home {
+ val log: DummyLogger.type = mill.util.DummyLogger
+ val home: Path = os.pwd
+ }
+ var cancelator: () => Unit = () => ()
+ var rootModule: JavaModule = ModuleUtils.getRootJavaModule(evaluator.rootModule)
+ var millModules: Seq[JavaModule] = getMillModules(millEvaluator)
+ var client: BuildClient = _
+ var moduleToTargetId: Predef.Map[JavaModule, BuildTargetIdentifier] = ModuleUtils.getModuleTargetIdMap(
+ millModules,
+ evaluator
+ )
+ var targetIdToModule: Predef.Map[BuildTargetIdentifier, JavaModule] = targetToModule(moduleToTargetId)
+ var moduleToTarget: Predef.Map[JavaModule, BuildTarget] =
+ ModuleUtils.millModulesToBspTargets(millModules, rootModule, evaluator, List("scala", "java"))
+ var moduleCodeToTargetId: Predef.Map[Int, BuildTargetIdentifier] =
+ for ((targetId, module) <- targetIdToModule) yield (module.hashCode(), targetId)
+ var initialized = false
+ var clientInitialized = false
+
+ override def onConnectWithClient(server: BuildClient): Unit =
+ client = server
+
+ override def buildInitialize(params: InitializeBuildParams): CompletableFuture[InitializeBuildResult] = {
+
+ val capabilities = new BuildServerCapabilities
+ capabilities.setCompileProvider(new CompileProvider(List("java", "scala").asJava))
+ capabilities.setRunProvider(new RunProvider(List("java", "scala").asJava))
+ capabilities.setTestProvider(new TestProvider(List("java", "scala").asJava))
+ capabilities.setDependencySourcesProvider(true)
+ capabilities.setInverseSourcesProvider(true)
+ capabilities.setResourcesProvider(true)
+ capabilities.setBuildTargetChangedProvider(false) //TODO: for now it's false, but will try to support this later
+ val future = new CompletableFuture[InitializeBuildResult]()
+ future.complete(new InitializeBuildResult("mill-bsp", millServerVersion, bspVersion, capabilities))
+ initialized = true
+ future
+ }
+
+ override def onBuildInitialized(): Unit = {
+ clientInitialized = true
+ }
+
+ override def buildShutdown(): CompletableFuture[Object] = {
+ handleExceptions[String, Object](_ => "shut down this server".asInstanceOf[Object], "")
+ }
+
+ override def onBuildExit(): Unit = {
+ cancelator()
+ }
+
+ override def workspaceBuildTargets(): CompletableFuture[WorkspaceBuildTargetsResult] = {
+ recomputeTargets()
+ handleExceptions[String, WorkspaceBuildTargetsResult](
+ _ => new WorkspaceBuildTargetsResult(moduleToTarget.values.toList.asJava),
+ "")
+ }
+
+ override def buildTargetSources(sourcesParams: SourcesParams): CompletableFuture[SourcesResult] = {
+ recomputeTargets()
+
+ def computeSourcesResult: SourcesResult = {
+ var items = List[SourcesItem]()
+
+ for (targetId <- sourcesParams.getTargets.asScala) {
+ var itemSources = List[SourceItem]()
+
+ val sources = evaluateInformativeTask(evaluator, targetIdToModule(targetId).sources, Agg.empty[PathRef]).
+ map(pathRef => pathRef.path).toSeq
+ val generatedSources = evaluateInformativeTask(evaluator,
+ targetIdToModule(targetId).generatedSources,
+ Agg.empty[PathRef]).
+ map(pathRef => pathRef.path).toSeq
+
+ for (source <- sources) {
+ itemSources ++= List(
+ new SourceItem(source.toIO.toURI.toString, SourceItemKind.DIRECTORY, false))
+ }
+
+ for (genSource <- generatedSources) {
+ itemSources ++= List(
+ new SourceItem(genSource.toIO.toURI.toString, SourceItemKind.DIRECTORY, true))
+ }
+
+ items ++= List(new SourcesItem(targetId, itemSources.asJava))
+ }
+
+ new SourcesResult(items.asJava)
+ }
+
+ handleExceptions[String, SourcesResult](_ => computeSourcesResult, "")
+ }
+
+ override def buildTargetInverseSources(inverseSourcesParams: InverseSourcesParams):
+ CompletableFuture[InverseSourcesResult] = {
+ recomputeTargets()
+
+ def getInverseSourcesResult: InverseSourcesResult = {
+ val textDocument = inverseSourcesParams.getTextDocument
+ val targets = millModules.filter(m => ModuleUtils.evaluateInformativeTask(
+ millEvaluator, m.allSourceFiles, Seq.empty[PathRef]).
+ map(pathRef => pathRef.path.toIO.toURI.toString).
+ contains(textDocument.getUri)).
+ map(m => moduleToTargetId(m))
+ new InverseSourcesResult(targets.asJava)
+ }
+
+ handleExceptions[String, InverseSourcesResult](_ => getInverseSourcesResult, "")
+ }
+
+ override def buildTargetDependencySources(dependencySourcesParams: DependencySourcesParams):
+ CompletableFuture[DependencySourcesResult] = {
+ recomputeTargets()
+
+ def getDependencySources: DependencySourcesResult = {
+ var items = List[DependencySourcesItem]()
+
+ for (targetId <- dependencySourcesParams.getTargets.asScala) {
+ val millModule = targetIdToModule(targetId)
+ var sources = evaluateInformativeTask(evaluator,
+ millModule.resolveDeps(millModule.transitiveIvyDeps),
+ Agg.empty[PathRef]) ++
+ evaluateInformativeTask(evaluator,
+ millModule.resolveDeps(millModule.compileIvyDeps),
+ Agg.empty[PathRef]) ++
+ evaluateInformativeTask(evaluator,
+ millModule.unmanagedClasspath,
+ Agg.empty[PathRef])
+ millModule match {
+ case _: ScalaModule => sources ++= evaluateInformativeTask(evaluator,
+ millModule.resolveDeps(millModule.asInstanceOf[ScalaModule].scalaLibraryIvyDeps),
+ Agg.empty[PathRef])
+ case _: JavaModule => sources ++= List()
+ }
+ items ++= List(new DependencySourcesItem(targetId, sources.
+ map(pathRef => pathRef.path.toIO.toURI.toString).
+ toList.asJava))
+ }
+
+ new DependencySourcesResult(items.asJava)
+ }
+
+ handleExceptions[String, DependencySourcesResult](_ => getDependencySources, "")
+ }
+
+ // Recompute the modules in the project in case any changes to the build took place
+ // and update all the mappings that depend on this info
+ private[this] def recomputeTargets(): Unit = {
+ rootModule = ModuleUtils.getRootJavaModule(millEvaluator.rootModule)
+ millModules = getMillModules(millEvaluator)
+ moduleToTargetId = ModuleUtils.getModuleTargetIdMap(millModules, millEvaluator)
+ targetIdToModule = targetToModule(moduleToTargetId)
+ moduleToTarget = ModuleUtils.millModulesToBspTargets(millModules, rootModule, evaluator, List("scala", "java"))
+ }
+
+ // Given the mapping from modules to targetIds, construct the mapping from targetIds to modules
+ private[this] def targetToModule(moduleToTargetId: Predef.Map[JavaModule, BuildTargetIdentifier]):
+ Predef.Map[BuildTargetIdentifier, JavaModule] = {
+ moduleToTargetId.keys.map(mod => (moduleToTargetId(mod), mod)).toMap
+
+ }
+
+ // Resolve all the mill modules contained in the project
+ private[this] def getMillModules(ev: Evaluator): Seq[JavaModule] = {
+ ev.rootModule.millInternal.segmentsToModules.values.
+ collect {
+ case m: scalalib.JavaModule => m
+ }.toSeq ++ Seq(rootModule)
+ }
+
+ // Given a function that take input of type T and return output of type V,
+ // apply the function on the given inputs and return a completable future of
+ // the result. If the execution of the function raises an Exception, complete
+ // the future exceptionally. Also complete exceptionally if the server was not
+ // yet initialized.
+ private[this] def handleExceptions[T, V](serverMethod: T => V, input: T): CompletableFuture[V] = {
+ val future = new CompletableFuture[V]()
+ if (initialized) {
+ try {
+ future.complete(serverMethod(input))
+ } catch {
+ case e: Exception => future.completeExceptionally(e)
+ }
+ } else {
+ future.completeExceptionally(
+ new Exception("Can not respond to any request before receiving the `initialize` request.")
+ )
+ }
+ future
+ }
+
+ override def buildTargetResources(resourcesParams: ResourcesParams): CompletableFuture[ResourcesResult] = {
+ recomputeTargets()
+
+ def getResources: ResourcesResult = {
+ var items = List[ResourcesItem]()
+
+ for (targetId <- resourcesParams.getTargets.asScala) {
+ val millModule = targetIdToModule(targetId)
+ val resources = evaluateInformativeTask(evaluator, millModule.resources, Agg.empty[PathRef]).
+ filter(pathRef => os.exists(pathRef.path)).
+ flatMap(pathRef => os.walk(pathRef.path)).
+ map(path => path.toIO.toURI.toString).
+ toList.asJava
+ items ++= List(new ResourcesItem(targetId, resources))
+ }
+ new ResourcesResult(items.asJava)
+ }
+
+ handleExceptions[String, ResourcesResult](_ => getResources, "")
+ }
+
+ //TODO: if the client wants to give compilation arguments and the module
+ // already has some from the build file, what to do?
+ override def buildTargetCompile(compileParams: CompileParams): CompletableFuture[CompileResult] = {
+ recomputeTargets()
+
+ def getCompileResult: CompileResult = {
+ val params = TaskParameters.fromCompileParams(compileParams)
+ val taskId = params.hashCode()
+ val compileTasks = Strict.Agg(params.getTargets.
+ filter(targetId => targetId != moduleToTarget(rootModule).getId).
+ map(targetId => targetIdToModule(targetId).compile): _*)
+ val result = millEvaluator.evaluate(compileTasks,
+ getBspLoggedReporterPool(params, t => s"Started compiling target: $t",
+ TaskDataKind.COMPILE_TASK, (targetId: BuildTargetIdentifier) => new CompileTask(targetId)),
+ DummyTestReporter,
+ new MillBspLogger(client, taskId, millEvaluator.log)
+ )
+ val compileResult = new CompileResult(getStatusCode(result))
+ compileResult.setOriginId(compileParams.getOriginId)
+ compileResult //TODO: See in what form IntelliJ expects data about products of compilation in order to set data field
+ }
+
+ handleExceptions[String, CompileResult](_ => getCompileResult, "")
+ }
+
+ override def buildTargetRun(runParams: RunParams): CompletableFuture[RunResult] = {
+ recomputeTargets()
+
+ def getRunResult: RunResult = {
+ val params = TaskParameters.fromRunParams(runParams)
+ val module = targetIdToModule(params.getTargets.head)
+ val args = params.getArguments.getOrElse(Seq.empty[String])
+ val runTask = module.run(args: _*)
+ val runResult = millEvaluator.evaluate(Strict.Agg(runTask),
+ getBspLoggedReporterPool(
+ params,
+ t => s"Started compiling target: $t",
+ TaskDataKind.COMPILE_TASK,
+ (targetId: BuildTargetIdentifier) => new CompileTask(targetId)),
+ logger = new MillBspLogger(client, runTask.hashCode(), millEvaluator.log))
+ val response = runResult.results(runTask) match {
+ case _: Result.Success[Any] => new RunResult(StatusCode.OK)
+ case _ => new RunResult(StatusCode.ERROR)
+ }
+ params.getOriginId match {
+ case Some(id) => response.setOriginId(id)
+ case None =>
+ }
+ response
+ }
+
+ handleExceptions[String, RunResult](_ => getRunResult, "")
+ }
+
+ override def buildTargetTest(testParams: TestParams): CompletableFuture[TestResult] = {
+ recomputeTargets()
+
+ def getTestResult: TestResult = {
+ val params = TaskParameters.fromTestParams(testParams)
+ val argsMap = try {
+ val scalaTestParams = testParams.getData.asInstanceOf[JsonObject]
+ (for (testItem <- scalaTestParams.get("testClasses").getAsJsonArray.asScala)
+ yield (
+ testItem.getAsJsonObject.get("target").getAsJsonObject.get("uri").getAsString,
+ testItem.getAsJsonObject.get("classes").getAsJsonArray
+ .asScala.map(elem => elem.getAsString).toSeq)).toMap
+ } catch {
+ case _: Exception => (for (targetId <- testParams.getTargets.asScala) yield
+ (targetId.getUri, Seq.empty[String])).toMap
+
+ }
+
+ var overallStatusCode = StatusCode.OK
+ for (targetId <- testParams.getTargets.asScala) {
+ val module = targetIdToModule(targetId)
+ module match {
+ case m: TestModule => val testModule = m.asInstanceOf[TestModule]
+ val testTask = testModule.testLocal(argsMap(targetId.getUri): _*)
+
+ // notifying the client that the testing of this build target started
+ val taskStartParams = new TaskStartParams(new TaskId(testTask.hashCode().toString))
+ taskStartParams.setEventTime(System.currentTimeMillis())
+ taskStartParams.setMessage("Testing target: " + targetId)
+ taskStartParams.setDataKind(TaskDataKind.TEST_TASK)
+ taskStartParams.setData(new TestTask(targetId))
+ client.onBuildTaskStart(taskStartParams)
+
+ val testReporter = new BspTestReporter(
+ client, targetId,
+ new TaskId(testTask.hashCode().toString),
+ Seq.empty[String])
+
+ val results = millEvaluator.evaluate(
+ Strict.Agg(testTask),
+ getBspLoggedReporterPool(params, t => s"Started compiling target: $t",
+ TaskDataKind.COMPILE_TASK, (targetId: BuildTargetIdentifier) => new CompileTask(targetId)),
+ testReporter,
+ new MillBspLogger(client, testTask.hashCode, millEvaluator.log))
+ val endTime = System.currentTimeMillis()
+ val statusCode = getStatusCode(results)
+ statusCode match {
+ case StatusCode.ERROR => overallStatusCode = StatusCode.ERROR
+ case StatusCode.CANCELLED => overallStatusCode =
+ if (overallStatusCode == StatusCode.ERROR) StatusCode.ERROR else StatusCode.CANCELLED
+ case StatusCode.OK =>
+ }
+ // notifying the client that the testing of this build target ended
+ val taskFinishParams = new TaskFinishParams(
+ new TaskId(testTask.hashCode().toString),
+ statusCode
+ )
+ taskFinishParams.setEventTime(endTime)
+ taskFinishParams.setMessage("Finished testing target: " +
+ moduleToTarget(targetIdToModule(targetId)).getDisplayName)
+ taskFinishParams.setDataKind(TaskDataKind.TEST_REPORT)
+ taskFinishParams.setData(testReporter.getTestReport)
+ client.onBuildTaskFinish(taskFinishParams)
+
+ case _ =>
+ }
+ }
+ val testResult = new TestResult(overallStatusCode)
+ params.getOriginId match {
+ case None => testResult
+ case Some(id) =>
+ //TODO: Add the messages from mill to the data field?
+ testResult.setOriginId(id)
+ testResult
+ }
+ }
+
+ handleExceptions[String, TestResult](_ => getTestResult, "")
+ }
+
+ // define the function that spawns compilation reporter for each module based on the
+ // module's hash code TODO: find something more reliable than the hash code
+ private[this] def getBspLoggedReporterPool(params: Parameters,
+ taskStartMessage: String => String,
+ taskStartDataKind: String,
+ taskStartData: BuildTargetIdentifier => Object):
+ Int => Option[BuildProblemReporter] = {
+ int: Int =>
+ if (moduleCodeToTargetId.contains(int)) {
+ val targetId = moduleCodeToTargetId(int)
+ val taskId = new TaskId(targetIdToModule(targetId).compile.hashCode.toString)
+ val taskStartParams = new TaskStartParams(taskId)
+ taskStartParams.setEventTime(System.currentTimeMillis())
+ taskStartParams.setData(taskStartData(targetId))
+ taskStartParams.setDataKind(taskStartDataKind)
+ taskStartParams.setMessage(taskStartMessage(moduleToTarget(targetIdToModule(targetId)).getDisplayName))
+ client.onBuildTaskStart(taskStartParams)
+ Option(new BspLoggedReporter(client,
+ targetId,
+ taskId,
+ params.getOriginId))
+ }
+ else None
+ }
+
+ // Get the execution status code given the results from Evaluator.evaluate
+ private[this] def getStatusCode(results: Evaluator.Results): StatusCode = {
+
+ val statusCodes = results.results.keys.map(task => getStatusCodePerTask(results, task)).toSeq
+ if (statusCodes.contains(StatusCode.ERROR))
+ StatusCode.ERROR
+ else if (statusCodes.contains(StatusCode.CANCELLED))
+ StatusCode.CANCELLED
+ else
+ StatusCode.OK
+ }
+
+ private[this] def getStatusCodePerTask(results: Evaluator.Results, task: mill.define.Task[_]): StatusCode = {
+ results.results(task) match {
+ case _: Success[_] => StatusCode.OK
+ case Skipped => StatusCode.CANCELLED
+ case _ => StatusCode.ERROR
+ }
+ }
+
+ override def buildTargetCleanCache(cleanCacheParams: CleanCacheParams): CompletableFuture[CleanCacheResult] = {
+ recomputeTargets()
+
+ def getCleanCacheResult: CleanCacheResult = {
+ var msg = ""
+ var cleaned = true
+ for (targetId <- cleanCacheParams.getTargets.asScala) {
+ val module = targetIdToModule(targetId)
+ val mainModule = new MainModule {
+ override implicit def millDiscover: Discover[_] = {
+ Discover[this.type]
+ }
+ }
+ val cleanTask = mainModule.clean(millEvaluator, List(s"${module.millModuleSegments.render}.compile"): _*)
+ val cleanResult = millEvaluator.evaluate(
+ Strict.Agg(cleanTask),
+ logger = new MillBspLogger(client, cleanTask.hashCode, millEvaluator.log)
+ )
+ if (cleanResult.failing.keyCount > 0) {
+ cleaned = false
+ msg += s" Target ${module.millModuleSegments.render} could not be cleaned. See message from mill: \n"
+ cleanResult.results(cleanTask) match {
+ case fail: Result.Failure[Any] => msg += fail.msg + "\n"
+ case _ => msg += "could not retrieve message"
+ }
+ } else {
+ msg += s"${module.millModuleSegments.render} cleaned \n"
+
+ val outDir = Evaluator.resolveDestPaths(os.pwd / "out", module.millModuleSegments ++
+ Seq(Label("compile"))).out
+ while (os.exists(outDir)) {
+ Thread.sleep(10)
+ }
+ }
+ }
+ new CleanCacheResult(msg, cleaned)
+ }
+
+ handleExceptions[String, CleanCacheResult](_ => getCleanCacheResult, "")
+ }
+
+ override def buildTargetScalacOptions(scalacOptionsParams: ScalacOptionsParams):
+ CompletableFuture[ScalacOptionsResult] = {
+ recomputeTargets()
+
+ def getScalacOptionsResult: ScalacOptionsResult = {
+ var targetScalacOptions = List.empty[ScalacOptionsItem]
+ for (targetId <- scalacOptionsParams.getTargets.asScala) {
+ val module = targetIdToModule(targetId)
+ module match {
+ case m: ScalaModule =>
+ val options = evaluateInformativeTask(evaluator, m.scalacOptions, Seq.empty[String]).toList
+ val classpath = evaluateInformativeTask(evaluator, m.runClasspath, Agg.empty[PathRef]).
+ map(pathRef => pathRef.path.toIO.toURI.toString).toList
+ val classDirectory = (Evaluator.resolveDestPaths(
+ os.pwd / "out",
+ m.millModuleSegments ++ Seq(Label("compile"))).dest / "classes"
+ ).toIO.toURI.toString
+
+ targetScalacOptions ++= List(new ScalacOptionsItem(targetId, options.asJava, classpath.asJava, classDirectory))
+ case _: JavaModule => targetScalacOptions ++= List()
+ }
+ }
+ new ScalacOptionsResult(targetScalacOptions.asJava)
+ }
+
+ handleExceptions[String, ScalacOptionsResult](_ => getScalacOptionsResult, "")
+ }
+
+ //TODO: In the case when mill fails to provide a main classes because multiple were
+ // defined for the same module, do something so that those can still be detected
+ // such that IntelliJ can run any of them
+ override def buildTargetScalaMainClasses(scalaMainClassesParams: ScalaMainClassesParams):
+ CompletableFuture[ScalaMainClassesResult] = {
+ recomputeTargets()
+
+ def getScalaMainClasses: ScalaMainClassesResult = {
+ var items = List.empty[ScalaMainClassesItem]
+ for (targetId <- scalaMainClassesParams.getTargets.asScala) {
+ val module = targetIdToModule(targetId)
+ val scalaMainClasses = getTaskResult(millEvaluator, module.finalMainClassOpt) match {
+ case result: Result.Success[Any] => result.asSuccess.get.value match {
+ case mainClass: Right[String, String] =>
+ List(new ScalaMainClass(
+ mainClass.value,
+ List.empty[String].asJava,
+ evaluateInformativeTask(evaluator, module.forkArgs, Seq.empty[String]).
+ toList.asJava))
+ case msg: Left[String, String] =>
+ val messageParams = new ShowMessageParams(MessageType.WARNING, msg.value)
+ messageParams.setOriginId(scalaMainClassesParams.getOriginId)
+ client.onBuildShowMessage(messageParams) // tell the client that no main class was found or specified
+ List.empty[ScalaMainClass]
+ }
+ case _ => List.empty[ScalaMainClass]
+ }
+ val item = new ScalaMainClassesItem(targetId, scalaMainClasses.asJava)
+ items ++= List(item)
+ }
+ new ScalaMainClassesResult(items.asJava)
+ }
+
+ handleExceptions[String, ScalaMainClassesResult](_ => getScalaMainClasses, "")
+ }
+
+ override def buildTargetScalaTestClasses(scalaTestClassesParams: ScalaTestClassesParams):
+ CompletableFuture[ScalaTestClassesResult] = {
+ recomputeTargets()
+
+ def getScalaTestClasses(implicit ctx: Ctx.Home): ScalaTestClassesResult = {
+ var items = List.empty[ScalaTestClassesItem]
+ for (targetId <- scalaTestClassesParams.getTargets.asScala) {
+ targetIdToModule(targetId) match {
+ case module: TestModule =>
+ items ++= List(new ScalaTestClassesItem(targetId, getTestClasses(module).toList.asJava))
+ case _: JavaModule => //TODO: maybe send a notification that this target has no test classes
+ }
+ }
+ new ScalaTestClassesResult(items.asJava)
+ }
+
+ handleExceptions[Ctx.Home, ScalaTestClassesResult](c => getScalaTestClasses(c), ctx)
+ }
+
+ // Detect and return the test classes contained in the given TestModule
+ private[this] def getTestClasses(module: TestModule)(implicit ctx: Ctx.Home): Seq[String] = {
+ val runClasspath = getTaskResult(millEvaluator, module.runClasspath)
+ val frameworks = getTaskResult(millEvaluator, module.testFrameworks)
+ val compilationResult = getTaskResult(millEvaluator, module.compile)
+
+ (runClasspath, frameworks, compilationResult) match {
+ case (Result.Success(classpath), Result.Success(testFrameworks), Result.Success(compResult)) =>
+ val classFingerprint = Jvm.inprocess(classpath.asInstanceOf[Seq[PathRef]].map(_.path),
+ classLoaderOverrideSbtTesting = true,
+ isolated = true,
+ closeContextClassLoaderWhenDone = false, cl => {
+ val fs = TestRunner.frameworks(testFrameworks.asInstanceOf[Seq[String]])(cl)
+ fs.flatMap(framework =>
+ discoverTests(cl, framework, Agg(compResult.asInstanceOf[CompilationResult].
+ classes.path)))
+ })
+ classFingerprint.map(classF => classF._1.getName.stripSuffix("$"))
+ case _ => Seq.empty[String] //TODO: or send notification that something went wrong
+ }
+ }
+
+}