diff options
author | Li Haoyi <haoyi.sg@gmail.com> | 2018-07-26 12:36:07 +0800 |
---|---|---|
committer | Li Haoyi <haoyi.sg@gmail.com> | 2018-07-26 12:36:07 +0800 |
commit | d8ae2662b44cc70d296832e3dc54685de4008a2d (patch) | |
tree | 61235be47c85a16707a53861bee97ee3651bb487 /cask | |
parent | dfaece8336adc803b9088621b45bdba1deb3213f (diff) | |
download | cask-d8ae2662b44cc70d296832e3dc54685de4008a2d.tar.gz cask-d8ae2662b44cc70d296832e3dc54685de4008a2d.tar.bz2 cask-d8ae2662b44cc70d296832e3dc54685de4008a2d.zip |
Allow Decorators to short-circuit request processing by bailing out early
Diffstat (limited to 'cask')
-rw-r--r-- | cask/src/cask/endpoints/FormEndpoint.scala | 25 | ||||
-rw-r--r-- | cask/src/cask/endpoints/JsonEndpoint.scala | 21 | ||||
-rw-r--r-- | cask/src/cask/endpoints/StaticEndpoints.scala | 2 | ||||
-rw-r--r-- | cask/src/cask/endpoints/WebEndpoints.scala | 10 | ||||
-rw-r--r-- | cask/src/cask/internal/Util.scala | 23 | ||||
-rw-r--r-- | cask/src/cask/main/ErrorMsgs.scala | 6 | ||||
-rw-r--r-- | cask/src/cask/main/Main.scala | 24 | ||||
-rw-r--r-- | cask/src/cask/main/Routes.scala | 4 | ||||
-rw-r--r-- | cask/test/src/test/cask/Decorated.scala | 2 | ||||
-rw-r--r-- | cask/test/src/test/cask/FailureTests.scala | 2 |
10 files changed, 85 insertions, 34 deletions
diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala index 2d9fc1f..2a7fbd2 100644 --- a/cask/src/cask/endpoints/FormEndpoint.scala +++ b/cask/src/cask/endpoints/FormEndpoint.scala @@ -1,6 +1,6 @@ package cask.endpoints -import cask.internal.Router +import cask.internal.{Router, Util} import cask.main.Routes import cask.model.{FormValue, ParamContext, Response} import io.undertow.server.handlers.form.FormParserFactory @@ -40,14 +40,21 @@ class postForm(val path: String, override val subpath: Boolean = false) extends type Input = Seq[FormValue] type InputParser[T] = FormReader[T] def getRawParams(ctx: ParamContext) = { - val formData = FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking() - val formDataBindings = - formData - .iterator() - .asScala - .map(k => (k, formData.get(k).asScala.map(FormValue.fromUndertow).toSeq)) - .toMap - formDataBindings + for{ + formData <- + try Right(FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking()) + catch{case e: Exception => Left(cask.model.Response( + "Unable to parse form data: " + e + "\n" + Util.stackTraceString(e) + ))} + } yield { + val formDataBindings = + formData + .iterator() + .asScala + .map(k => (k, formData.get(k).asScala.map(FormValue.fromUndertow).toSeq)) + .toMap + formDataBindings + } } def wrapPathSegment(s: String): Input = Seq(FormValue.Plain(s, new io.undertow.util.HeaderMap)) } diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala index 80fac9a..d9d39c3 100644 --- a/cask/src/cask/endpoints/JsonEndpoint.scala +++ b/cask/src/cask/endpoints/JsonEndpoint.scala @@ -1,6 +1,6 @@ package cask.endpoints -import cask.internal.Router +import cask.internal.{Router, Util} import cask.internal.Router.EntryPoint import cask.main.Routes import cask.model.{ParamContext, Response} @@ -28,7 +28,22 @@ class postJson(val path: String, override val subpath: Boolean = false) extends val methods = Seq("post") type Input = ujson.Js.Value type InputParser[T] = JsReader[T] - def getRawParams(ctx: ParamContext) = - ujson.read(new String(ctx.exchange.getInputStream.readAllBytes())).obj.toMap + def getRawParams(ctx: ParamContext) = { + for{ + str <- + try Right(new String(ctx.exchange.getInputStream.readAllBytes())) + catch{case e: Throwable => Left(cask.model.Response( + "Unable to deserialize input JSON text: " + e + "\n" + Util.stackTraceString(e) + ))} + json <- + try Right(ujson.read(str)) + catch{case e: Throwable => Left(cask.model.Response( + "Input text is invalid JSON: " + e + "\n" + Util.stackTraceString(e) + ))} + obj <- + try Right(json.obj) + catch {case e: Throwable => Left(cask.model.Response("Input JSON must be a dictionary"))} + } yield obj.toMap + } def wrapPathSegment(s: String): Input = ujson.Js.Str(s) } diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala index 5d57144..31528aa 100644 --- a/cask/src/cask/endpoints/StaticEndpoints.scala +++ b/cask/src/cask/endpoints/StaticEndpoints.scala @@ -14,6 +14,6 @@ class static(val path: String) extends Routes.Endpoint[String] { Router.Result.Success(cask.model.Static(t + "/" + ctx.remaining.mkString("/"))) } - def getRawParams(ctx: ParamContext) = Map() + def getRawParams(ctx: ParamContext) = Right(Map()) def wrapPathSegment(s: String): Input = Seq(s) } diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala index 82009df..cd707b8 100644 --- a/cask/src/cask/endpoints/WebEndpoints.scala +++ b/cask/src/cask/endpoints/WebEndpoints.scala @@ -11,10 +11,12 @@ import collection.JavaConverters._ trait WebEndpoint extends Routes.Endpoint[BaseResponse]{ type Input = Seq[String] type InputParser[T] = QueryParamReader[T] - def getRawParams(ctx: ParamContext) = ctx.exchange.getQueryParameters - .asScala - .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)} - .toMap + def getRawParams(ctx: ParamContext) = Right( + ctx.exchange.getQueryParameters + .asScala + .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)} + .toMap + ) def wrapPathSegment(s: String) = Seq(s) } class get(val path: String, override val subpath: Boolean = false) extends WebEndpoint{ diff --git a/cask/src/cask/internal/Util.scala b/cask/src/cask/internal/Util.scala index 59bc2ce..98c30c4 100644 --- a/cask/src/cask/internal/Util.scala +++ b/cask/src/cask/internal/Util.scala @@ -1,12 +1,25 @@ package cask.internal +import java.io.{PrintWriter, StringWriter} + +import scala.collection.generic.CanBuildFrom +import scala.collection.mutable + object Util { def pluralize(s: String, n: Int) = { if (n == 1) s else s + "s" } - def splitPath(p: String) = + def splitPath(p: String) = { p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/').filter(_.nonEmpty) + } + def stackTraceString(e: Throwable) = { + val trace = new StringWriter() + val pw = new PrintWriter(trace) + e.printStackTrace(pw) + pw.flush() + trace.toString + } def softWrap(s: String, leftOffset: Int, maxWidth: Int) = { val oneLine = s.lines.mkString(" ").split(' ') @@ -28,4 +41,12 @@ object Util { } output.mkString } + def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])( + implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = { + in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) { + case (acc, el) => + for (a <- acc; e <- el) yield a += e + } + .map(_.result()) + } } diff --git a/cask/src/cask/main/ErrorMsgs.scala b/cask/src/cask/main/ErrorMsgs.scala index c5ce978..e54ea88 100644 --- a/cask/src/cask/main/ErrorMsgs.scala +++ b/cask/src/cask/main/ErrorMsgs.scala @@ -97,12 +97,10 @@ object ErrorMsgs { val thingies = x.map{ case Router.Result.ParamError.Invalid(p, v, ex) => val literalV = literalize(v) - val trace = new StringWriter() - ex.printStackTrace(new PrintWriter(trace)) + val trace = Util.stackTraceString(ex) s"${p.name}: ${p.typeString} = $literalV failed to parse with $ex\n$trace" case Router.Result.ParamError.DefaultFailed(p, ex) => - val trace = new StringWriter() - ex.printStackTrace(new PrintWriter(trace)) + val trace = Util.stackTraceString(ex) s"${p.name}'s default value failed to evaluate with $ex\n$trace" } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index f28ae12..9ac0022 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -56,14 +56,22 @@ abstract class BaseMain{ routeTries(exchange.getRequestMethod.toString.toLowerCase()).lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{ case None => writeResponse(exchange, handleError(404)) case Some(((routes, metadata), bindings, remaining)) => - val providers = - Seq(metadata.endpoint.getRawParams(ParamContext(exchange, remaining)) ++ - bindings.mapValues(metadata.endpoint.wrapPathSegment)) ++ - metadata.decorators.map(e => e.getRawParams(ParamContext(exchange, remaining))) - - val result = metadata.entryPoint - .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]] - .invoke(routes, ParamContext(exchange, remaining), providers) + val params = for{ + endpointParams <- metadata.endpoint.getRawParams(ParamContext(exchange, remaining)) + decoratorParams <- Util.sequenceEither( + metadata.decorators.map(e => e.getRawParams(ParamContext(exchange, remaining))) + ) + } yield (endpointParams ++ bindings.mapValues(metadata.endpoint.wrapPathSegment)) +: decoratorParams + + val result = params match{ + case Left(resp) => resp + case Right(paramValues) => + metadata.entryPoint + .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]] + .invoke(routes, ParamContext(exchange, remaining), paramValues) + } + + result match{ case Router.Result.Success(response: BaseResponse) => writeResponse(exchange, response) case e: Router.Result.Error => diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala index 76a84d0..957a293 100644 --- a/cask/src/cask/main/Routes.scala +++ b/cask/src/cask/main/Routes.scala @@ -16,7 +16,7 @@ object Routes{ def wrapMethodOutput(ctx: ParamContext,t: R): cask.internal.Router.Result[Any] = { cask.internal.Router.Result.Success(t) } - def getRawParams(ctx: ParamContext): Map[String, Input] + def wrapPathSegment(s: String): Input } @@ -37,7 +37,7 @@ object Routes{ trait BaseDecorator{ type Input type InputParser[T] <: ArgReader[Input, T, ParamContext] - def getRawParams(ctx: ParamContext): Map[String, Input] + def getRawParams(ctx: ParamContext): Either[cask.model.Response, Map[String, Input]] def getParamParser[T](implicit p: InputParser[T]) = p } diff --git a/cask/test/src/test/cask/Decorated.scala b/cask/test/src/test/cask/Decorated.scala index c6e048a..e20149a 100644 --- a/cask/test/src/test/cask/Decorated.scala +++ b/cask/test/src/test/cask/Decorated.scala @@ -4,7 +4,7 @@ import cask.model.ParamContext object Decorated extends cask.MainRoutes{ class myDecorator extends cask.Routes.Decorator { - def getRawParams(ctx: ParamContext) = Map("extra" -> 31337) + def getRawParams(ctx: ParamContext) = Right(Map("extra" -> 31337)) } @myDecorator() diff --git a/cask/test/src/test/cask/FailureTests.scala b/cask/test/src/test/cask/FailureTests.scala index c2db0f4..03f3777 100644 --- a/cask/test/src/test/cask/FailureTests.scala +++ b/cask/test/src/test/cask/FailureTests.scala @@ -5,7 +5,7 @@ import utest._ object FailureTests extends TestSuite { class myDecorator extends cask.Routes.Decorator { - def getRawParams(ctx: ParamContext) = Map("extra" -> 31337) + def getRawParams(ctx: ParamContext) = Right(Map("extra" -> 31337)) } val tests = Tests{ |