summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2017-11-02 08:28:36 -0700
committerLi Haoyi <haoyi.sg@gmail.com>2017-11-02 08:28:36 -0700
commit66f1c5c2438aeb8f2496575f52c25b09cf5793a6 (patch)
tree26abe57db2e2176e9f7a974431790235c5385fed /src
parentbfbbe450d4ac330f83fb28334e57789f3130a51c (diff)
downloadmill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.tar.gz
mill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.tar.bz2
mill-66f1c5c2438aeb8f2496575f52c25b09cf5793a6.zip
`T.raw` macro now works without needing `c.untypecheck`
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/forge/Target.scala46
-rw-r--r--src/main/scala/forge/package.scala15
-rw-r--r--src/test/scala/forge/MetacircularTests.scala6
3 files changed, 35 insertions, 32 deletions
diff --git a/src/main/scala/forge/Target.scala b/src/main/scala/forge/Target.scala
index a6a8eda7..6a465a92 100644
--- a/src/main/scala/forge/Target.scala
+++ b/src/main/scala/forge/Target.scala
@@ -47,39 +47,33 @@ object Target{
def raw[T](t: T): Target[T] = macro impl[T]
def impl[T: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[Target[T]] = {
import c.universe._
- val bound = collection.mutable.Buffer.empty[(c.Tree, c.TermName)]
-
- object transformer extends c.universe.Transformer{
- override def transform(tree: c.Tree): c.Tree = tree match{
- case q"$fun.apply()" if fun.tpe <:< weakTypeOf[Target[_]] =>
- val newTerm = TermName(c.freshName())
- bound.append((fun, newTerm))
- val ident = Ident(newTerm)
- ident
+ val bound = collection.mutable.Buffer.empty[(c.Tree, Symbol)]
+ val OptionGet = c.universe.typeOf[Target[_]].member(TermName("apply"))
+ object transformer extends c.universe.Transformer {
+ // Derived from @olafurpg's
+ // https://gist.github.com/olafurpg/596d62f87bf3360a29488b725fbc7608
+ override def transform(tree: c.Tree): c.Tree = tree match {
+ case t @ q"$fun.apply()" if t.symbol == OptionGet =>
+ val tempName = c.freshName(TermName("tmp"))
+ val tempSym = c.internal.newTermSymbol(c.internal.enclosingOwner, tempName)
+ c.internal.setInfo(tempSym, t.tpe)
+ val tempIdent = Ident(tempSym)
+ c.internal.setType(tempIdent, t.tpe)
+ bound.append((fun, tempSym))
+ tempIdent
case _ => super.transform(tree)
}
}
-
val transformed = transformer.transform(t.tree)
- val (exprs, names) = bound.unzip
- val embedded = bound.length match{
- case 0 => transformed
- case 1 => q"zip(..$exprs).map{ case ${names(0)} => $transformed }"
- case n =>
-
- // For some reason, pq"(..$names)" doesn't work...
- val pq = n match{
- case 2 => pq"(${names(0)}, ${names(1)})"
- case 3 => pq"(${names(0)}, ${names(1)}, ${names(2)})"
- case 4 => pq"(${names(0)}, ${names(1)}, ${names(2)}, ${names(3)})"
- case 5 => pq"(${names(0)}, ${names(1)}, ${names(2)}, ${names(3)}, ${names(4)})"
- }
- q"zip(..$exprs).map{ case $pq => $transformed }"
- }
+ val (exprs, symbols) = bound.unzip
+ val bindings = symbols.map(c.internal.valDef(_))
- c.Expr[Target[T]](c.untypecheck(embedded))
+ val embedded = q"forge.zipMap(..$exprs){ (..$bindings) => $transformed }"
+
+ c.Expr[Target[T]](embedded)
}
+
abstract class Ops[T]{ this: Target[T] =>
def map[V](f: T => V) = new Target.Mapped(this, f)
diff --git a/src/main/scala/forge/package.scala b/src/main/scala/forge/package.scala
index f7635b9b..6a52d1a1 100644
--- a/src/main/scala/forge/package.scala
+++ b/src/main/scala/forge/package.scala
@@ -6,17 +6,24 @@ package object forge {
val T = Target
type T[T] = Target[T]
- def zip[A](a: T[A]) = a
+ def zipMap[R]()(f: () => R) = T(f())
+ def zipMap[A, R](a: T[A])(f: A => R) = a.map(f)
+ def zipMap[A, B, R](a: T[A], b: T[B])(f: (A, B) => R) = zip(a, b).map(f.tupled)
+ def zipMap[A, B, C, R](a: T[A], b: T[B], c: T[C])(f: (A, B, C) => R) = zip(a, b, c).map(f.tupled)
+ def zipMap[A, B, C, D, R](a: T[A], b: T[B], c: T[C], d: T[D])(f: (A, B, C, D) => R) = zip(a, b, c, d).map(f.tupled)
+ def zipMap[A, B, C, D, E, R](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E])(f: (A, B, C, D, E) => R) = zip(a, b, c, d, e).map(f.tupled)
+ def zip() = T(())
+ def zip[A](a: T[A]) = a.map(Tuple1(_))
def zip[A, B](a: T[A], b: T[B]) = a.zip(b)
- def zip[A, B, C](a: T[A], b: T[B], c: T[C]) = new Target[(A, B, C)]{
+ def zip[A, B, C](a: T[A], b: T[B], c: T[C]) = new T[(A, B, C)]{
val inputs = Seq(a, b, c)
def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2))
}
- def zip[A, B, C, D](a: T[A], b: T[B], c: T[C], d: T[D]) = new Target[(A, B, C, D)]{
+ def zip[A, B, C, D](a: T[A], b: T[B], c: T[C], d: T[D]) = new T[(A, B, C, D)]{
val inputs = Seq(a, b, c, d)
def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2), args[D](3))
}
- def zip[A, B, C, D, E](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E]) = new Target[(A, B, C, D, E)]{
+ def zip[A, B, C, D, E](a: T[A], b: T[B], c: T[C], d: T[D], e: T[E]) = new T[(A, B, C, D, E)]{
val inputs = Seq(a, b, c, d, e)
def evaluate(args: Args) = (args[A](0), args[B](1), args[C](2), args[D](3), args[E](4))
}
diff --git a/src/test/scala/forge/MetacircularTests.scala b/src/test/scala/forge/MetacircularTests.scala
index f11b4af5..7d0a4c1a 100644
--- a/src/test/scala/forge/MetacircularTests.scala
+++ b/src/test/scala/forge/MetacircularTests.scala
@@ -9,10 +9,11 @@ object MetacircularTests extends TestSuite{
object Self extends scalaplugin.Subproject {
val scalaVersion = T{ "2.12.4" }
override val compileDeps = T.raw{
- Seq(Dep(Mod("org.scala-lang", "scala-reflect"), scalaVersion(), configuration = "provided"))
+ Seq(
+ Dep(Mod("org.scala-lang", "scala-reflect"), scalaVersion(), configuration = "provided")
+ )
}
-
override val deps = T.raw{
Seq(
Dep(Mod("com.lihaoyi", "sourcecode_" + scalaBinaryVersion()), "0.1.4"),
@@ -42,3 +43,4 @@ object MetacircularTests extends TestSuite{
}
}
}
+