From 9cca64627d31f078c565c6865a50ae558f567d8f Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 12 Jan 2018 20:28:43 -0800 Subject: Allow implicit `crossModule()` syntax to automatically find a set of cross-coordinates which are compatible with a given implicit resolver --- core/src/main/scala/mill/define/Cross.scala | 29 +++++++++++++++++++--- core/src/main/scala/mill/discover/Discovered.scala | 2 +- scalalib/src/main/scala/mill/scalalib/Module.scala | 21 +++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/mill/define/Cross.scala b/core/src/main/scala/mill/define/Cross.scala index 6226590d..2dd4d2fe 100644 --- a/core/src/main/scala/mill/define/Cross.scala +++ b/core/src/main/scala/mill/define/Cross.scala @@ -26,10 +26,15 @@ object Cross{ reify { mill.define.Cross.Factory[T](instance.splice) } } } + + trait Resolver[-T]{ + def resolve[V <: T](c: Cross[V]): V + } } + class Cross[T](cases: Any*) - (implicit ci: Cross.Factory[T], - val ctx: Module.Ctx){ + (implicit ci: Cross.Factory[T], + val ctx: Module.Ctx){ val items = for(c0 <- cases.toList) yield{ val c = c0 match{ @@ -47,5 +52,23 @@ class Cross[T](cases: Any*) (crossValues, sub) } val itemMap = items.toMap - def apply(args: Any*) = itemMap(args.toList) + + /** + * Fetch the cross module corresponding to the given cross values + */ + def get(args: Seq[Any]) = itemMap(args.toList) + + /** + * Fetch the cross module corresponding to the given cross values + */ + def apply(arg0: Any, args: Any*) = itemMap(arg0 :: args.toList) + + /** + * Fetch the relevant cross module given the implicit resolver you have in + * scope. This is often the first cross module whose cross-version is + * compatible with the current module. + */ + def apply[V >: T]()(implicit resolver: Cross.Resolver[V]): T = { + resolver.resolve(this.asInstanceOf[Cross[V]]).asInstanceOf[T] + } } \ No newline at end of file diff --git a/core/src/main/scala/mill/discover/Discovered.scala b/core/src/main/scala/mill/discover/Discovered.scala index 0728a027..c5ca4843 100644 --- a/core/src/main/scala/mill/discover/Discovered.scala +++ b/core/src/main/scala/mill/discover/Discovered.scala @@ -112,7 +112,7 @@ object Discovered { val base = q"${TermName(c.freshName())}" val ident = segments.reverse.zipWithIndex.foldLeft[Tree](base) { case (prefix, (Some(name), i)) => q"$prefix.${TermName(name)}" - case (prefix, (None, i)) => q"$prefix.apply($crossName($i):_*)" + case (prefix, (None, i)) => q"$prefix.get($crossName($i))" } q"($base: $baseType, $crossName: List[List[Any]]) => $ident.asInstanceOf[$t]" } diff --git a/scalalib/src/main/scala/mill/scalalib/Module.scala b/scalalib/src/main/scala/mill/scalalib/Module.scala index de962c9c..a6819435 100644 --- a/scalalib/src/main/scala/mill/scalalib/Module.scala +++ b/scalalib/src/main/scala/mill/scalalib/Module.scala @@ -3,12 +3,13 @@ package scalalib import ammonite.ops._ import coursier.{Cache, MavenRepository, Repository} -import mill.define.Task +import mill.define.{Cross, Task} import mill.define.Task.TaskModule import mill.eval.{PathRef, Result} import mill.modules.Jvm import mill.modules.Jvm.{createAssembly, createJar, interactiveSubprocess, subprocess} import Lib._ +import mill.define.Cross.Resolver import sbt.testing.Status object TestModule{ def handleResults(doneMsg: String, results: Seq[TestRunner.Result]) = { @@ -375,6 +376,24 @@ trait SbtModule extends Module { outer => } trait CrossSbtModule extends SbtModule { outer => + implicit def crossSbtModuleResolver: Resolver[CrossSbtModule] = new Resolver[CrossSbtModule]{ + def resolve[V <: CrossSbtModule](c: Cross[V]): V = { + crossScalaVersion.split('.') + .inits + .takeWhile(_.length > 1) + .flatMap( prefix => + c.items.map(_._2).find(_.crossScalaVersion.split('.').startsWith(prefix)) + ) + .collectFirst{case x => x} + .getOrElse( + throw new Exception( + s"Unable to find compatible cross version between $crossScalaVersion and "+ + c.items.map(_._2.crossScalaVersion).mkString(",") + ) + ) + } + } + def crossScalaVersion: String def scalaVersion = crossScalaVersion override def sources = T.input{ -- cgit v1.2.3