From bc6eb79f74a30aef2eb874eb7ba3c443c49e7554 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sat, 4 Nov 2017 13:57:01 -0700 Subject: Add some basic compile-time checks to enforce usage of `T{...}` within traits --- core/src/main/scala/forge/Target.scala | 16 +++++++++++---- core/src/main/scala/forge/util/LocalDef.scala | 28 +++++++++++++++++++++++++++ core/src/test/scala/forge/CacherTests.scala | 10 ++++++++++ core/src/test/scala/forge/TestGraphs.scala | 1 + 4 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/forge/util/LocalDef.scala diff --git a/core/src/main/scala/forge/Target.scala b/core/src/main/scala/forge/Target.scala index a6e977ae..ae5b488b 100644 --- a/core/src/main/scala/forge/Target.scala +++ b/core/src/main/scala/forge/Target.scala @@ -2,7 +2,7 @@ package forge import ammonite.ops.{ls, mkdir} -import forge.util.{Args, PathRef} +import forge.util.{Args, LocalDef, PathRef} import play.api.libs.json.{Format, JsValue, Json} import scala.annotation.compileTimeOnly @@ -34,8 +34,10 @@ abstract class Target[T] extends Target.Ops[T]{ object Target{ trait Cacher{ private[this] val cacherLazyMap = mutable.Map.empty[sourcecode.Enclosing, Target[_]] - protected[this] def T[T](t: T): Target[T] = macro impl[T] - protected[this] def T[T](t: => Target[T])(implicit c: sourcecode.Enclosing): Target[T] = { + protected[this] def T[T](t: T) + (implicit l: LocalDef): Target[T] = macro localDefImpl[T] + protected[this] def T[T](t: => Target[T]) + (implicit c: sourcecode.Enclosing, l: LocalDef): Target[T] = { cacherLazyMap.getOrElseUpdate(c, t).asInstanceOf[Target[T]] } } @@ -46,7 +48,13 @@ object Target{ } def apply[T](t: Target[T]): Target[T] = t def apply[T](t: T): Target[T] = macro impl[T] - def impl[T: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[Target[T]] = { + def localDefImpl[T: c.WeakTypeTag](c: Context) + (t: c.Expr[T]) + (l: c.Expr[LocalDef]): c.Expr[Target[T]] = { + impl(c)(t) + } + def impl[T: c.WeakTypeTag](c: Context) + (t: c.Expr[T]): c.Expr[Target[T]] = { import c.universe._ val bound = collection.mutable.Buffer.empty[(c.Tree, Symbol)] val OptionGet = c.universe.typeOf[Target[_]].member(TermName("apply")) diff --git a/core/src/main/scala/forge/util/LocalDef.scala b/core/src/main/scala/forge/util/LocalDef.scala new file mode 100644 index 00000000..2a58bbd2 --- /dev/null +++ b/core/src/main/scala/forge/util/LocalDef.scala @@ -0,0 +1,28 @@ +package forge.util + +import scala.reflect.macros.blackbox +import language.experimental.macros +class LocalDef +object LocalDef { + implicit def default: LocalDef = macro enclosing + def enclosing(c: blackbox.Context): c.Expr[LocalDef] = { + + import c.universe._ + val current = c.internal.enclosingOwner + + if ( + !current.isMethod || + // We can't do this right now because it causes recursive method errors + // current.asMethod.paramLists.nonEmpty || + !(current.owner.isClass || current.owner.isModuleClass) + ) { + c.abort( + c.enclosingPosition, + "T{} can only be used directly within a zero-arg method defined in a class body" + ) + }else{ + + c.Expr[LocalDef](q"""new forge.util.LocalDef()""") + } + } +} diff --git a/core/src/test/scala/forge/CacherTests.scala b/core/src/test/scala/forge/CacherTests.scala index 4c346e5e..c6157a4a 100644 --- a/core/src/test/scala/forge/CacherTests.scala +++ b/core/src/test/scala/forge/CacherTests.scala @@ -46,5 +46,15 @@ object CacherTests extends TestSuite{ eval(Terminal, Terminal.value) == 7, eval(Terminal, Terminal.overriden) == 1 ) + 'errors{ + val expectedMsg = + "T{} can only be used directly within a zero-arg method defined in a class body" + + val err1 = compileError("object Foo extends Target.Cacher{ val x = T{1} }") + assert(err1.msg == expectedMsg) + + val err2 = compileError("object Foo extends Target.Cacher{ def x = {def y = T{1}} }") + assert(err2.msg == expectedMsg) + } } } diff --git a/core/src/test/scala/forge/TestGraphs.scala b/core/src/test/scala/forge/TestGraphs.scala index 5960ed4c..0d6040eb 100644 --- a/core/src/test/scala/forge/TestGraphs.scala +++ b/core/src/test/scala/forge/TestGraphs.scala @@ -100,3 +100,4 @@ class TestGraphs(){ val j = test(test(i), test(i, f), test(f)) } } + -- cgit v1.2.3