From 5aa363015646644cc81afdf0120d8df441161e2d Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sun, 7 Jan 2018 23:53:11 -0800 Subject: Swap over to a new, concise `CrossModule[T](..cases)` macro syntax that automatically propagates the `ctx` for you --- core/src/main/scala/mill/define/Cross.scala | 104 ++++++++++++--------- core/src/main/scala/mill/define/Module.scala | 2 +- core/src/main/scala/mill/package.scala | 4 +- .../scala/mill/discover/CrossModuleTests.scala | 8 +- core/src/test/scala/mill/util/TestGraphs.scala | 24 ++--- 5 files changed, 76 insertions(+), 66 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/mill/define/Cross.scala b/core/src/main/scala/mill/define/Cross.scala index 4d68ee62..4b0e1b8f 100644 --- a/core/src/main/scala/mill/define/Cross.scala +++ b/core/src/main/scala/mill/define/Cross.scala @@ -1,4 +1,7 @@ package mill.define +import language.experimental.macros +import scala.reflect.ClassTag +import scala.reflect.macros.{Context, blackbox} case class Cross[+T](items: List[(List[Any], T)])(implicit val e: sourcecode.Enclosing, val l: sourcecode.Line){ def flatMap[V](f: T => Cross[V]): Cross[V] = new Cross( @@ -26,53 +29,62 @@ object Cross{ def apply[T](t: T*) = new Cross(t.map(i => List(i) -> i).toList) } -class CrossModule[T, V](constructor: (T, Module.Ctx) => V, cases: T*) - (implicit ctx: Module.Ctx) -extends Cross[V]({ - cases.toList.map(x => - ( - List(x), - constructor( - x, - ctx.copy( - segments0 = Segments(ctx.segments0.value :+ ctx.segment), - segment = Segment.Cross(List(x)) - ) - ) - ) - ) -}) +object CrossModule{ + def autoCast[A](x: Any): A = x.asInstanceOf[A] + abstract class Implicit[T]{ + def make(v: Any, ctx: Module.Ctx): T + def crossValues(v: Any): List[Any] + } + object Implicit{ + implicit def make[T]: Implicit[T] = macro makeImpl[T] + def makeImpl[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[Implicit[T]] = { + import c.universe._ + val tpe = weakTypeOf[T] -class CrossModule2[T1, T2, V](constructor: (T1, T2, Module.Ctx) => V, cases: (T1, T2)*) - (implicit ctx: Module.Ctx) -extends Cross[V]( - cases.toList.map(x => - ( - List(x._2, x._1), - constructor( - x._1, x._2, - ctx.copy( - segments0 = Segments(ctx.segments0.value :+ ctx.segment), - segment = Segment.Cross(List(x._2, x._1)) - ) - ) - ) - ) -) + val primaryConstructorArgs = + tpe.typeSymbol.asClass.primaryConstructor.typeSignature.paramLists.head -class CrossModule3[T1, T2, T3, V](constructor: (T1, T2, T3, Module.Ctx) => V, cases: (T1, T2, T3)*) - (implicit ctx: Module.Ctx) -extends Cross[V]( - cases.toList.map(x => - ( - List(x._3, x._2, x._1), - constructor( - x._1, x._2, x._3, - ctx.copy( - segments0 = Segments(ctx.segments0.value :+ ctx.segment), - segment = Segment.Cross(List(x._3, x._2, x._1)) - ) + val tree = primaryConstructorArgs match{ + case List(arg) => + q""" + new mill.define.CrossModule.Implicit[$tpe]{ + def make(v: Any, ctx0: mill.define.Module.Ctx) = new $tpe(v.asInstanceOf[${arg.info}]){ + override def ctx = ctx0 + } + def crossValues(v: Any) = List(v) + } + """ + case args => + val argTupleValues = for((a, n) <- args.zipWithIndex) yield{ + q"v.asInstanceOf[scala.Product].productElement($n).asInstanceOf[${a.info}]" + } + q""" + new mill.define.CrossModule.Implicit[$tpe]{ + def make(v: Any, ctx0: mill.define.Module.Ctx) = new $tpe(..$argTupleValues){ + override def ctx = ctx0 + } + def crossValues(v: Any) = List(..$argTupleValues) + } + """ + + } + c.Expr[Implicit[T]](tree) + } + } +} +class CrossModule[T](cases: Any*) + (implicit ci: CrossModule.Implicit[T], + ctx: Module.Ctx) +extends Cross[T]({ + for(c <- cases.toList) yield{ + val crossValues = ci.crossValues(c) + val sub = ci.make( + c, + ctx.copy( + segments0 = Segments(ctx.segments0.value :+ ctx.segment), + segment = Segment.Cross(crossValues.reverse) ) ) - ) -) \ No newline at end of file + (crossValues.reverse, sub) + } +}) \ No newline at end of file diff --git a/core/src/main/scala/mill/define/Module.scala b/core/src/main/scala/mill/define/Module.scala index c5837278..9c830c1a 100644 --- a/core/src/main/scala/mill/define/Module.scala +++ b/core/src/main/scala/mill/define/Module.scala @@ -20,7 +20,7 @@ case class Segments(value: Seq[Segment]) * the concrete instance. */ class Module(implicit ctx0: Module.Ctx) extends mill.moduledefs.Cacher{ - val ctx = ctx0 + def ctx = ctx0 // Ensure we do not propagate the implicit parameters as implicits within // the body of any inheriting class/trait/objects, as it would screw up any // one else trying to use sourcecode.{Enclosing,Line} to capture debug info diff --git a/core/src/main/scala/mill/package.scala b/core/src/main/scala/mill/package.scala index b5ab6429..b578ed3b 100644 --- a/core/src/main/scala/mill/package.scala +++ b/core/src/main/scala/mill/package.scala @@ -7,7 +7,5 @@ package object mill extends JsonFormatters{ type PathRef = mill.eval.PathRef type Module = define.Module val Module = define.Module - type CrossModule[T, V] = define.CrossModule[T, V] - type CrossModule2[T1, T2, V] = define.CrossModule2[T1, T2, V] - type CrossModule3[T1, T2, T3, V] = define.CrossModule3[T1, T2, T3, V] + type CrossModule[T] = define.CrossModule[T] } diff --git a/core/src/test/scala/mill/discover/CrossModuleTests.scala b/core/src/test/scala/mill/discover/CrossModuleTests.scala index 7f5cb89c..ff3b52b1 100644 --- a/core/src/test/scala/mill/discover/CrossModuleTests.scala +++ b/core/src/test/scala/mill/discover/CrossModuleTests.scala @@ -14,8 +14,8 @@ object CrossModuleTests extends TestSuite{ 'cross - { object outer extends TestUtil.BaseModule { - object crossed extends mill.CrossModule(CrossedModule, "2.10.6", "2.11.8", "2.12.4") - case class CrossedModule(n: String, ctx0: Module.Ctx) extends Module()(ctx0){ + object crossed extends mill.CrossModule[CrossedModule]("2.10.6", "2.11.8", "2.12.4") + class CrossedModule(n: String) extends mill.Module{ def scalaVersion = n } } @@ -42,8 +42,8 @@ object CrossModuleTests extends TestSuite{ scalaVersion <- Seq("2.10.6", "2.11.8", "2.12.4") if !(platform == "native0.3" && scalaVersion == "2.10.6") } yield (platform, scalaVersion) - object crossed extends mill.CrossModule2(CrossModule, crossMatrix:_*) - case class CrossModule(platform: String, scalaVersion: String, ctx0: Module.Ctx) extends mill.Module()(ctx0){ + object crossed extends mill.CrossModule[CrossModule](crossMatrix:_*) + case class CrossModule(platform: String, scalaVersion: String) extends mill.Module{ def suffix = Seq(scalaVersion, platform).filter(_.nonEmpty).map("_"+_).mkString } } diff --git a/core/src/test/scala/mill/util/TestGraphs.scala b/core/src/test/scala/mill/util/TestGraphs.scala index 95ed8bf0..3ddae02d 100644 --- a/core/src/test/scala/mill/util/TestGraphs.scala +++ b/core/src/test/scala/mill/util/TestGraphs.scala @@ -192,8 +192,8 @@ object TestGraphs{ object singleCross extends TestUtil.BaseModule { - object cross extends mill.CrossModule(CrossModule, "210", "211", "212") - case class CrossModule(scalaVersion: String, ctx0: Module.Ctx) extends Module()(ctx0){ + object cross extends mill.CrossModule[CrossModule]("210", "211", "212") + class CrossModule(scalaVersion: String) extends Module{ def suffix = T{ scalaVersion } } } @@ -203,27 +203,27 @@ object TestGraphs{ platform <- Seq("jvm", "js", "native") if !(platform == "native" && scalaVersion != "212") } yield (scalaVersion, platform) - object cross extends mill.CrossModule2(CrossModule, crossMatrix:_*) - case class CrossModule(scalaVersion: String, platform: String, ctx0: Module.Ctx) extends Module()(ctx0){ + object cross extends mill.CrossModule[CrossModule](crossMatrix:_*) + class CrossModule(scalaVersion: String, platform: String) extends Module{ def suffix = T{ scalaVersion + "_" + platform } } } object indirectNestedCrosses extends TestUtil.BaseModule { - object cross extends mill.CrossModule(CrossModule, "210", "211", "212") - case class CrossModule(scalaVersion: String, ctx0: Module.Ctx) extends mill.Module()(ctx0){ - object cross2 extends mill.CrossModule(CrossModule, "jvm", "js", "native") - case class CrossModule(platform: String, ctx0: Module.Ctx) extends mill.Module{ + object cross extends mill.CrossModule[CrossModule]("210", "211", "212") + class CrossModule(scalaVersion: String) extends mill.Module{ + object cross2 extends mill.CrossModule[CrossModule]("jvm", "js", "native") + class CrossModule(platform: String) extends mill.Module{ def suffix = T{ scalaVersion + "_" + platform } } } } object nestedCrosses extends TestUtil.BaseModule { - object cross extends mill.CrossModule(CrossModule, "210", "211", "212") - case class CrossModule(scalaVersion: String, ctx0: Module.Ctx) extends mill.Module()(ctx0){ - object cross2 extends mill.CrossModule(CrossModule, "jvm", "js", "native") - case class CrossModule(platform: String, ctx0: Module.Ctx) extends mill.Module()(ctx0){ + object cross extends mill.CrossModule[CrossModule]("210", "211", "212") + class CrossModule(scalaVersion: String) extends mill.Module{ + object cross2 extends mill.CrossModule[CrossModule]("jvm", "js", "native") + class CrossModule(platform: String) extends mill.Module{ def suffix = T{ scalaVersion + "_" + platform } } } -- cgit v1.2.3