diff options
author | Li Haoyi <haoyi.sg@gmail.com> | 2017-11-02 08:28:36 -0700 |
---|---|---|
committer | Li Haoyi <haoyi.sg@gmail.com> | 2017-11-02 08:28:36 -0700 |
commit | 66f1c5c2438aeb8f2496575f52c25b09cf5793a6 (patch) | |
tree | 26abe57db2e2176e9f7a974431790235c5385fed /src | |
parent | bfbbe450d4ac330f83fb28334e57789f3130a51c (diff) | |
download | mill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.tar.gz mill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.tar.bz2 mill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.zip |
`T.raw` macro now works without needing `c.untypecheck`
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/forge/Target.scala | 46 | ||||
-rw-r--r-- | src/main/scala/forge/package.scala | 15 | ||||
-rw-r--r-- | src/test/scala/forge/MetacircularTests.scala | 6 |
3 files changed, 35 insertions, 32 deletions
diff --git a/src/main/scala/forge/Target.scala b/src/main/scala/forge/Target.scala index a6a8eda7..6a465a92 100644 --- a/src/main/scala/forge/Target.scala +++ b/src/main/scala/forge/Target.scala @@ -47,39 +47,33 @@ object Target{ def raw[T](t: T): Target[T] = macro impl[T] def impl[T: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[Target[T]] = { import c.universe._ - val bound = collection.mutable.Buffer.empty[(c.Tree, c.TermName)] - - object transformer extends c.universe.Transformer{ - override def transform(tree: c.Tree): c.Tree = tree match{ - case q"$fun.apply()" if fun.tpe <:< weakTypeOf[Target[_]] => - val newTerm = TermName(c.freshName()) - bound.append((fun, newTerm)) - val ident = Ident(newTerm) - ident + val bound = collection.mutable.Buffer.empty[(c.Tree, Symbol)] + val OptionGet = c.universe.typeOf[Target[_]].member(TermName("apply")) + object transformer extends c.universe.Transformer { + // Derived from @olafurpg's + // https://gist.github.com/olafurpg/596d62f87bf3360a29488b725fbc7608 + override def transform(tree: c.Tree): c.Tree = tree match { + case t @ q"$fun.apply()" if t.symbol == OptionGet => + val tempName = c.freshName(TermName("tmp")) + val tempSym = c.internal.newTermSymbol(c.internal.enclosingOwner, tempName) + c.internal.setInfo(tempSym, t.tpe) + val tempIdent = Ident(tempSym) + c.internal.setType(tempIdent, t.tpe) + bound.append((fun, tempSym)) + tempIdent case _ => super.transform(tree) } } - val transformed = transformer.transform(t.tree) - val (exprs, names) = bound.unzip - val embedded = bound.length match{ - case 0 => transformed - case 1 => q"zip(..$exprs).map{ case ${names(0)} => $transformed }" - case n => - - // For some reason, pq"(..$names)" doesn't work... - val pq = n match{ - case 2 => pq"(${names(0)}, ${names(1)})" - case 3 => pq"(${names(0)}, ${names(1)}, ${names(2)})" - case 4 => pq"(${names(0)}, ${names(1)}, ${names(2)}, ${names(3)})" - case 5 => pq"(${names(0)}, ${names(1)}, ${names(2)}, ${names(3)}, ${names(4)})" - } - q"zip(..$exprs).map{ case $pq => $transformed }" - } + val (exprs, symbols) = bound.unzip + val bindings = symbols.map(c.internal.valDef(_)) - c.Expr[Target[T]](c.untypecheck(embedded)) + val embedded = q"forge.zipMap(..$exprs){ (..$bindings) => $transformed }" + + c.Expr[Target[T]](embedded) } + abstract class Ops[T]{ this: Target[T] => def map[V](f: T => V) = new Target.Mapped(this, f) diff --git a/src/main/scala/forge/package.scala b/src/main/scala/forge/package.scala index f7635b9b..6a52d1a1 100644 --- a/src/main/scala/forge/package.scala +++ b/src/main/scala/forge/package.scala @@ -6,17 +6,24 @@ package object forge { val T = Target type T[T] = Target[T] - def zip[A](a: T[A]) = a + def zipMap[R]()(f: () => R) = T(f()) + def zipMap[A, R](a: T[A])(f: A => R) = a.map(f) + def zipMap[A, B, R](a: T[A], b: T[B])(f: (A, B) => R) = zip(a, b).map(f.tupled) + def zipMap[A, B, C, R](a: T[A], b: T[B], c: T[C])(f: (A, B, C) => R) = zip(a, b, c).map(f.tupled) + def zipMap[A, B, C, D, R](a: T[A], b: T[B], c: T[C], d: T[D])(f: (A, B, C, D) => R) = zip(a, b, c, d).map(f.tupled) + def zipMap[A, B, C, D, E, R](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E])(f: (A, B, C, D, E) => R) = zip(a, b, c, d, e).map(f.tupled) + def zip() = T(()) + def zip[A](a: T[A]) = a.map(Tuple1(_)) def zip[A, B](a: T[A], b: T[B]) = a.zip(b) - def zip[A, B, C](a: T[A], b: T[B], c: T[C]) = new Target[(A, B, C)]{ + def zip[A, B, C](a: T[A], b: T[B], c: T[C]) = new T[(A, B, C)]{ val inputs = Seq(a, b, c) def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2)) } - def zip[A, B, C, D](a: T[A], b: T[B], c: T[C], d: T[D]) = new Target[(A, B, C, D)]{ + def zip[A, B, C, D](a: T[A], b: T[B], c: T[C], d: T[D]) = new T[(A, B, C, D)]{ val inputs = Seq(a, b, c, d) def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2), args[D](3)) } - def zip[A, B, C, D, E](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E]) = new Target[(A, B, C, D, E)]{ + def zip[A, B, C, D, E](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E]) = new T[(A, B, C, D, E)]{ val inputs = Seq(a, b, c, d, e) def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2), args[D](3), args[E](4)) } diff --git a/src/test/scala/forge/MetacircularTests.scala b/src/test/scala/forge/MetacircularTests.scala index f11b4af5..7d0a4c1a 100644 --- a/src/test/scala/forge/MetacircularTests.scala +++ b/src/test/scala/forge/MetacircularTests.scala @@ -9,10 +9,11 @@ object MetacircularTests extends TestSuite{ object Self extends scalaplugin.Subproject { val scalaVersion = T{ "2.12.4" } override val compileDeps = T.raw{ - Seq(Dep(Mod("org.scala-lang", "scala-reflect"), scalaVersion(), configuration = "provided")) + Seq( + Dep(Mod("org.scala-lang", "scala-reflect"), scalaVersion(), configuration = "provided") + ) } - override val deps = T.raw{ Seq( Dep(Mod("com.lihaoyi", "sourcecode_" + scalaBinaryVersion()), "0.1.4"), @@ -42,3 +43,4 @@ object MetacircularTests extends TestSuite{ } } } + |