diff options
author | Li Haoyi <haoyi.sg@gmail.com> | 2019-09-16 08:58:31 +0800 |
---|---|---|
committer | Li Haoyi <haoyi.sg@gmail.com> | 2019-09-16 08:58:31 +0800 |
commit | 583893aff7e39e085dbf5bec27c9b0b24e5e8d2e (patch) | |
tree | b2cf94c9bef2824159d8c9bd4f4867163a6ef36c /cask/src/cask | |
parent | 4851f249c8124ce725576f4f87f097f16e2f3843 (diff) | |
download | cask-583893aff7e39e085dbf5bec27c9b0b24e5e8d2e.tar.gz cask-583893aff7e39e085dbf5bec27c9b0b24e5e8d2e.tar.bz2 cask-583893aff7e39e085dbf5bec27c9b0b24e5e8d2e.zip |
Break up `Router.scala` into a `router/` folder with multiple files
Diffstat (limited to 'cask/src/cask')
19 files changed, 426 insertions, 416 deletions
diff --git a/cask/src/cask/decorators/compress.scala b/cask/src/cask/decorators/compress.scala index 7a7eefb..0ffab28 100644 --- a/cask/src/cask/decorators/compress.scala +++ b/cask/src/cask/decorators/compress.scala @@ -2,7 +2,6 @@ package cask.decorators import java.io.{ByteArrayOutputStream, OutputStream} import java.util.zip.{DeflaterOutputStream, GZIPOutputStream} -import cask.internal.Router import cask.model.{Request, Response} import collection.JavaConverters._ diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala index 6f65786..84ba618 100644 --- a/cask/src/cask/endpoints/FormEndpoint.scala +++ b/cask/src/cask/endpoints/FormEndpoint.scala @@ -1,13 +1,14 @@ package cask.endpoints -import cask.internal.{Router, Util} -import cask.main.HttpEndpoint +import cask.internal.Util +import cask.router.HttpEndpoint import cask.model._ +import cask.router.{ArgReader, Result} import io.undertow.server.handlers.form.FormParserFactory import collection.JavaConverters._ -sealed trait FormReader[T] extends Router.ArgReader[Seq[FormEntry], T, Request] +sealed trait FormReader[T] extends ArgReader[Seq[FormEntry], T, Request] object FormReader{ implicit def paramFormReader[T: QueryParamReader] = new FormReader[T]{ def arity = implicitly[QueryParamReader[T]].arity @@ -49,7 +50,7 @@ class postForm(val path: String, override val subpath: Boolean = false) val methods = Seq("post") type InputParser[T] = FormReader[T] def wrapFunction(ctx: Request, - delegate: Delegate): Router.Result[Response.Raw] = { + delegate: Delegate): Result[Response.Raw] = { try { val formData = FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking() delegate( @@ -60,7 +61,7 @@ class postForm(val path: String, override val subpath: Boolean = false) .toMap ) } catch{case e: Exception => - Router.Result.Success(cask.model.Response( + Result.Success(cask.model.Response( "Unable to parse form data: " + e + "\n" + Util.stackTraceString(e), statusCode = 400 )) diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala index 6d3db82..842eae6 100644 --- a/cask/src/cask/endpoints/JsonEndpoint.scala +++ b/cask/src/cask/endpoints/JsonEndpoint.scala @@ -2,14 +2,15 @@ package cask.endpoints import java.io.{ByteArrayOutputStream, InputStream, OutputStream, OutputStreamWriter} -import cask.internal.{Router, Util} -import cask.main.HttpEndpoint +import cask.internal.Util +import cask.router.HttpEndpoint import cask.model.Response.DataCompanion import cask.model.{Request, Response} +import cask.router.{ArgReader, Result} import collection.JavaConverters._ -sealed trait JsReader[T] extends Router.ArgReader[ujson.Value, T, cask.model.Request] +sealed trait JsReader[T] extends ArgReader[ujson.Value, T, cask.model.Request] object JsReader{ implicit def defaultJsReader[T: upickle.default.Reader] = new JsReader[T]{ def arity = 1 @@ -43,9 +44,9 @@ class postJson(val path: String, override val subpath: Boolean = false) extends HttpEndpoint[Response[JsonData], ujson.Value]{ val methods = Seq("post") type InputParser[T] = JsReader[T] - override type OuterReturned = Router.Result[Response.Raw] + override type OuterReturned = Result[Response.Raw] def wrapFunction(ctx: Request, - delegate: Delegate): Router.Result[Response.Raw] = { + delegate: Delegate): Result[Response.Raw] = { val obj = for{ str <- try { @@ -71,7 +72,7 @@ class postJson(val path: String, override val subpath: Boolean = false) ))} } yield obj.toMap obj match{ - case Left(r) => Router.Result.Success(r.map(Response.Data.StringData)) + case Left(r) => Result.Success(r.map(Response.Data.StringData)) case Right(params) => delegate(params) } } @@ -82,8 +83,8 @@ class getJson(val path: String, override val subpath: Boolean = false) extends HttpEndpoint[Response[JsonData], Seq[String]]{ val methods = Seq("get") type InputParser[T] = QueryParamReader[T] - override type OuterReturned = Router.Result[Response.Raw] - def wrapFunction(ctx: Request, delegate: Delegate): Router.Result[Response.Raw] = { + override type OuterReturned = Result[Response.Raw] + def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = { delegate(WebEndpoint.buildMapFromQueryParams(ctx)) } diff --git a/cask/src/cask/endpoints/ParamReader.scala b/cask/src/cask/endpoints/ParamReader.scala index e43f482..4ac34f0 100644 --- a/cask/src/cask/endpoints/ParamReader.scala +++ b/cask/src/cask/endpoints/ParamReader.scala @@ -1,11 +1,11 @@ package cask.endpoints -import cask.internal.Router +import cask.router.ArgReader import cask.model.{Cookie, Request} import io.undertow.server.HttpServerExchange import io.undertow.server.handlers.form.{FormData, FormParserFactory} -abstract class ParamReader[T] extends Router.ArgReader[Unit, T, cask.model.Request]{ +abstract class ParamReader[T] extends ArgReader[Unit, T, cask.model.Request]{ def arity: Int def read(ctx: cask.model.Request, label: String, v: Unit): T } diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala index 0abfcf5..1e11055 100644 --- a/cask/src/cask/endpoints/StaticEndpoints.scala +++ b/cask/src/cask/endpoints/StaticEndpoints.scala @@ -1,6 +1,6 @@ package cask.endpoints -import cask.main.HttpEndpoint +import cask.router.HttpEndpoint import cask.model.Request class staticFiles(val path: String) extends HttpEndpoint[String, Seq[String]]{ diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala index 8bb3bae..89ca421 100644 --- a/cask/src/cask/endpoints/WebEndpoints.scala +++ b/cask/src/cask/endpoints/WebEndpoints.scala @@ -1,8 +1,8 @@ package cask.endpoints -import cask.internal.Router -import cask.main.HttpEndpoint +import cask.router.HttpEndpoint import cask.model.{Request, Response} +import cask.router.{ArgReader, Result} import collection.JavaConverters._ @@ -10,7 +10,7 @@ import collection.JavaConverters._ trait WebEndpoint extends HttpEndpoint[Response.Raw, Seq[String]]{ type InputParser[T] = QueryParamReader[T] def wrapFunction(ctx: Request, - delegate: Delegate): Router.Result[Response.Raw] = { + delegate: Delegate): Result[Response.Raw] = { delegate(WebEndpoint.buildMapFromQueryParams(ctx)) } def wrapPathSegment(s: String) = Seq(s) @@ -40,7 +40,7 @@ class put(val path: String, override val subpath: Boolean = false) extends WebEn class route(val path: String, val methods: Seq[String], override val subpath: Boolean = false) extends WebEndpoint abstract class QueryParamReader[T] - extends Router.ArgReader[Seq[String], T, cask.model.Request]{ + extends ArgReader[Seq[String], T, cask.model.Request]{ def arity: Int def read(ctx: cask.model.Request, label: String, v: Seq[String]): T } diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala index fae7fde..6ca5def 100644 --- a/cask/src/cask/endpoints/WebSocketEndpoint.scala +++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala @@ -2,8 +2,8 @@ package cask.endpoints import java.nio.ByteBuffer -import cask.internal.Router import cask.model.Request +import cask.router.Result import cask.util.Logger import io.undertow.websockets.WebSocketConnectionCallback import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets} @@ -21,10 +21,10 @@ object WebsocketResult{ } class websocket(val path: String, override val subpath: Boolean = false) - extends cask.main.Endpoint[WebsocketResult, Seq[String]]{ + extends cask.router.Endpoint[WebsocketResult, Seq[String]]{ val methods = Seq("websocket") type InputParser[T] = QueryParamReader[T] - type OuterReturned = Router.Result[WebsocketResult] + type OuterReturned = Result[WebsocketResult] def wrapFunction(ctx: Request, delegate: Delegate): OuterReturned = { delegate(WebEndpoint.buildMapFromQueryParams(ctx)) } diff --git a/cask/src/cask/internal/Router.scala b/cask/src/cask/internal/Router.scala deleted file mode 100644 index 4b11811..0000000 --- a/cask/src/cask/internal/Router.scala +++ /dev/null @@ -1,351 +0,0 @@ -package cask.internal - -import language.experimental.macros -import scala.annotation.StaticAnnotation -import scala.collection.mutable -import scala.reflect.macros.blackbox.Context - -/** - * More or less a minimal version of Autowire's Server that lets you generate - * a set of "routes" from the methods defined in an object, and call them - * using passing in name/args/kwargs via Java reflection, without having to - * generate/compile code or use Scala reflection. This saves us spinning up - * the Scala compiler and greatly reduces the startup time of cached scripts. - */ -object Router{ - class doc(s: String) extends StaticAnnotation - - /** - * Models what is known by the router about a single argument: that it has - * a [[name]], a human-readable [[typeString]] describing what the type is - * (just for logging and reading, not a replacement for a `TypeTag`) and - * possible a function that can compute its default value - */ - case class ArgSig[I, -T, +V, -C](name: String, - typeString: String, - doc: Option[String], - default: Option[T => V]) - (implicit val reads: ArgReader[I, V, C]) - - trait ArgReader[I, +T, -C]{ - def arity: Int - def read(ctx: C, label: String, input: I): T - } - - /** - * What is known about a single endpoint for our routes. It has a [[name]], - * [[argSignatures]] for each argument, and a macro-generated [[invoke0]] - * that performs all the necessary argument parsing and de-serialization. - * - * Realistically, you will probably spend most of your time calling [[invoke]] - * instead, which provides a nicer API to call it that mimmicks the API of - * calling a Scala method. - */ - case class EntryPoint[T, C](name: String, - argSignatures: Seq[Seq[ArgSig[_, T, _, C]]], - doc: Option[String], - invoke0: (T, C, Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){ - - val firstArgs = argSignatures.head - .map(x => x.name -> x) - .toMap[String, Router.ArgSig[_, T, _, C]] - - def invoke(target: T, - ctx: C, - paramLists: Seq[Map[String, Any]]): Result[Any] = { - - val missing = mutable.Buffer.empty[Router.ArgSig[_, T, _, C]] - - val unknown = paramLists.head.keys.filter(!firstArgs.contains(_)) - - for(k <- firstArgs.keys) { - if (!paramLists.head.contains(k)) { - val as = firstArgs(k) - if (as.reads.arity != 0 && as.default.isEmpty) missing.append(as) - } - } - - if (missing.nonEmpty || unknown.nonEmpty) Result.Error.MismatchedArguments(missing.toSeq, unknown.toSeq) - else { - try invoke0( - target, - ctx, - paramLists, - argSignatures.asInstanceOf[Seq[Seq[ArgSig[Any, _, _, C]]]] - ) - catch{case e: Throwable => Result.Error.Exception(e)} - } - } - } - - def tryEither[T](t: => T, error: Throwable => Result.ParamError) = { - try Right(t) - catch{ case e: Throwable => Left(error(e))} - } - - /** - * Represents what comes out of an attempt to invoke an [[EntryPoint]]. - * Could succeed with a value, but could fail in many different ways. - */ - sealed trait Result[+T]{ - def map[V](f: T => V): Result[V] - } - object Result{ - - /** - * Invoking the [[EntryPoint]] was totally successful, and returned a - * result - */ - case class Success[T](value: T) extends Result[T]{ - def map[V](f: T => V) = Success(f(value)) - } - - /** - * Invoking the [[EntryPoint]] was not successful - */ - sealed trait Error extends Result[Nothing]{ - def map[V](f: Nothing => V) = this - } - - - object Error{ - - - /** - * Invoking the [[EntryPoint]] failed with an exception while executing - * code within it. - */ - case class Exception(t: Throwable) extends Error - - /** - * Invoking the [[EntryPoint]] failed because the arguments provided - * did not line up with the arguments expected - */ - case class MismatchedArguments(missing: Seq[ArgSig[_, _, _, _]], - unknown: Seq[String]) extends Error - /** - * Invoking the [[EntryPoint]] failed because there were problems - * deserializing/parsing individual arguments - */ - case class InvalidArguments(values: Seq[ParamError]) extends Error - } - - sealed trait ParamError - object ParamError{ - /** - * Something went wrong trying to de-serialize the input parameter; - * the thrown exception is stored in [[ex]] - */ - case class Invalid(arg: ArgSig[_, _, _, _], value: String, ex: Throwable) extends ParamError - /** - * Something went wrong trying to evaluate the default value - * for this input parameter - */ - case class DefaultFailed(arg: ArgSig[_, _, _, _], ex: Throwable) extends ParamError - } - } - - def validate(args: Seq[Either[Seq[Result.ParamError], Any]]): Result[Seq[Any]] = { - val lefts = args.collect{case Left(x) => x}.flatten - - if (lefts.nonEmpty) Result.Error.InvalidArguments(lefts) - else { - val rights = args.collect{case Right(x) => x} - Result.Success(rights) - } - } - - def makeReadCall[I, C](dict: Map[String, I], - ctx: C, - default: => Option[Any], - arg: ArgSig[I, _, _, C]) = { - arg.reads.arity match{ - case 0 => - tryEither( - arg.reads.read(ctx, arg.name, null.asInstanceOf[I]), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) - case 1 => - dict.get(arg.name) match{ - case None => - tryEither(default.get, Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) - - case Some(x) => - tryEither(arg.reads.read(ctx, arg.name, x), Result.ParamError.Invalid(arg, x.toString, _)).left.map(Seq(_)) - } - } - } -} - - -class Router[C <: Context](val c: C) { - import c.universe._ - def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = { - def isAMemberOfAnyRef(member: Symbol) = { - // AnyRef is an alias symbol, we go to the real "owner" of these methods - val anyRefSym = c.mirror.universe.definitions.ObjectClass - member.owner == anyRefSym - } - val extractableMembers = for { - member <- curCls.members.toList.reverse - if !isAMemberOfAnyRef(member) - if !member.isSynthetic - if member.isPublic - if member.isTerm - memTerm = member.asTerm - if memTerm.isMethod - if !memTerm.isModule - } yield memTerm.asMethod - - extractableMembers flatMap { case memTerm => - if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil - else Seq(memTerm) - - } - } - - - - def unwrapVarargType(arg: Symbol) = { - val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass - val unwrappedType = - if (!vararg) arg.typeSignature - else arg.typeSignature.asInstanceOf[TypeRef].args(0) - - (vararg, unwrappedType) - } - - def extractMethod(method: MethodSymbol, - curCls: c.universe.Type, - convertToResultType: c.Tree, - ctx: c.Tree, - argReaders: Seq[c.Tree], - annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = { - val baseArgSym = TermName(c.freshName()) - - def getDocAnnotation(annotations: List[Annotation]) = { - val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[Router.doc]) - val docValues = for { - doc <- docTrees - if doc.scalaArgs.head.isInstanceOf[Literal] - l = doc.scalaArgs.head.asInstanceOf[Literal] - if l.value.value.isInstanceOf[String] - } yield l.value.value.asInstanceOf[String] - (remaining, docValues.headOption) - } - val (_, methodDoc) = getDocAnnotation(method.annotations) - val argValuesSymbol = q"${c.fresh[TermName]("argValues")}" - val argSigsSymbol = q"${c.fresh[TermName]("argSigs")}" - val ctxSymbol = q"${c.fresh[TermName]("ctx")}" - val argData = for(argListIndex <- method.paramLists.indices) yield{ - val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any") - val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.main.NoOpParser.instanceAny") - val flattenedArgLists = method.paramss(argListIndex) - def hasDefault(i: Int) = { - val defaultName = s"${method.name}$$default$$${i + 1}" - if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName) - else None - } - - val defaults = for (i <- flattenedArgLists.indices) yield { - val arg = TermName(c.freshName()) - hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}") - } - - val readArgSigs = for ( - ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex - ) yield { - - if (arg.typeSignature.typeSymbol == definitions.RepeatedParamClass) c.abort(method.pos, "Varargs are not supported in cask routes") - - val default = defaultOpt match { - case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))" - case None => q"scala.None" - } - - val (docUnwrappedType, docOpt) = arg.typeSignature match { - case t: AnnotatedType => - import compat._ - val (remaining, docValue) = getDocAnnotation(t.annotations) - if (remaining.isEmpty) (t.underlying, docValue) - else (c.universe.AnnotatedType(remaining, t.underlying), docValue) - - case t => (t, None) - } - - val docTree = docOpt match { - case None => q"scala.None" - case Some(s) => q"scala.Some($s)" - } - - val argSig = - q""" - cask.internal.Router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx]( - ${arg.name.toString}, - ${docUnwrappedType.toString}, - $docTree, - $defaultOpt - )($argReader[$docUnwrappedType]) - """ - - val reader = q""" - cask.internal.Router.makeReadCall( - $argValuesSymbol($argListIndex), - $ctxSymbol, - $default, - $argSigsSymbol($argListIndex)($i) - ) - """ - - c.internal.setPos(reader, method.pos) - (reader, argSig) - } - - val (readArgs, argSigs) = readArgSigs.unzip - val (argNames, argNameCasts) = flattenedArgLists.map { arg => - val (vararg, unwrappedType) = unwrapVarargType(arg) - ( - pq"${arg.name.toTermName}", - if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]" - else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*" - - ) - }.unzip - - (argNameCasts, argSigs, argNames, readArgs) - } - - val argNameCasts = argData.map(_._1) - val argSigs = argData.map(_._2) - val argNames = argData.map(_._3) - val readArgs = argData.map(_._4) - var methodCall: c.Tree = q"$baseArgSym.${method.name.toTermName}" - for(argNameCast <- argNameCasts) methodCall = q"$methodCall(..$argNameCast)" - - val res = q""" - cask.internal.Router.EntryPoint[$curCls, $ctx]( - ${method.name.toString}, - ${argSigs.toList}, - ${methodDoc match{ - case None => q"scala.None" - case Some(s) => q"scala.Some($s)" - }}, - ( - $baseArgSym: $curCls, - $ctxSymbol: $ctx, - $argValuesSymbol: Seq[Map[String, Any]], - $argSigsSymbol: scala.Seq[scala.Seq[cask.internal.Router.ArgSig[Any, _, _, $ctx]]] - ) => - cask.internal.Router.validate(Seq(..${readArgs.flatten.toList})).map{ - case Seq(..${argNames.flatten.toList}) => $convertToResultType($methodCall) - } - ) - """ - - c.internal.transform(res){(t, a) => - c.internal.setPos(t, method.pos) - a.default(t) - } - - res - } - -} diff --git a/cask/src/cask/main/ErrorMsgs.scala b/cask/src/cask/main/ErrorMsgs.scala index 254f4e0..a22bd89 100644 --- a/cask/src/cask/main/ErrorMsgs.scala +++ b/cask/src/cask/main/ErrorMsgs.scala @@ -1,11 +1,12 @@ package cask.main -import cask.internal.{Router, Util} +import cask.internal.Util import cask.internal.Util.literalize +import cask.router.{ArgSig, EntryPoint, Result} object ErrorMsgs { - def getLeftColWidth(items: Seq[Router.ArgSig[_, _, _,_]]) = { + def getLeftColWidth(items: Seq[ArgSig[_, _, _,_]]) = { items.map(_.name.length + 2) match{ case Nil => 0 case x => x.max @@ -13,7 +14,7 @@ object ErrorMsgs { } def renderArg[T](base: T, - arg: Router.ArgSig[_, T, _, _], + arg: ArgSig[_, T, _, _], leftOffset: Int, wrappedWidth: Int): (String, String) = { val suffix = arg.default match{ @@ -33,7 +34,7 @@ object ErrorMsgs { } def formatMainMethodSignature[T](base: T, - main: Router.EntryPoint[T, _], + main: EntryPoint[T, _], leftIndent: Int, leftColWidth: Int) = { // +2 for space on right of left col @@ -56,12 +57,12 @@ object ErrorMsgs { |${argStrings.map(_ + "\n").mkString}""".stripMargin } - def formatInvokeError[T](base: T, route: Router.EntryPoint[T, _], x: Router.Result.Error): String = { + def formatInvokeError[T](base: T, route: EntryPoint[T, _], x: Result.Error): String = { def expectedMsg = formatMainMethodSignature(base: T, route, 0, 0) x match{ - case Router.Result.Error.Exception(x) => Util.stackTraceString(x) - case Router.Result.Error.MismatchedArguments(missing, unknown) => + case Result.Error.Exception(x) => Util.stackTraceString(x) + case Result.Error.MismatchedArguments(missing, unknown) => val missingStr = if (missing.isEmpty) "" else { @@ -88,14 +89,14 @@ object ErrorMsgs { |$expectedMsg |""".stripMargin - case Router.Result.Error.InvalidArguments(x) => + case Result.Error.InvalidArguments(x) => val argumentsStr = Util.pluralize("argument", x.length) val thingies = x.map{ - case Router.Result.ParamError.Invalid(p, v, ex) => + case Result.ParamError.Invalid(p, v, ex) => val literalV = literalize(v) val trace = Util.stackTraceString(ex) s"${p.name}: ${p.typeString} = $literalV failed to parse with $ex\n$trace" - case Router.Result.ParamError.DefaultFailed(p, ex) => + case Result.ParamError.DefaultFailed(p, ex) => val trace = Util.stackTraceString(ex) s"${p.name}'s default value failed to evaluate with $ex\n$trace" } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index fddd9b7..6d08c04 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -2,9 +2,9 @@ package cask.main import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ -import cask.internal.Router.EntryPoint -import cask.internal.{DispatchTrie, Router, Util} +import cask.internal.{DispatchTrie, Util} import cask.main +import cask.router.{Decorator, EndpointMetadata, EntryPoint, RawDecorator, Result} import cask.util.Logger import io.undertow.Undertow import io.undertow.server.{HttpHandler, HttpServerExchange} @@ -27,7 +27,7 @@ class MainRoutes extends Main with Routes{ * application-wide properties. */ abstract class Main{ - def mainDecorators: Seq[cask.main.RawDecorator] = Nil + def mainDecorators: Seq[RawDecorator] = Nil def allRoutes: Seq[Routes] def port: Int = 8080 def host: String = "localhost" @@ -45,7 +45,7 @@ abstract class Main{ def handleEndpointError(routes: Routes, metadata: EndpointMetadata[_], - e: Router.Result.Error) = { + e: cask.router.Result.Error) = { Main.defaultHandleError(routes, metadata, e, debugMode) } @@ -64,7 +64,7 @@ object Main{ mainDecorators: Seq[RawDecorator], debugMode: Boolean, handleNotFound: () => Response.Raw, - handleError: (Routes, EndpointMetadata[_], Router.Result.Error) => Response.Raw) + handleError: (Routes, EndpointMetadata[_], Result.Error) => Response.Raw) (implicit log: Logger) extends HttpHandler() { def handleRequest(exchange: HttpServerExchange): Unit = try { // println("Handling Request: " + exchange.getRequestPath) @@ -98,8 +98,8 @@ object Main{ (mainDecorators ++ routes.decorators ++ metadata.decorators).toList, Nil ) match{ - case Router.Result.Success(res) => runner(res) - case e: Router.Result.Error => + case Result.Success(res) => runner(res) + case e: Result.Error => Main.writeResponse( exchange, handleError(routes, metadata, e) @@ -145,17 +145,17 @@ object Main{ def defaultHandleError(routes: Routes, metadata: EndpointMetadata[_], - e: Router.Result.Error, + e: Result.Error, debugMode: Boolean) (implicit log: Logger) = { e match { - case e: Router.Result.Error.Exception => log.exception(e.t) + case e: Result.Error.Exception => log.exception(e.t) case _ => // do nothing } val statusCode = e match { - case _: Router.Result.Error.Exception => 500 - case _: Router.Result.Error.InvalidArguments => 400 - case _: Router.Result.Error.MismatchedArguments => 400 + case _: Result.Error.Exception => 500 + case _: Result.Error.InvalidArguments => 400 + case _: Result.Error.MismatchedArguments => 400 } val str = diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala index 9be9f50..512860a 100644 --- a/cask/src/cask/main/Routes.scala +++ b/cask/src/cask/main/Routes.scala @@ -1,10 +1,12 @@ package cask.main +import cask.router.RoutesEndpointsMetadata + import language.experimental.macros trait Routes{ - def decorators = Seq.empty[cask.main.RawDecorator] + def decorators = Seq.empty[cask.router.RawDecorator] private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null def caskMetadata = if (metadata0 != null) metadata0 diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala index d9e29ba..d34fe26 100644 --- a/cask/src/cask/package.scala +++ b/cask/src/cask/package.scala @@ -39,8 +39,8 @@ package object cask { type Routes = main.Routes type Main = main.Main - type RawDecorator = main.RawDecorator - type HttpEndpoint[InnerReturned, Input] = main.HttpEndpoint[InnerReturned, Input] + type RawDecorator = router.RawDecorator + type HttpEndpoint[InnerReturned, Input] = router.HttpEndpoint[InnerReturned, Input] type WsHandler = cask.endpoints.WsHandler val WsHandler = cask.endpoints.WsHandler diff --git a/cask/src/cask/main/Decorators.scala b/cask/src/cask/router/Decorators.scala index fb795ba..cd2c1a3 100644 --- a/cask/src/cask/main/Decorators.scala +++ b/cask/src/cask/router/Decorators.scala @@ -1,8 +1,8 @@ -package cask.main +package cask.router -import cask.internal.{Conversion, Router} -import cask.internal.Router.{ArgReader, EntryPoint} +import cask.internal.Conversion import cask.model.{Request, Response} +import cask.router.{ArgReader, EntryPoint, Result} /** * A [[Decorator]] allows you to annotate a function to wrap it, via @@ -18,8 +18,8 @@ import cask.model.{Request, Response} trait Decorator[InnerReturned, Input]{ final type InputTypeAlias = Input type InputParser[T] <: ArgReader[Input, T, Request] - final type Delegate = Map[String, Input] => Router.Result[InnerReturned] - type OuterReturned <: Router.Result[Any] + final type Delegate = Map[String, Input] => Result[InnerReturned] + type OuterReturned <: Result[Any] def wrapFunction(ctx: Request, delegate: Delegate): OuterReturned def getParamParser[T](implicit p: InputParser[T]) = p } @@ -41,13 +41,13 @@ object Decorator{ routes: T, routeBindings: Map[String, String], remainingDecorators: List[RawDecorator], - bindings: List[Map[String, Any]]): Router.Result[Any] = try { + bindings: List[Map[String, Any]]): Result[Any] = try { remainingDecorators match { case head :: rest => head.wrapFunction( ctx, args => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, args :: bindings) - .asInstanceOf[Router.Result[cask.model.Response.Raw]] + .asInstanceOf[Result[cask.model.Response.Raw]] ) case Nil => @@ -58,13 +58,13 @@ object Decorator{ entryPoint .asInstanceOf[EntryPoint[T, cask.model.Request]] .invoke(routes, ctx, finalBindings) - .asInstanceOf[Router.Result[Nothing]] + .asInstanceOf[Result[Nothing]] }) } // Make sure we wrap any exceptions that bubble up from decorator // bodies, so outer decorators do not need to worry about their // delegate throwing on them - }catch{case e: Throwable => Router.Result.Error.Exception(e) } + }catch{case e: Throwable => Result.Error.Exception(e) } } /** @@ -72,7 +72,7 @@ object Decorator{ * response stream, before and after the primary [[Endpoint]] does it's job. */ trait RawDecorator extends Decorator[Response.Raw, Any]{ - type OuterReturned = Router.Result[Response.Raw] + type OuterReturned = Result[Response.Raw] type InputParser[T] = NoOpParser[Any, T] } @@ -121,7 +121,7 @@ trait Endpoint[InnerReturned, Input] extends Decorator[InnerReturned, Input]{ * [[RawDecorator]] but with additional metadata and capabilities. */ trait HttpEndpoint[InnerReturned, Input] extends Endpoint[InnerReturned, Input] { - type OuterReturned = Router.Result[Response.Raw] + type OuterReturned = Result[Response.Raw] } diff --git a/cask/src/cask/router/EntryPoint.scala b/cask/src/cask/router/EntryPoint.scala new file mode 100644 index 0000000..6fe44fc --- /dev/null +++ b/cask/src/cask/router/EntryPoint.scala @@ -0,0 +1,51 @@ +package cask.router + + +import scala.collection.mutable + + +/** + * What is known about a single endpoint for our routes. It has a [[name]], + * [[argSignatures]] for each argument, and a macro-generated [[invoke0]] + * that performs all the necessary argument parsing and de-serialization. + * + * Realistically, you will probably spend most of your time calling [[invoke]] + * instead, which provides a nicer API to call it that mimmicks the API of + * calling a Scala method. + */ +case class EntryPoint[T, C](name: String, + argSignatures: Seq[Seq[ArgSig[_, T, _, C]]], + doc: Option[String], + invoke0: (T, C, Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){ + + val firstArgs = argSignatures.head + .map(x => x.name -> x) + .toMap[String, ArgSig[_, T, _, C]] + + def invoke(target: T, + ctx: C, + paramLists: Seq[Map[String, Any]]): Result[Any] = { + + val missing = mutable.Buffer.empty[ArgSig[_, T, _, C]] + + val unknown = paramLists.head.keys.filter(!firstArgs.contains(_)) + + for(k <- firstArgs.keys) { + if (!paramLists.head.contains(k)) { + val as = firstArgs(k) + if (as.reads.arity != 0 && as.default.isEmpty) missing.append(as) + } + } + + if (missing.nonEmpty || unknown.nonEmpty) Result.Error.MismatchedArguments(missing.toSeq, unknown.toSeq) + else { + try invoke0( + target, + ctx, + paramLists, + argSignatures.asInstanceOf[Seq[Seq[ArgSig[Any, _, _, C]]]] + ) + catch{case e: Throwable => Result.Error.Exception(e)} + } + } +} diff --git a/cask/src/cask/router/Macros.scala b/cask/src/cask/router/Macros.scala new file mode 100644 index 0000000..de27e5c --- /dev/null +++ b/cask/src/cask/router/Macros.scala @@ -0,0 +1,178 @@ +package cask.router + +import scala.reflect.macros.blackbox + + +class Macros[C <: blackbox.Context](val c: C) { + import c.universe._ + def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = { + def isAMemberOfAnyRef(member: Symbol) = { + // AnyRef is an alias symbol, we go to the real "owner" of these methods + val anyRefSym = c.mirror.universe.definitions.ObjectClass + member.owner == anyRefSym + } + val extractableMembers = for { + member <- curCls.members.toList.reverse + if !isAMemberOfAnyRef(member) + if !member.isSynthetic + if member.isPublic + if member.isTerm + memTerm = member.asTerm + if memTerm.isMethod + if !memTerm.isModule + } yield memTerm.asMethod + + extractableMembers flatMap { case memTerm => + if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil + else Seq(memTerm) + + } + } + + + + def unwrapVarargType(arg: Symbol) = { + val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass + val unwrappedType = + if (!vararg) arg.typeSignature + else arg.typeSignature.asInstanceOf[TypeRef].args(0) + + (vararg, unwrappedType) + } + + def extractMethod(method: MethodSymbol, + curCls: c.universe.Type, + convertToResultType: c.Tree, + ctx: c.Tree, + argReaders: Seq[c.Tree], + annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = { + val baseArgSym = TermName(c.freshName()) + + def getDocAnnotation(annotations: List[Annotation]) = { + val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[doc]) + val docValues = for { + doc <- docTrees + if doc.scalaArgs.head.isInstanceOf[Literal] + l = doc.scalaArgs.head.asInstanceOf[Literal] + if l.value.value.isInstanceOf[String] + } yield l.value.value.asInstanceOf[String] + (remaining, docValues.headOption) + } + val (_, methodDoc) = getDocAnnotation(method.annotations) + val argValuesSymbol = q"${c.fresh[TermName]("argValues")}" + val argSigsSymbol = q"${c.fresh[TermName]("argSigs")}" + val ctxSymbol = q"${c.fresh[TermName]("ctx")}" + val argData = for(argListIndex <- method.paramLists.indices) yield{ + val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any") + val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAny") + val flattenedArgLists = method.paramss(argListIndex) + def hasDefault(i: Int) = { + val defaultName = s"${method.name}$$default$$${i + 1}" + if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName) + else None + } + + val defaults = for (i <- flattenedArgLists.indices) yield { + val arg = TermName(c.freshName()) + hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}") + } + + val readArgSigs = for ( + ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex + ) yield { + + if (arg.typeSignature.typeSymbol == definitions.RepeatedParamClass) c.abort(method.pos, "Varargs are not supported in cask routes") + + val default = defaultOpt match { + case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))" + case None => q"scala.None" + } + + val (docUnwrappedType, docOpt) = arg.typeSignature match { + case t: AnnotatedType => + import compat._ + val (remaining, docValue) = getDocAnnotation(t.annotations) + if (remaining.isEmpty) (t.underlying, docValue) + else (c.universe.AnnotatedType(remaining, t.underlying), docValue) + + case t => (t, None) + } + + val docTree = docOpt match { + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + } + + val argSig = + q""" + cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx]( + ${arg.name.toString}, + ${docUnwrappedType.toString}, + $docTree, + $defaultOpt + )($argReader[$docUnwrappedType]) + """ + + val reader = q""" + cask.router.Runtime.makeReadCall( + $argValuesSymbol($argListIndex), + $ctxSymbol, + $default, + $argSigsSymbol($argListIndex)($i) + ) + """ + + c.internal.setPos(reader, method.pos) + (reader, argSig) + } + + val (readArgs, argSigs) = readArgSigs.unzip + val (argNames, argNameCasts) = flattenedArgLists.map { arg => + val (vararg, unwrappedType) = unwrapVarargType(arg) + ( + pq"${arg.name.toTermName}", + if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]" + else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*" + + ) + }.unzip + + (argNameCasts, argSigs, argNames, readArgs) + } + + val argNameCasts = argData.map(_._1) + val argSigs = argData.map(_._2) + val argNames = argData.map(_._3) + val readArgs = argData.map(_._4) + var methodCall: c.Tree = q"$baseArgSym.${method.name.toTermName}" + for(argNameCast <- argNameCasts) methodCall = q"$methodCall(..$argNameCast)" + + val res = q""" + cask.router.EntryPoint[$curCls, $ctx]( + ${method.name.toString}, + ${argSigs.toList}, + ${methodDoc match{ + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + }}, + ( + $baseArgSym: $curCls, + $ctxSymbol: $ctx, + $argValuesSymbol: Seq[Map[String, Any]], + $argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, $ctx]]] + ) => + cask.router.Runtime.validate(Seq(..${readArgs.flatten.toList})).map{ + case Seq(..${argNames.flatten.toList}) => $convertToResultType($methodCall) + } + ) + """ + + c.internal.transform(res){(t, a) => + c.internal.setPos(t, method.pos) + a.default(t) + } + + res + } + +} diff --git a/cask/src/cask/router/Misc.scala b/cask/src/cask/router/Misc.scala new file mode 100644 index 0000000..438ec43 --- /dev/null +++ b/cask/src/cask/router/Misc.scala @@ -0,0 +1,23 @@ +package cask.router + +import scala.annotation.StaticAnnotation + + +class doc(s: String) extends StaticAnnotation + +/** + * Models what is known by the router about a single argument: that it has + * a [[name]], a human-readable [[typeString]] describing what the type is + * (just for logging and reading, not a replacement for a `TypeTag`) and + * possible a function that can compute its default value + */ +case class ArgSig[I, -T, +V, -C](name: String, + typeString: String, + doc: Option[String], + default: Option[T => V]) + (implicit val reads: ArgReader[I, V, C]) + +trait ArgReader[I, +T, -C]{ + def arity: Int + def read(ctx: C, label: String, input: I): T +} diff --git a/cask/src/cask/router/Result.scala b/cask/src/cask/router/Result.scala new file mode 100644 index 0000000..52ef0f8 --- /dev/null +++ b/cask/src/cask/router/Result.scala @@ -0,0 +1,66 @@ +package cask.router + + + + +/** + * Represents what comes out of an attempt to invoke an [[EntryPoint]]. + * Could succeed with a value, but could fail in many different ways. + */ +sealed trait Result[+T]{ + def map[V](f: T => V): Result[V] +} +object Result{ + + /** + * Invoking the [[EntryPoint]] was totally successful, and returned a + * result + */ + case class Success[T](value: T) extends Result[T]{ + def map[V](f: T => V) = Success(f(value)) + } + + /** + * Invoking the [[EntryPoint]] was not successful + */ + sealed trait Error extends Result[Nothing]{ + def map[V](f: Nothing => V) = this + } + + + object Error{ + + + /** + * Invoking the [[EntryPoint]] failed with an exception while executing + * code within it. + */ + case class Exception(t: Throwable) extends Error + + /** + * Invoking the [[EntryPoint]] failed because the arguments provided + * did not line up with the arguments expected + */ + case class MismatchedArguments(missing: Seq[ArgSig[_, _, _, _]], + unknown: Seq[String]) extends Error + /** + * Invoking the [[EntryPoint]] failed because there were problems + * deserializing/parsing individual arguments + */ + case class InvalidArguments(values: Seq[ParamError]) extends Error + } + + sealed trait ParamError + object ParamError{ + /** + * Something went wrong trying to de-serialize the input parameter; + * the thrown exception is stored in [[ex]] + */ + case class Invalid(arg: ArgSig[_, _, _, _], value: String, ex: Throwable) extends ParamError + /** + * Something went wrong trying to evaluate the default value + * for this input parameter + */ + case class DefaultFailed(arg: ArgSig[_, _, _, _], ex: Throwable) extends ParamError + } +}
\ No newline at end of file diff --git a/cask/src/cask/main/RoutesEndpointMetadata.scala b/cask/src/cask/router/RoutesEndpointMetadata.scala index fa93a0c..7940641 100644 --- a/cask/src/cask/main/RoutesEndpointMetadata.scala +++ b/cask/src/cask/router/RoutesEndpointMetadata.scala @@ -1,6 +1,6 @@ -package cask.main +package cask.router -import cask.internal.Router.EntryPoint +import cask.router.EntryPoint import language.experimental.macros import scala.reflect.macros.blackbox @@ -12,7 +12,7 @@ object RoutesEndpointsMetadata{ implicit def initialize[T]: RoutesEndpointsMetadata[T] = macro initializeImpl[T] implicit def initializeImpl[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[RoutesEndpointsMetadata[T]] = { import c.universe._ - val router = new cask.internal.Router[c.type](c) + val router = new cask.router.Macros[c.type](c) val routeParts = for{ m <- c.weakTypeOf[T].members @@ -54,7 +54,7 @@ object RoutesEndpointsMetadata{ val res = q"""{ ..$declarations - cask.main.EndpointMetadata( + cask.router.EndpointMetadata( Seq(..${annotObjectSyms.dropRight(1)}), ${annotObjectSyms.last}, $route @@ -63,6 +63,6 @@ object RoutesEndpointsMetadata{ res } - c.Expr[RoutesEndpointsMetadata[T]](q"""cask.main.RoutesEndpointsMetadata(..$routeParts)""") + c.Expr[RoutesEndpointsMetadata[T]](q"""cask.router.RoutesEndpointsMetadata(..$routeParts)""") } }
\ No newline at end of file diff --git a/cask/src/cask/router/Runtime.scala b/cask/src/cask/router/Runtime.scala new file mode 100644 index 0000000..4fbbb48 --- /dev/null +++ b/cask/src/cask/router/Runtime.scala @@ -0,0 +1,39 @@ +package cask.router + +object Runtime{ + + def tryEither[T](t: => T, error: Throwable => Result.ParamError) = { + try Right(t) + catch{ case e: Throwable => Left(error(e))} + } + + + def validate(args: Seq[Either[Seq[Result.ParamError], Any]]): Result[Seq[Any]] = { + val lefts = args.collect{case Left(x) => x}.flatten + + if (lefts.nonEmpty) Result.Error.InvalidArguments(lefts) + else { + val rights = args.collect{case Right(x) => x} + Result.Success(rights) + } + } + + def makeReadCall[I, C](dict: Map[String, I], + ctx: C, + default: => Option[Any], + arg: ArgSig[I, _, _, C]) = { + arg.reads.arity match{ + case 0 => + tryEither( + arg.reads.read(ctx, arg.name, null.asInstanceOf[I]), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + case 1 => + dict.get(arg.name) match{ + case None => + tryEither(default.get, Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + + case Some(x) => + tryEither(arg.reads.read(ctx, arg.name, x), Result.ParamError.Invalid(arg, x.toString, _)).left.map(Seq(_)) + } + } + } +}
\ No newline at end of file |