summaryrefslogtreecommitdiff
path: root/main/core/src/define/Discover.scala
diff options
context:
space:
mode:
Diffstat (limited to 'main/core/src/define/Discover.scala')
-rw-r--r--main/core/src/define/Discover.scala89
1 files changed, 89 insertions, 0 deletions
diff --git a/main/core/src/define/Discover.scala b/main/core/src/define/Discover.scala
new file mode 100644
index 00000000..f0c668e6
--- /dev/null
+++ b/main/core/src/define/Discover.scala
@@ -0,0 +1,89 @@
+package mill.define
+import mill.util.Router.EntryPoint
+
+import language.experimental.macros
+import sourcecode.Compat.Context
+
+import scala.collection.mutable
+import scala.reflect.macros.blackbox
+
+
+
+case class Discover[T](value: Map[Class[_], Seq[(Int, EntryPoint[_])]])
+object Discover {
+ def apply[T]: Discover[T] = macro applyImpl[T]
+
+ def applyImpl[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[Discover[T]] = {
+ import c.universe._
+ import compat._
+ val seen = mutable.Set.empty[Type]
+ def rec(tpe: Type): Unit = {
+ if (!seen(tpe)){
+ seen.add(tpe)
+ for{
+ m <- tpe.members
+ memberTpe = m.typeSignature
+ if memberTpe.resultType <:< typeOf[mill.define.Module] && memberTpe.paramLists.isEmpty
+ } rec(memberTpe.resultType)
+
+ if (tpe <:< typeOf[mill.define.Cross[_]]){
+ val inner = typeOf[Cross[_]]
+ .typeSymbol
+ .asClass
+ .typeParams
+ .head
+ .asType
+ .toType
+ .asSeenFrom(tpe, typeOf[Cross[_]].typeSymbol)
+
+ rec(inner)
+ }
+ }
+ }
+ rec(weakTypeOf[T])
+
+ def assertParamListCounts(methods: Iterable[router.c.universe.MethodSymbol],
+ cases: (c.Type, Int, String)*) = {
+ for (m <- methods.toList){
+ for ((tt, n, label) <- cases){
+ if (m.returnType <:< tt.asInstanceOf[router.c.Type] &&
+ m.paramLists.length != n){
+ c.abort(
+ m.pos.asInstanceOf[c.Position],
+ s"$label definitions must have $n parameter list" + (if (n == 1) "" else "s")
+ )
+ }
+ }
+ }
+ }
+ val router = new mill.util.Router(c)
+ val mapping = for{
+ discoveredModuleType <- seen
+ val curCls = discoveredModuleType.asInstanceOf[router.c.Type]
+ val methods = router.getValsOrMeths(curCls)
+ val overridesRoutes = {
+ assertParamListCounts(
+ methods,
+ (weakTypeOf[mill.define.Sources], 0, "`T.sources`"),
+ (weakTypeOf[mill.define.Input[_]], 0, "`T.input`"),
+ (weakTypeOf[mill.define.Persistent[_]], 0, "`T.persistent`"),
+ (weakTypeOf[mill.define.Target[_]], 0, "`T{...}`"),
+ (weakTypeOf[mill.define.Command[_]], 1, "`T.command`")
+ )
+
+ for{
+ m <- methods.toList
+ if m.returnType <:< weakTypeOf[mill.define.Command[_]].asInstanceOf[router.c.Type]
+ } yield (m.overrides.length, router.extractMethod(m, curCls).asInstanceOf[c.Tree])
+
+ }
+ if overridesRoutes.nonEmpty
+ } yield {
+ val lhs = q"classOf[${discoveredModuleType.typeSymbol.asClass}]"
+ val rhs = q"scala.Seq[(Int, mill.util.Router.EntryPoint[_])](..$overridesRoutes)"
+ q"$lhs -> $rhs"
+ }
+
+ c.Expr[Discover[T]](q"mill.define.Discover(scala.collection.immutable.Map(..$mapping))")
+ }
+}