From 8175b657bd5d008db67fc5d4f28e8463acb8d1ad Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Tue, 24 Oct 2017 22:44:50 -0700 Subject: Macro-based anonymous `Target` naming now works: we use the `T{...}` macro look at any `Target`s defined lexically within the call, and as long as they can only be evaluated once (relative to the `T{...}`) we assign them labels with an incrementing integer suffixed onto the enclosing `T{...}`s label --- src/main/scala/forge/DefCtx.scala | 69 ++++++++++++---- src/main/scala/forge/Evaluator.scala | 24 +++--- src/main/scala/forge/Target.scala | 5 +- src/test/scala/forge/ForgeTests.scala | 143 ++++++++++++++++++++++++++++------ 4 files changed, 186 insertions(+), 55 deletions(-) (limited to 'src') diff --git a/src/main/scala/forge/DefCtx.scala b/src/main/scala/forge/DefCtx.scala index b9c02602..97f9bf3f 100644 --- a/src/main/scala/forge/DefCtx.scala +++ b/src/main/scala/forge/DefCtx.scala @@ -1,14 +1,15 @@ package forge +import scala.annotation.compileTimeOnly import scala.language.experimental.macros -import scala.reflect.macros._ +import scala.reflect.macros.blackbox._ -sealed abstract class DefCtx(val value: Option[String]) +final case class DefCtx(label: String) object DefCtx{ - implicit object Anonymous extends DefCtx(None) - case class Labeled(label: String) extends DefCtx(Some(label)) + @compileTimeOnly("A DefCtx can only be provided directly within a T{} macro") + implicit def dummy: DefCtx with Int = ??? } object T{ @@ -16,19 +17,61 @@ object T{ def applyImpl[T: c.WeakTypeTag](c: Context)(expr: c.Expr[T]): c.Expr[T] = { import c.universe._ - val transformed = expr.tree match{ + var count = 0 + object transformer extends c.universe.Transformer { + override def transform(tree: c.Tree): c.Tree = { + if (tree.toString.startsWith("forge.") && tree.toString.endsWith(".DefCtx.dummy")) { + count += 1 + c.typecheck(q"forge.DefCtx(sourcecode.Enclosing() + $count)") + }else tree match{ + case Apply(fun, args) => + val extendedParams = fun.tpe.paramLists.head.padTo( + args.length, + fun.tpe.paramLists.head.lastOption.getOrElse(null) + ) + val newArgs = + for((sym, tree) <- extendedParams.zip(args)) + yield { + if (sym.asTerm.isByNameParam) tree + else transform(tree) + } + treeCopy.Apply(tree, transform(fun), newArgs) + + case t: DefDef => t + case t: ClassDef => t + case t: Function => t + case t: LabelDef => t + case t => super.transform(t) + } + + } + } + + + def transformTerminal(tree: c.Tree): c.Tree = tree match{ + case Block(stats, returnExpr) => + treeCopy.Block( + tree, + stats.map(transformer.transform(_)), + transformTerminal(returnExpr) + ) + case Apply(fun, args) => - var transformed = false - val newArgs = args.map{ - case x if x.tpe == weakTypeOf[DefCtx.Anonymous.type] => - transformed = true - q"forge.DefCtx.Labeled(sourcecode.Enclosing())" - case x => x + var isTransformed = false + val newArgs = for(x <- args) yield { + if (x.toString.startsWith("forge.") && x.toString.endsWith(".DefCtx.dummy")) { + isTransformed = true + c.typecheck(q"forge.DefCtx(sourcecode.Enclosing())") + }else transformer.transform(x) } - assert(transformed) - Apply(fun, newArgs) + + assert(isTransformed) + treeCopy.Apply(tree, transformer.transform(fun), newArgs) + case _ => ??? } + + val transformed = transformTerminal(expr.tree) c.Expr[T](transformed) } } \ No newline at end of file diff --git a/src/main/scala/forge/Evaluator.scala b/src/main/scala/forge/Evaluator.scala index d6ff39a2..43b4f353 100644 --- a/src/main/scala/forge/Evaluator.scala +++ b/src/main/scala/forge/Evaluator.scala @@ -20,27 +20,25 @@ class Evaluator(workspacePath: jnio.Path, for (target <- sortedTargets){ val inputResults = target.inputs.map(results).toIndexedSeq - val targetDestPath = target.defCtx.value match{ - case Some(enclosingStr) => - val targetDestPath = workspacePath.resolve( - jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.value.getOrElse(""))) - ) - deleteRec(targetDestPath) - targetDestPath - - case None => jnio.Files.createTempDirectory(null) + val targetDestPath = { + val enclosingStr = target.defCtx.label + val targetDestPath = workspacePath.resolve( + jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.label)) + ) + deleteRec(targetDestPath) + targetDestPath + } val inputsHash = inputResults.hashCode - target.defCtx.value.flatMap(resultCache.get) match{ + resultCache.get(target.defCtx.label) match{ case Some((hash, res)) if hash == inputsHash && !target.dirty => results(target) = res case _ => evaluated.append(target) val res = target.evaluate(new Args(inputResults, targetDestPath)) - for(label <- target.defCtx.value) { - resultCache(label) = (inputsHash, res) - } + + resultCache(target.defCtx.label) = (inputsHash, res) results(target) = res } diff --git a/src/main/scala/forge/Target.scala b/src/main/scala/forge/Target.scala index d7edc188..4d356d29 100644 --- a/src/main/scala/forge/Target.scala +++ b/src/main/scala/forge/Target.scala @@ -14,10 +14,7 @@ trait TargetOps[T]{ this: Target[T] => this.zip(other).map(s.apply _ tupled) } - override def toString = defCtx.value match{ - case None => this.getClass.getSimpleName + "@" + Integer.toHexString(System.identityHashCode(this)) - case Some(s) => this.getClass.getName + "@" + s - } + override def toString = this.getClass.getName + "@" + defCtx.label } trait Target[T] extends TargetOps[T]{ /** diff --git a/src/test/scala/forge/ForgeTests.scala b/src/test/scala/forge/ForgeTests.scala index d5275d0c..cf903461 100644 --- a/src/test/scala/forge/ForgeTests.scala +++ b/src/test/scala/forge/ForgeTests.scala @@ -6,7 +6,8 @@ import java.nio.{file => jnio} object ForgeTests extends TestSuite{ val tests = Tests{ - val evaluator = new Evaluator(jnio.Paths.get("target/workspace"), implicitly) + val baseCtx = DefCtx("forge.ForgeTests.tests ") + val evaluator = new Evaluator(jnio.Paths.get("target/workspace"), baseCtx) object Singleton { val single = T{ test() } } @@ -30,11 +31,69 @@ object ForgeTests extends TestSuite{ val down = T{ test(test(up), test(up)) } } - 'neg - { - compileError("T{ 123 }") - compileError("T{ println() }") - () + + 'syntaxLimits - { + // Make sure that we properly prohibit cases where a `test()` target can + // be created more than once with the same `DefCtx`, while still allowing + // cases where the `test()` target is created exactly one time, or even + // zero-or-one times (since that's ok, as long as it's not more than once) + + 'neg - { + 'nakedTest - { + compileError("test()") + () + } + 'notFunctionCall - { + compileError("T{ 123 }") + () + } + 'functionCallWithoutImplicit - { + compileError("T{ println() }") + () + } + // Make sure the snippets without `test()`s compile, but the same snippets + // *with* the `test()` calls do not (presumably due to the `@compileTimeOnly` + // annotation) + // + // For some reason, `if(false)` isn't good enough because scalac constant + // folds the conditional, eliminates the entire code block, and makes any + // `@compileTimeOnly`s annotations disappear... + + + 'canEvaluateMoreThanOnce - { + if (math.random() > 10) T{ Seq(1, 2).map(_ => ???); test() } + compileError("T{ Seq(1, 2).map(_ => test()); test() }") + + if (math.random() > 10) T{ class Foo{ ??? }; test() } + compileError("T{ class Foo{ test() }; test() }") + + if (math.random() > 10) T{ test({while(true){ }; ???}) } + compileError("T{ test({while(true){ test() }; ???}) }") + + if (math.random() > 10) T{ do{ } while(true); test() } + compileError("T{ do{ test() } while(true); test() }") + + if (math.random() > 10) T{ def foo() = ???; test() } + compileError("T{ def foo() = test(); test() }") + + if (math.random() > 10) T{ None.getOrElse(???); test() } + if (math.random() > 10) T{ None.contains(test()); test() } + compileError("T{ None.getOrElse(test()); test() }") + + () + } + } + 'pos - { + T{ test({val x = test(); x}) } + T{ test({lazy val x = test(); x}) } + T { object foo {val x = test()}; test(foo.x) } + T{ test({val x = if (math.random() > 0.5) test() else test(); x}) } + + () + } } + + 'topoSortedTransitiveTargets - { def check(targets: Seq[Target[_]], expected: Seq[Target[_]]) = { val result = Evaluator.topoSortedTransitiveTargets(targets) @@ -67,55 +126,89 @@ object ForgeTests extends TestSuite{ ) ) } + 'labeling - { + + def check(t: Target[_], relPath: String) = { + val targetLabel = t.defCtx.label + val expectedLabel = baseCtx.label + relPath + assert(targetLabel == expectedLabel) + } + 'singleton - check(Singleton.single, "Singleton.single") + 'pair - { + check(Pair.up, "Pair.up") + check(Pair.down, "Pair.down") + } + + 'anonTriple - { + check(AnonTriple.up, "AnonTriple.up") + check(AnonTriple.down.inputs(0), "AnonTriple.down1") + check(AnonTriple.down, "AnonTriple.down") + } + + 'diamond - { + check(Diamond.up, "Diamond.up") + check(Diamond.left, "Diamond.left") + check(Diamond.right, "Diamond.right") + check(Diamond.down, "Diamond.down") + } + + 'anonDiamond - { + check(AnonDiamond.up, "AnonDiamond.up") + check(AnonDiamond.down.inputs(0), "AnonDiamond.down1") + check(AnonDiamond.down.inputs(1), "AnonDiamond.down2") + check(AnonDiamond.down, "AnonDiamond.down") + } + + } 'evaluate - { def check(targets: Seq[Target[_]], - values: Seq[Any], - evaluated: Seq[Target[_]]) = { + expectedValues: Seq[Any], + expectedEvaluated: Seq[Target[_]]) = { val Evaluator.Results(returnedValues, returnedEvaluated) = evaluator.evaluate(targets) assert( - returnedValues == values, - returnedEvaluated == evaluated + returnedValues == expectedValues, + returnedEvaluated == expectedEvaluated ) } 'singleton - { import Singleton._ // First time the target is evaluated - check(Seq(single), values = Seq(0), evaluated = Seq(single)) + check(Seq(single), expectedValues = Seq(0), expectedEvaluated = Seq(single)) // Second time the value is already cached, so no evaluation needed - check(Seq(single), values = Seq(0), evaluated = Seq()) + check(Seq(single), expectedValues = Seq(0), expectedEvaluated = Seq()) single.counter += 1 // After incrementing the counter, it forces re-evaluation - check(Seq(single), values = Seq(1), evaluated = Seq(single)) + check(Seq(single), expectedValues = Seq(1), expectedEvaluated = Seq(single)) // Then it's cached again - check(Seq(single), values = Seq(1), evaluated = Seq()) + check(Seq(single), expectedValues = Seq(1), expectedEvaluated = Seq()) } 'pair - { import Pair._ - check(Seq(down), values = Seq(0), evaluated = Seq(up, down)) - check(Seq(down), values = Seq(0), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(0), expectedEvaluated = Seq(up, down)) + check(Seq(down), expectedValues = Seq(0), expectedEvaluated = Seq()) down.counter += 1 - check(Seq(down), values = Seq(1), evaluated = Seq(down)) - check(Seq(down), values = Seq(1), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(1), expectedEvaluated = Seq(down)) + check(Seq(down), expectedValues = Seq(1), expectedEvaluated = Seq()) up.counter += 1 - check(Seq(down), values = Seq(2), evaluated = Seq(up, down)) - check(Seq(down), values = Seq(2), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(2), expectedEvaluated = Seq(up, down)) + check(Seq(down), expectedValues = Seq(2), expectedEvaluated = Seq()) } 'anonTriple - { import AnonTriple._ val middle = down.inputs(0) - check(Seq(down), values = Seq(0), evaluated = Seq(up, middle, down)) - check(Seq(down), values = Seq(0), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(0), expectedEvaluated = Seq(up, middle, down)) + check(Seq(down), expectedValues = Seq(0), expectedEvaluated = Seq()) down.counter += 1 - check(Seq(down), values = Seq(1), evaluated = Seq(middle, down)) - check(Seq(down), values = Seq(1), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(1), expectedEvaluated = Seq(down)) + check(Seq(down), expectedValues = Seq(1), expectedEvaluated = Seq()) up.counter += 1 - check(Seq(down), values = Seq(2), evaluated = Seq(up, middle, down)) - check(Seq(down), values = Seq(2), evaluated = Seq()) + check(Seq(down), expectedValues = Seq(2), expectedEvaluated = Seq(up, middle, down)) + check(Seq(down), expectedValues = Seq(2), expectedEvaluated = Seq()) } } -- cgit v1.2.3