diff options
Diffstat (limited to 'main/core/src/util')
-rw-r--r-- | main/core/src/util/AggWrapper.scala | 119 | ||||
-rw-r--r-- | main/core/src/util/EitherOps.scala | 18 | ||||
-rw-r--r-- | main/core/src/util/EnclosingClass.scala | 15 | ||||
-rw-r--r-- | main/core/src/util/JsonFormatters.scala | 10 | ||||
-rw-r--r-- | main/core/src/util/Loggers.scala | 190 | ||||
-rw-r--r-- | main/core/src/util/MultiBiMap.scala | 57 | ||||
-rw-r--r-- | main/core/src/util/ParseArgs.scala | 137 | ||||
-rw-r--r-- | main/core/src/util/Router.scala | 451 | ||||
-rw-r--r-- | main/core/src/util/Scripts.scala | 330 | ||||
-rw-r--r-- | main/core/src/util/Watched.scala | 8 | ||||
-rw-r--r-- | main/core/src/util/package.scala | 7 |
11 files changed, 1342 insertions, 0 deletions
diff --git a/main/core/src/util/AggWrapper.scala b/main/core/src/util/AggWrapper.scala new file mode 100644 index 00000000..6c107875 --- /dev/null +++ b/main/core/src/util/AggWrapper.scala @@ -0,0 +1,119 @@ +package mill.util + + + +import scala.collection.mutable +object Strict extends AggWrapper(true) +object Loose extends AggWrapper(false) +sealed class AggWrapper(strictUniqueness: Boolean){ + /** + * A collection with enforced uniqueness, fast contains and deterministic + * ordering. Raises an exception if a duplicate is found; call + * `toSeq.distinct` if you explicitly want to make it swallow duplicates + */ + trait Agg[V] extends TraversableOnce[V]{ + def contains(v: V): Boolean + def items: Iterator[V] + def indexed: IndexedSeq[V] + def flatMap[T](f: V => TraversableOnce[T]): Agg[T] + def map[T](f: V => T): Agg[T] + def filter(f: V => Boolean): Agg[V] + def withFilter(f: V => Boolean): Agg[V] + def collect[T](f: PartialFunction[V, T]): Agg[T] + def zipWithIndex: Agg[(V, Int)] + def reverse: Agg[V] + def zip[T](other: Agg[T]): Agg[(V, T)] + def ++[T >: V](other: TraversableOnce[T]): Agg[T] + def length: Int + } + + object Agg{ + def empty[V]: Agg[V] = new Agg.Mutable[V] + implicit def jsonFormat[T: upickle.default.ReadWriter]: upickle.default.ReadWriter[Agg[T]] = + upickle.default.readwriter[Seq[T]].bimap[Agg[T]]( + _.toList, + Agg.from(_) + ) + + def apply[V](items: V*) = from(items) + + implicit def from[V](items: TraversableOnce[V]): Agg[V] = { + val set = new Agg.Mutable[V]() + items.foreach(set.append) + set + } + + + class Mutable[V]() extends Agg[V]{ + + private[this] val set0 = mutable.LinkedHashSet.empty[V] + def contains(v: V) = set0.contains(v) + def append(v: V) = if (!contains(v)){ + set0.add(v) + + }else if (strictUniqueness){ + throw new Exception("Duplicated item inserted into OrderedSet: " + v) + } + def appendAll(vs: Seq[V]) = vs.foreach(append) + def items = set0.iterator + def indexed: IndexedSeq[V] = items.toIndexedSeq + def set: collection.Set[V] = set0 + + def map[T](f: V => T): Agg[T] = { + val output = new Agg.Mutable[T] + for(i <- items) output.append(f(i)) + output + } + def flatMap[T](f: V => TraversableOnce[T]): Agg[T] = { + val output = new Agg.Mutable[T] + for(i <- items) for(i0 <- f(i)) output.append(i0) + output + } + def filter(f: V => Boolean): Agg[V] = { + val output = new Agg.Mutable[V] + for(i <- items) if (f(i)) output.append(i) + output + } + def withFilter(f: V => Boolean): Agg[V] = filter(f) + + def collect[T](f: PartialFunction[V, T]) = this.filter(f.isDefinedAt).map(x => f(x)) + + def zipWithIndex = { + var i = 0 + this.map{ x => + i += 1 + (x, i-1) + } + } + + def reverse = Agg.from(indexed.reverseIterator) + + def zip[T](other: Agg[T]) = Agg.from(items.zip(other.items)) + def ++[T >: V](other: TraversableOnce[T]) = Agg.from(items ++ other) + def length: Int = set0.size + + // Members declared in scala.collection.GenTraversableOnce + def isTraversableAgain: Boolean = items.isTraversableAgain + def toIterator: Iterator[V] = items.toIterator + def toStream: Stream[V] = items.toStream + + // Members declared in scala.collection.TraversableOnce + def copyToArray[B >: V](xs: Array[B], start: Int,len: Int): Unit = items.copyToArray(xs, start, len) + def exists(p: V => Boolean): Boolean = items.exists(p) + def find(p: V => Boolean): Option[V] = items.find(p) + def forall(p: V => Boolean): Boolean = items.forall(p) + def foreach[U](f: V => U): Unit = items.foreach(f) + def hasDefiniteSize: Boolean = items.hasDefiniteSize + def isEmpty: Boolean = items.isEmpty + def seq: scala.collection.TraversableOnce[V] = items + def toTraversable: Traversable[V] = items.toTraversable + + override def hashCode() = items.map(_.hashCode()).sum + override def equals(other: Any) = other match{ + case s: Agg[_] => items.sameElements(s.items) + case _ => super.equals(other) + } + override def toString = items.mkString("Agg(", ", ", ")") + } + } +} diff --git a/main/core/src/util/EitherOps.scala b/main/core/src/util/EitherOps.scala new file mode 100644 index 00000000..da2552c8 --- /dev/null +++ b/main/core/src/util/EitherOps.scala @@ -0,0 +1,18 @@ +package mill.util + +import scala.collection.generic.CanBuildFrom +import scala.collection.mutable +import scala.language.higherKinds + +object EitherOps { + + // implementation similar to scala.concurrent.Future#sequence + def sequence[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])( + implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = { + in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) { + case (acc, el) => + for (a <- acc; e <- el) yield a += e + } + .map(_.result()) + } +} diff --git a/main/core/src/util/EnclosingClass.scala b/main/core/src/util/EnclosingClass.scala new file mode 100644 index 00000000..a69cc525 --- /dev/null +++ b/main/core/src/util/EnclosingClass.scala @@ -0,0 +1,15 @@ +package mill.util + +import sourcecode.Compat.Context +import language.experimental.macros +case class EnclosingClass(value: Class[_]) +object EnclosingClass{ + def apply()(implicit c: EnclosingClass) = c.value + implicit def generate: EnclosingClass = macro impl + def impl(c: Context): c.Tree = { + import c.universe._ + val cls = c.internal.enclosingOwner.owner.asType.asClass + // q"new _root_.mill.define.EnclosingClass(classOf[$cls])" + q"new _root_.mill.util.EnclosingClass(this.getClass)" + } +} diff --git a/main/core/src/util/JsonFormatters.scala b/main/core/src/util/JsonFormatters.scala new file mode 100644 index 00000000..830782c6 --- /dev/null +++ b/main/core/src/util/JsonFormatters.scala @@ -0,0 +1,10 @@ +package mill.util + +import upickle.default.{ReadWriter => RW} + +trait JsonFormatters extends mill.api.JsonFormatters{ + implicit lazy val modFormat: RW[coursier.Module] = upickle.default.macroRW + implicit lazy val depFormat: RW[coursier.Dependency]= upickle.default.macroRW + implicit lazy val attrFormat: RW[coursier.Attributes] = upickle.default.macroRW +} +object JsonFormatters extends JsonFormatters diff --git a/main/core/src/util/Loggers.scala b/main/core/src/util/Loggers.scala new file mode 100644 index 00000000..aab1a324 --- /dev/null +++ b/main/core/src/util/Loggers.scala @@ -0,0 +1,190 @@ +package mill.util + +import java.io._ +import mill.api.Logger + +object DummyLogger extends Logger { + def colored = false + + object errorStream extends PrintStream(_ => ()) + object outputStream extends PrintStream(_ => ()) + val inStream = new ByteArrayInputStream(Array()) + + def info(s: String) = () + def error(s: String) = () + def ticker(s: String) = () + def debug(s: String) = () +} + +class CallbackStream(wrapped: OutputStream, + setPrintState0: PrintState => Unit) extends OutputStream{ + def setPrintState(c: Char) = { + setPrintState0( + c match{ + case '\n' => PrintState.Newline + case '\r' => PrintState.Newline + case _ => PrintState.Middle + } + ) + } + override def write(b: Array[Byte]): Unit = { + if (b.nonEmpty) setPrintState(b(b.length-1).toChar) + wrapped.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (len != 0) setPrintState(b(off+len-1).toChar) + wrapped.write(b, off, len) + } + + def write(b: Int) = { + setPrintState(b.toChar) + wrapped.write(b) + } +} +sealed trait PrintState +object PrintState{ + case object Ticker extends PrintState + case object Newline extends PrintState + case object Middle extends PrintState +} + +case class PrintLogger( + colored: Boolean, + disableTicker: Boolean, + colors: ammonite.util.Colors, + outStream: PrintStream, + infoStream: PrintStream, + errStream: PrintStream, + inStream: InputStream, + debugEnabled: Boolean + ) extends Logger { + + var printState: PrintState = PrintState.Newline + + override val errorStream = new PrintStream(new CallbackStream(errStream, printState = _)) + override val outputStream = new PrintStream(new CallbackStream(outStream, printState = _)) + + + def info(s: String) = { + printState = PrintState.Newline + infoStream.println(colors.info()(s)) + } + def error(s: String) = { + printState = PrintState.Newline + errStream.println(colors.error()(s)) + } + def ticker(s: String) = { + if(!disableTicker) { + printState match{ + case PrintState.Newline => + infoStream.println(colors.info()(s)) + case PrintState.Middle => + infoStream.println() + infoStream.println(colors.info()(s)) + case PrintState.Ticker => + val p = new PrintWriter(infoStream) + val nav = new ammonite.terminal.AnsiNav(p) + nav.up(1) + nav.clearLine(2) + nav.left(9999) + p.flush() + + infoStream.println(colors.info()(s)) + } + printState = PrintState.Ticker + } + } + + def debug(s: String) = if (debugEnabled) { + printState = PrintState.Newline + errStream.println(s) + } +} + +case class FileLogger(colored: Boolean, file: os.Path, debugEnabled: Boolean) extends Logger { + private[this] var outputStreamUsed: Boolean = false + + lazy val outputStream = { + if (!outputStreamUsed) os.remove.all(file) + outputStreamUsed = true + new PrintStream(new FileOutputStream(file.toIO.getAbsolutePath)) + } + + lazy val errorStream = { + if (!outputStreamUsed) os.remove.all(file) + outputStreamUsed = true + new PrintStream(new FileOutputStream(file.toIO.getAbsolutePath)) + } + + def info(s: String) = outputStream.println(s) + def error(s: String) = outputStream.println(s) + def ticker(s: String) = outputStream.println(s) + def debug(s: String) = if (debugEnabled) outputStream.println(s) + val inStream: InputStream = mill.api.DummyInputStream + override def close() = { + if (outputStreamUsed) + outputStream.close() + } +} + + + +class MultiStream(stream1: OutputStream, stream2: OutputStream) extends PrintStream(new OutputStream { + def write(b: Int): Unit = { + stream1.write(b) + stream2.write(b) + } + override def write(b: Array[Byte]): Unit = { + stream1.write(b) + stream2.write(b) + } + override def write(b: Array[Byte], off: Int, len: Int) = { + stream1.write(b, off, len) + stream2.write(b, off, len) + } + override def flush() = { + stream1.flush() + stream2.flush() + } + override def close() = { + stream1.close() + stream2.close() + } +}) + +case class MultiLogger(colored: Boolean, logger1: Logger, logger2: Logger) extends Logger { + + + lazy val outputStream: PrintStream = new MultiStream(logger1.outputStream, logger2.outputStream) + + lazy val errorStream: PrintStream = new MultiStream(logger1.errorStream, logger2.errorStream) + + lazy val inStream = Seq(logger1, logger2).collectFirst{case t: PrintLogger => t} match{ + case Some(x) => x.inStream + case None => new ByteArrayInputStream(Array()) + } + + def info(s: String) = { + logger1.info(s) + logger2.info(s) + } + def error(s: String) = { + logger1.error(s) + logger2.error(s) + } + def ticker(s: String) = { + logger1.ticker(s) + logger2.ticker(s) + } + + def debug(s: String) = { + logger1.debug(s) + logger2.debug(s) + } + + override def close() = { + logger1.close() + logger2.close() + } +} diff --git a/main/core/src/util/MultiBiMap.scala b/main/core/src/util/MultiBiMap.scala new file mode 100644 index 00000000..73bb42c4 --- /dev/null +++ b/main/core/src/util/MultiBiMap.scala @@ -0,0 +1,57 @@ +package mill.util + +import scala.collection.mutable +import Strict.Agg + +/** + * A map from keys to collections of values: you can assign multiple values + * to any particular key. Also allows lookups in both directions: what values + * are assigned to a key or what key a value is assigned to. + */ +trait MultiBiMap[K, V]{ + def containsValue(v: V): Boolean + def lookupKey(k: K): Agg[V] + def lookupValue(v: V): K + def lookupValueOpt(v: V): Option[K] + def add(k: K, v: V): Unit + def removeAll(k: K): Agg[V] + def addAll(k: K, vs: TraversableOnce[V]): Unit + def keys(): Iterator[K] + def items(): Iterator[(K, Agg[V])] + def values(): Iterator[Agg[V]] + def keyCount: Int +} + +object MultiBiMap{ + + class Mutable[K, V]() extends MultiBiMap[K, V]{ + private[this] val valueToKey = mutable.LinkedHashMap.empty[V, K] + private[this] val keyToValues = mutable.LinkedHashMap.empty[K, Agg.Mutable[V]] + def containsValue(v: V) = valueToKey.contains(v) + def lookupKey(k: K) = keyToValues(k) + def lookupKeyOpt(k: K) = keyToValues.get(k) + def lookupValue(v: V) = valueToKey(v) + def lookupValueOpt(v: V) = valueToKey.get(v) + def add(k: K, v: V): Unit = { + valueToKey(v) = k + keyToValues.getOrElseUpdate(k, new Agg.Mutable[V]()).append(v) + } + def removeAll(k: K): Agg[V] = keyToValues.get(k) match { + case None => Agg() + case Some(vs) => + vs.foreach(valueToKey.remove) + + keyToValues.remove(k) + vs + } + def addAll(k: K, vs: TraversableOnce[V]): Unit = vs.foreach(this.add(k, _)) + + def keys() = keyToValues.keysIterator + + def values() = keyToValues.valuesIterator + + def items() = keyToValues.iterator + + def keyCount = keyToValues.size + } +} diff --git a/main/core/src/util/ParseArgs.scala b/main/core/src/util/ParseArgs.scala new file mode 100644 index 00000000..fc1a8ab3 --- /dev/null +++ b/main/core/src/util/ParseArgs.scala @@ -0,0 +1,137 @@ +package mill.util + +import fastparse._, NoWhitespace._ +import mill.define.{Segment, Segments} + +object ParseArgs { + + def apply(scriptArgs: Seq[String], + multiSelect: Boolean): Either[String, (List[(Option[Segments], Segments)], Seq[String])] = { + val (selectors, args) = extractSelsAndArgs(scriptArgs, multiSelect) + for { + _ <- validateSelectors(selectors) + expandedSelectors <- EitherOps + .sequence(selectors.map(expandBraces)) + .map(_.flatten) + selectors <- EitherOps.sequence(expandedSelectors.map(extractSegments)) + } yield (selectors.toList, args) + } + + def extractSelsAndArgs(scriptArgs: Seq[String], + multiSelect: Boolean): (Seq[String], Seq[String]) = { + + if (multiSelect) { + val dd = scriptArgs.indexOf("--") + val selectors = if (dd == -1) scriptArgs else scriptArgs.take(dd) + val args = if (dd == -1) Seq.empty else scriptArgs.drop(dd + 1) + + (selectors, args) + } else { + (scriptArgs.take(1), scriptArgs.drop(1)) + } + } + + private def validateSelectors(selectors: Seq[String]): Either[String, Unit] = { + if (selectors.isEmpty || selectors.exists(_.isEmpty)) + Left("Selector cannot be empty") + else Right(()) + } + + def expandBraces(selectorString: String): Either[String, List[String]] = { + parseBraceExpansion(selectorString) match { + case f: Parsed.Failure => Left(s"Parsing exception ${f.msg}") + case Parsed.Success(expanded, _) => Right(expanded.toList) + } + } + + private sealed trait Fragment + private object Fragment { + case class Keep(value: String) extends Fragment + case class Expand(values: List[List[Fragment]]) extends Fragment + + def unfold(fragments: List[Fragment]): Seq[String] = { + fragments match { + case head :: rest => + val prefixes = head match { + case Keep(v) => Seq(v) + case Expand(Nil) => Seq("{}") + case Expand(List(vs)) => unfold(vs).map("{" + _ + "}") + case Expand(vss) => vss.flatMap(unfold) + } + for { + prefix <- prefixes + suffix <- unfold(rest) + } yield prefix + suffix + + case Nil => Seq("") + } + } + } + + private object BraceExpansionParser { + def plainChars[_: P] = + P(CharsWhile(c => c != ',' && c != '{' && c != '}')).!.map(Fragment.Keep) + + def toExpand[_: P]: P[Fragment] = + P("{" ~ braceParser.rep(1).rep(sep = ",") ~ "}").map( + x => Fragment.Expand(x.toList.map(_.toList)) + ) + + def braceParser[_: P] = P(toExpand | plainChars) + + def parser[_: P] = P(braceParser.rep(1).rep(sep = ",") ~ End).map { vss => + def unfold(vss: List[Seq[String]]): Seq[String] = { + vss match { + case Nil => Seq("") + case head :: rest => + for { + str <- head + r <- unfold(rest) + } yield + r match { + case "" => str + case _ => str + "," + r + } + } + } + + val stringss = vss.map(x => Fragment.unfold(x.toList)).toList + unfold(stringss) + } + } + + private def parseBraceExpansion(input: String) = { + + + parse( + input, + BraceExpansionParser.parser(_) + ) + } + + def extractSegments(selectorString: String): Either[String, (Option[Segments], Segments)] = + parseSelector(selectorString) match { + case f: Parsed.Failure => Left(s"Parsing exception ${f.msg}") + case Parsed.Success(selector, _) => Right(selector) + } + + private def ident[_: P] = P( CharsWhileIn("a-zA-Z0-9_\\-") ).! + + def standaloneIdent[_: P] = P(Start ~ ident ~ End ) + def isLegalIdentifier(identifier: String): Boolean = + parse(identifier, standaloneIdent(_)).isInstanceOf[Parsed.Success[_]] + + private def parseSelector(input: String) = { + def ident2[_: P] = P( CharsWhileIn("a-zA-Z0-9_\\-.") ).! + def segment[_: P] = P( ident ).map( Segment.Label) + def crossSegment[_: P] = P("[" ~ ident2.rep(1, sep = ",") ~ "]").map(Segment.Cross) + def simpleQuery[_: P] = P(segment ~ ("." ~ segment | crossSegment).rep).map { + case (h, rest) => Segments(h :: rest.toList:_*) + } + def query[_: P] = P( simpleQuery ~ ("/" ~/ simpleQuery).?).map{ + case (q, None) => (None, q) + case (q, Some(q2)) => (Some(q), q2) + } + parse(input, query(_)) + } +} diff --git a/main/core/src/util/Router.scala b/main/core/src/util/Router.scala new file mode 100644 index 00000000..5dd3c947 --- /dev/null +++ b/main/core/src/util/Router.scala @@ -0,0 +1,451 @@ +package mill.util + +import ammonite.main.Compat +import language.experimental.macros + +import scala.annotation.StaticAnnotation +import scala.collection.mutable +import scala.reflect.macros.blackbox.Context + +/** + * More or less a minimal version of Autowire's Server that lets you generate + * a set of "routes" from the methods defined in an object, and call them + * using passing in name/args/kwargs via Java reflection, without having to + * generate/compile code or use Scala reflection. This saves us spinning up + * the Scala compiler and greatly reduces the startup time of cached scripts. + */ +object Router{ + /** + * Allows you to query how many things are overriden by the enclosing owner. + */ + case class Overrides(value: Int) + object Overrides{ + def apply()(implicit c: Overrides) = c.value + implicit def generate: Overrides = macro impl + def impl(c: Context): c.Tree = { + import c.universe._ + q"new _root_.mill.util.Router.Overrides(${c.internal.enclosingOwner.overrides.length})" + } + } + + class doc(s: String) extends StaticAnnotation + class main extends StaticAnnotation + def generateRoutes[T]: Seq[Router.EntryPoint[T]] = macro generateRoutesImpl[T] + def generateRoutesImpl[T: c.WeakTypeTag](c: Context): c.Expr[Seq[EntryPoint[T]]] = { + import c.universe._ + val r = new Router(c) + val allRoutes = r.getAllRoutesForClass( + weakTypeOf[T].asInstanceOf[r.c.Type] + ).asInstanceOf[Iterable[c.Tree]] + + c.Expr[Seq[EntryPoint[T]]](q"_root_.scala.Seq(..$allRoutes)") + } + + /** + * Models what is known by the router about a single argument: that it has + * a [[name]], a human-readable [[typeString]] describing what the type is + * (just for logging and reading, not a replacement for a `TypeTag`) and + * possible a function that can compute its default value + */ + case class ArgSig[T, V](name: String, + typeString: String, + doc: Option[String], + default: Option[T => V]) + (implicit val reads: scopt.Read[V]) + + def stripDashes(s: String) = { + if (s.startsWith("--")) s.drop(2) + else if (s.startsWith("-")) s.drop(1) + else s + } + /** + * What is known about a single endpoint for our routes. It has a [[name]], + * [[argSignatures]] for each argument, and a macro-generated [[invoke0]] + * that performs all the necessary argument parsing and de-serialization. + * + * Realistically, you will probably spend most of your time calling [[invoke]] + * instead, which provides a nicer API to call it that mimmicks the API of + * calling a Scala method. + */ + case class EntryPoint[T](name: String, + argSignatures: Seq[ArgSig[T, _]], + doc: Option[String], + varargs: Boolean, + invoke0: (T, Map[String, String], Seq[String], Seq[ArgSig[T, _]]) => Result[Any], + overrides: Int){ + def invoke(target: T, groupedArgs: Seq[(String, Option[String])]): Result[Any] = { + var remainingArgSignatures = argSignatures.toList.filter(_.reads.arity > 0) + + val accumulatedKeywords = mutable.Map.empty[ArgSig[T, _], mutable.Buffer[String]] + val keywordableArgs = if (varargs) argSignatures.dropRight(1) else argSignatures + + for(arg <- keywordableArgs) accumulatedKeywords(arg) = mutable.Buffer.empty + + val leftoverArgs = mutable.Buffer.empty[String] + + val lookupArgSig = Map(argSignatures.map(x => (x.name, x)):_*) + + var incomplete: Option[ArgSig[T, _]] = None + + for(group <- groupedArgs){ + + group match{ + case (value, None) => + if (value(0) == '-' && !varargs){ + lookupArgSig.get(stripDashes(value)) match{ + case None => leftoverArgs.append(value) + case Some(sig) => incomplete = Some(sig) + } + + } else remainingArgSignatures match { + case Nil => leftoverArgs.append(value) + case last :: Nil if varargs => leftoverArgs.append(value) + case next :: rest => + accumulatedKeywords(next).append(value) + remainingArgSignatures = rest + } + case (rawKey, Some(value)) => + val key = stripDashes(rawKey) + lookupArgSig.get(key) match{ + case Some(x) if accumulatedKeywords.contains(x) => + if (accumulatedKeywords(x).nonEmpty && varargs){ + leftoverArgs.append(rawKey, value) + }else{ + accumulatedKeywords(x).append(value) + remainingArgSignatures = remainingArgSignatures.filter(_.name != key) + } + case _ => + leftoverArgs.append(rawKey, value) + } + } + } + + val missing0 = remainingArgSignatures + .filter(_.default.isEmpty) + + val missing = if(varargs) { + missing0.filter(_ != argSignatures.last) + } else { + missing0.filter(x => incomplete != Some(x)) + } + val duplicates = accumulatedKeywords.toSeq.filter(_._2.length > 1) + + if ( + incomplete.nonEmpty || + missing.nonEmpty || + duplicates.nonEmpty || + (leftoverArgs.nonEmpty && !varargs) + ){ + Result.Error.MismatchedArguments( + missing = missing, + unknown = leftoverArgs, + duplicate = duplicates, + incomplete = incomplete + + ) + } else { + val mapping = accumulatedKeywords + .iterator + .collect{case (k, Seq(single)) => (k.name, single)} + .toMap + + try invoke0(target, mapping, leftoverArgs, argSignatures) + catch{case e: Throwable => + Result.Error.Exception(e) + } + } + } + } + + def tryEither[T](t: => T, error: Throwable => Result.ParamError) = { + try Right(t) + catch{ case e: Throwable => Left(error(e))} + } + def readVarargs(arg: ArgSig[_, _], + values: Seq[String], + thunk: String => Any) = { + val attempts = + for(item <- values) + yield tryEither(thunk(item), Result.ParamError.Invalid(arg, item, _)) + + + val bad = attempts.collect{ case Left(x) => x} + if (bad.nonEmpty) Left(bad) + else Right(attempts.collect{case Right(x) => x}) + } + def read(dict: Map[String, String], + default: => Option[Any], + arg: ArgSig[_, _], + thunk: String => Any): FailMaybe = { + arg.reads.arity match{ + case 0 => + tryEither(thunk(null), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + case 1 => + dict.get(arg.name) match{ + case None => + tryEither(default.get, Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + + case Some(x) => + tryEither(thunk(x), Result.ParamError.Invalid(arg, x, _)).left.map(Seq(_)) + } + } + + } + + /** + * Represents what comes out of an attempt to invoke an [[EntryPoint]]. + * Could succeed with a value, but could fail in many different ways. + */ + sealed trait Result[+T] + object Result{ + + /** + * Invoking the [[EntryPoint]] was totally successful, and returned a + * result + */ + case class Success[T](value: T) extends Result[T] + + /** + * Invoking the [[EntryPoint]] was not successful + */ + sealed trait Error extends Result[Nothing] + object Error{ + + /** + * Invoking the [[EntryPoint]] failed with an exception while executing + * code within it. + */ + case class Exception(t: Throwable) extends Error + + /** + * Invoking the [[EntryPoint]] failed because the arguments provided + * did not line up with the arguments expected + */ + case class MismatchedArguments(missing: Seq[ArgSig[_, _]], + unknown: Seq[String], + duplicate: Seq[(ArgSig[_, _], Seq[String])], + incomplete: Option[ArgSig[_, _]]) extends Error + /** + * Invoking the [[EntryPoint]] failed because there were problems + * deserializing/parsing individual arguments + */ + case class InvalidArguments(values: Seq[ParamError]) extends Error + } + + sealed trait ParamError + object ParamError{ + /** + * Something went wrong trying to de-serialize the input parameter; + * the thrown exception is stored in [[ex]] + */ + case class Invalid(arg: ArgSig[_, _], value: String, ex: Throwable) extends ParamError + /** + * Something went wrong trying to evaluate the default value + * for this input parameter + */ + case class DefaultFailed(arg: ArgSig[_, _], ex: Throwable) extends ParamError + } + } + + + type FailMaybe = Either[Seq[Result.ParamError], Any] + type FailAll = Either[Seq[Result.ParamError], Seq[Any]] + + def validate(args: Seq[FailMaybe]): Result[Seq[Any]] = { + val lefts = args.collect{case Left(x) => x}.flatten + + if (lefts.nonEmpty) Result.Error.InvalidArguments(lefts) + else { + val rights = args.collect{case Right(x) => x} + Result.Success(rights) + } + } + + def makeReadCall(dict: Map[String, String], + default: => Option[Any], + arg: ArgSig[_, _]) = { + read(dict, default, arg, arg.reads.reads(_)) + } + def makeReadVarargsCall(arg: ArgSig[_, _], values: Seq[String]) = { + readVarargs(arg, values, arg.reads.reads(_)) + } +} + + +class Router [C <: Context](val c: C) { + import c.universe._ + def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = { + def isAMemberOfAnyRef(member: Symbol) = { + // AnyRef is an alias symbol, we go to the real "owner" of these methods + val anyRefSym = c.mirror.universe.definitions.ObjectClass + member.owner == anyRefSym + } + val extractableMembers = for { + member <- curCls.members.toList.reverse + if !isAMemberOfAnyRef(member) + if !member.isSynthetic + if member.isPublic + if member.isTerm + memTerm = member.asTerm + if memTerm.isMethod + if !memTerm.isModule + } yield memTerm.asMethod + + extractableMembers flatMap { case memTerm => + if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil + else Seq(memTerm) + + } + } + + + + def extractMethod(meth: MethodSymbol, curCls: c.universe.Type): c.universe.Tree = { + val baseArgSym = TermName(c.freshName()) + val flattenedArgLists = meth.paramss.flatten + def hasDefault(i: Int) = { + val defaultName = s"${meth.name}$$default$$${i + 1}" + if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName) + else None + } + val argListSymbol = q"${c.fresh[TermName]("argsList")}" + val extrasSymbol = q"${c.fresh[TermName]("extras")}" + val defaults = for ((arg, i) <- flattenedArgLists.zipWithIndex) yield { + val arg = TermName(c.freshName()) + hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}") + } + + def getDocAnnotation(annotations: List[Annotation]) = { + val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[Router.doc]) + val docValues = for { + doc <- docTrees + if doc.scalaArgs.head.isInstanceOf[Literal] + l = doc.scalaArgs.head.asInstanceOf[Literal] + if l.value.value.isInstanceOf[String] + } yield l.value.value.asInstanceOf[String] + (remaining, docValues.headOption) + } + + def unwrapVarargType(arg: Symbol) = { + val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass + val unwrappedType = + if (!vararg) arg.typeSignature + else arg.typeSignature.asInstanceOf[TypeRef].args(0) + + (vararg, unwrappedType) + } + + val argSigSymbol = q"${c.fresh[TermName]("argSigs")}" + + val (_, methodDoc) = getDocAnnotation(meth.annotations) + val readArgSigs = for( + ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex + ) yield { + + val (vararg, varargUnwrappedType) = unwrapVarargType(arg) + + val default = + if (vararg) q"scala.Some(scala.Nil)" + else defaultOpt match { + case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))" + case None => q"scala.None" + } + + val (docUnwrappedType, docOpt) = varargUnwrappedType match{ + case t: AnnotatedType => + + val (remaining, docValue) = getDocAnnotation(t.annotations) + if (remaining.isEmpty) (t.underlying, docValue) + else (Compat.copyAnnotatedType(c)(t, remaining), docValue) + + case t => (t, None) + } + + val docTree = docOpt match{ + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + } + + + val argSig = q""" + mill.util.Router.ArgSig[$curCls, $docUnwrappedType]( + ${arg.name.toString}, + ${docUnwrappedType.toString + (if(vararg) "*" else "")}, + $docTree, + $defaultOpt + ) + """ + + val reader = + if(vararg) q""" + mill.util.Router.makeReadVarargsCall( + $argSigSymbol($i), + $extrasSymbol + ) + """ else q""" + mill.util.Router.makeReadCall( + $argListSymbol, + $default, + $argSigSymbol($i) + ) + """ + c.internal.setPos(reader, meth.pos) + (reader, argSig, vararg) + } + + val readArgs = readArgSigs.map(_._1) + val argSigs = readArgSigs.map(_._2) + val varargs = readArgSigs.map(_._3) + val (argNames, argNameCasts) = flattenedArgLists.map { arg => + val (vararg, unwrappedType) = unwrapVarargType(arg) + ( + pq"${arg.name.toTermName}", + if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]" + else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*" + + ) + }.unzip + + + val res = q""" + mill.util.Router.EntryPoint[$curCls]( + ${meth.name.toString}, + scala.Seq(..$argSigs), + ${methodDoc match{ + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + }}, + ${varargs.contains(true)}, + ( + $baseArgSym: $curCls, + $argListSymbol: Map[String, String], + $extrasSymbol: Seq[String], + $argSigSymbol: Seq[mill.util.Router.ArgSig[$curCls, _]] + ) => + mill.util.Router.validate(Seq(..$readArgs)) match{ + case mill.util.Router.Result.Success(List(..$argNames)) => + mill.util.Router.Result.Success( + $baseArgSym.${meth.name.toTermName}(..$argNameCasts) + ) + case x: mill.util.Router.Result.Error => x + }, + ammonite.main.Router.Overrides() + ) + """ + res + } + + def hasMainAnnotation(t: MethodSymbol) = { + t.annotations.exists(_.tpe =:= typeOf[Router.main]) + } + def getAllRoutesForClass(curCls: Type, + pred: MethodSymbol => Boolean = hasMainAnnotation) + : Iterable[c.universe.Tree] = { + for{ + t <- getValsOrMeths(curCls) + if pred(t) + } yield { + extractMethod(t, curCls) + } + } +} diff --git a/main/core/src/util/Scripts.scala b/main/core/src/util/Scripts.scala new file mode 100644 index 00000000..65eb6b2b --- /dev/null +++ b/main/core/src/util/Scripts.scala @@ -0,0 +1,330 @@ +package mill.util + +import java.nio.file.NoSuchFileException + + +import ammonite.runtime.Evaluator.AmmoniteExit +import ammonite.util.Name.backtickWrap +import ammonite.util.Util.CodeSource +import ammonite.util.{Name, Res, Util} +import fastparse.internal.Util.literalize +import mill.util.Router.{ArgSig, EntryPoint} + +/** + * Logic around using Ammonite as a script-runner; invoking scripts via the + * macro-generated [[Router]], and pretty-printing any output or error messages + */ +object Scripts { + def groupArgs(flatArgs: List[String]): Seq[(String, Option[String])] = { + var keywordTokens = flatArgs + var scriptArgs = Vector.empty[(String, Option[String])] + + while(keywordTokens.nonEmpty) keywordTokens match{ + case List(head, next, rest@_*) if head.startsWith("-") => + scriptArgs = scriptArgs :+ (head, Some(next)) + keywordTokens = rest.toList + case List(head, rest@_*) => + scriptArgs = scriptArgs :+ (head, None) + keywordTokens = rest.toList + + } + scriptArgs + } + + def runScript(wd: os.Path, + path: os.Path, + interp: ammonite.interp.Interpreter, + scriptArgs: Seq[(String, Option[String])] = Nil) = { + interp.watch(path) + val (pkg, wrapper) = Util.pathToPackageWrapper(Seq(), path relativeTo wd) + + for{ + scriptTxt <- try Res.Success(Util.normalizeNewlines(os.read(path))) catch{ + case e: NoSuchFileException => Res.Failure("Script file not found: " + path) + } + + processed <- interp.processModule( + scriptTxt, + CodeSource(wrapper, pkg, Seq(Name("ammonite"), Name("$file")), Some(path)), + autoImport = true, + // Not sure why we need to wrap this in a separate `$routes` object, + // but if we don't do it for some reason the `generateRoutes` macro + // does not see the annotations on the methods of the outer-wrapper. + // It can inspect the type and its methods fine, it's just the + // `methodsymbol.annotations` ends up being empty. + extraCode = Util.normalizeNewlines( + s""" + |val $$routesOuter = this + |object $$routes + |extends scala.Function0[scala.Seq[ammonite.main.Router.EntryPoint[$$routesOuter.type]]]{ + | def apply() = ammonite.main.Router.generateRoutes[$$routesOuter.type] + |} + """.stripMargin + ), + hardcoded = true + ) + + routeClsName <- processed.blockInfo.lastOption match{ + case Some(meta) => Res.Success(meta.id.wrapperPath) + case None => Res.Skip + } + + mainCls = + interp + .evalClassloader + .loadClass(processed.blockInfo.last.id.wrapperPath + "$") + + routesCls = + interp + .evalClassloader + .loadClass(routeClsName + "$$routes$") + + scriptMains = + routesCls + .getField("MODULE$") + .get(null) + .asInstanceOf[() => Seq[Router.EntryPoint[Any]]] + .apply() + + + mainObj = mainCls.getField("MODULE$").get(null) + + res <- Util.withContextClassloader(interp.evalClassloader){ + scriptMains match { + // If there are no @main methods, there's nothing to do + case Seq() => + if (scriptArgs.isEmpty) Res.Success(()) + else { + val scriptArgString = + scriptArgs.flatMap{case (a, b) => Seq(a) ++ b}.map(literalize(_)) + .mkString(" ") + + Res.Failure("Script " + path.last + " does not take arguments: " + scriptArgString) + } + + // If there's one @main method, we run it with all args + case Seq(main) => runMainMethod(mainObj, main, scriptArgs) + + // If there are multiple @main methods, we use the first arg to decide + // which method to run, and pass the rest to that main method + case mainMethods => + val suffix = formatMainMethods(mainObj, mainMethods) + scriptArgs match{ + case Seq() => + Res.Failure( + s"Need to specify a subcommand to call when running " + path.last + suffix + ) + case Seq((head, Some(_)), tail @ _*) => + Res.Failure( + "To select a subcommand to run, you don't need --s." + Util.newLine + + s"Did you mean `${head.drop(2)}` instead of `$head`?" + ) + case Seq((head, None), tail @ _*) => + mainMethods.find(_.name == head) match{ + case None => + Res.Failure( + s"Unable to find subcommand: " + backtickWrap(head) + suffix + ) + case Some(main) => + runMainMethod(mainObj, main, tail) + } + } + } + } + } yield res + } + def formatMainMethods[T](base: T, mainMethods: Seq[Router.EntryPoint[T]]) = { + if (mainMethods.isEmpty) "" + else{ + val leftColWidth = getLeftColWidth(mainMethods.flatMap(_.argSignatures)) + + val methods = + for(main <- mainMethods) + yield formatMainMethodSignature(base, main, 2, leftColWidth) + + Util.normalizeNewlines( + s""" + | + |Available subcommands: + | + |${methods.mkString(Util.newLine)}""".stripMargin + ) + } + } + def getLeftColWidth[T](items: Seq[ArgSig[T, _]]) = { + items.map(_.name.length + 2) match{ + case Nil => 0 + case x => x.max + } + } + def formatMainMethodSignature[T](base: T, + main: Router.EntryPoint[T], + leftIndent: Int, + leftColWidth: Int) = { + // +2 for space on right of left col + val args = main.argSignatures.map(renderArg(base, _, leftColWidth + leftIndent + 2 + 2, 80)) + + val leftIndentStr = " " * leftIndent + val argStrings = + for((lhs, rhs) <- args) + yield { + val lhsPadded = lhs.padTo(leftColWidth, ' ') + val rhsPadded = rhs.linesIterator.mkString(Util.newLine) + s"$leftIndentStr $lhsPadded $rhsPadded" + } + val mainDocSuffix = main.doc match{ + case Some(d) => Util.newLine + leftIndentStr + softWrap(d, leftIndent, 80) + case None => "" + } + + s"""$leftIndentStr${main.name}$mainDocSuffix + |${argStrings.map(_ + Util.newLine).mkString}""".stripMargin + } + def runMainMethod[T](base: T, + mainMethod: Router.EntryPoint[T], + scriptArgs: Seq[(String, Option[String])]): Res[Any] = { + val leftColWidth = getLeftColWidth(mainMethod.argSignatures) + + def expectedMsg = formatMainMethodSignature(base: T, mainMethod, 0, leftColWidth) + + def pluralize(s: String, n: Int) = { + if (n == 1) s else s + "s" + } + + mainMethod.invoke(base, scriptArgs) match{ + case Router.Result.Success(x) => Res.Success(x) + case Router.Result.Error.Exception(x: AmmoniteExit) => Res.Success(x.value) + case Router.Result.Error.Exception(x) => Res.Exception(x, "") + case Router.Result.Error.MismatchedArguments(missing, unknown, duplicate, incomplete) => + val missingStr = + if (missing.isEmpty) "" + else { + val chunks = + for (x <- missing) + yield "--" + x.name + ": " + x.typeString + + val argumentsStr = pluralize("argument", chunks.length) + s"Missing $argumentsStr: (${chunks.mkString(", ")})" + Util.newLine + } + + + val unknownStr = + if (unknown.isEmpty) "" + else { + val argumentsStr = pluralize("argument", unknown.length) + s"Unknown $argumentsStr: " + unknown.map(literalize(_)).mkString(" ") + Util.newLine + } + + val duplicateStr = + if (duplicate.isEmpty) "" + else { + val lines = + for ((sig, options) <- duplicate) + yield { + s"Duplicate arguments for (--${sig.name}: ${sig.typeString}): " + + options.map(literalize(_)).mkString(" ") + Util.newLine + } + + lines.mkString + + } + val incompleteStr = incomplete match{ + case None => "" + case Some(sig) => + s"Option (--${sig.name}: ${sig.typeString}) is missing a corresponding value" + + Util.newLine + + } + + Res.Failure( + Util.normalizeNewlines( + s"""$missingStr$unknownStr$duplicateStr$incompleteStr + |Arguments provided did not match expected signature: + | + |$expectedMsg + |""".stripMargin + ) + ) + + case Router.Result.Error.InvalidArguments(x) => + val argumentsStr = pluralize("argument", x.length) + val thingies = x.map{ + case Router.Result.ParamError.Invalid(p, v, ex) => + val literalV = literalize(v) + val rendered = {renderArgShort(p)} + s"$rendered: ${p.typeString} = $literalV failed to parse with $ex" + case Router.Result.ParamError.DefaultFailed(p, ex) => + s"${renderArgShort(p)}'s default value failed to evaluate with $ex" + } + + Res.Failure( + Util.normalizeNewlines( + s"""The following $argumentsStr failed to parse: + | + |${thingies.mkString(Util.newLine)} + | + |expected signature: + | + |$expectedMsg + """.stripMargin + ) + ) + } + } + + def softWrap(s: String, leftOffset: Int, maxWidth: Int) = { + val oneLine = s.linesIterator.mkString(" ").split(' ') + + lazy val indent = " " * leftOffset + + val output = new StringBuilder(oneLine.head) + var currentLineWidth = oneLine.head.length + for(chunk <- oneLine.tail){ + val addedWidth = currentLineWidth + chunk.length + 1 + if (addedWidth > maxWidth){ + output.append(Util.newLine + indent) + output.append(chunk) + currentLineWidth = chunk.length + } else{ + currentLineWidth = addedWidth + output.append(' ') + output.append(chunk) + } + } + output.mkString + } + def renderArgShort[T](arg: ArgSig[T, _]) = "--" + backtickWrap(arg.name) + def renderArg[T](base: T, + arg: ArgSig[T, _], + leftOffset: Int, + wrappedWidth: Int): (String, String) = { + val suffix = arg.default match{ + case Some(f) => " (default " + f(base) + ")" + case None => "" + } + val docSuffix = arg.doc match{ + case Some(d) => ": " + d + case None => "" + } + val wrapped = softWrap( + arg.typeString + suffix + docSuffix, + leftOffset, + wrappedWidth - leftOffset + ) + (renderArgShort(arg), wrapped) + } + + + def mainMethodDetails[T](ep: EntryPoint[T]) = { + ep.argSignatures.collect{ + case ArgSig(name, tpe, Some(doc), default) => + Util.newLine + name + " // " + doc + }.mkString + } + + /** + * Additional [[scopt.Read]] instance to teach it how to read Ammonite paths + */ + implicit def pathScoptRead: scopt.Read[os.Path] = scopt.Read.stringRead.map(os.Path(_, os.pwd)) + +} diff --git a/main/core/src/util/Watched.scala b/main/core/src/util/Watched.scala new file mode 100644 index 00000000..29be53c3 --- /dev/null +++ b/main/core/src/util/Watched.scala @@ -0,0 +1,8 @@ +package mill.util + +import mill.api.PathRef + +case class Watched[T](value: T, watched: Seq[PathRef]) +object Watched{ + implicit def readWrite[T: upickle.default.ReadWriter] = upickle.default.macroRW[Watched[T]] +} diff --git a/main/core/src/util/package.scala b/main/core/src/util/package.scala new file mode 100644 index 00000000..ec5d2efc --- /dev/null +++ b/main/core/src/util/package.scala @@ -0,0 +1,7 @@ +package mill + +package object util { + // Backwards compat stubs + val Ctx = mill.api.Ctx + type Ctx = mill.api.Ctx +} |