summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-07-26 12:36:07 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2018-07-26 12:36:07 +0800
commitd8ae2662b44cc70d296832e3dc54685de4008a2d (patch)
tree61235be47c85a16707a53861bee97ee3651bb487
parentdfaece8336adc803b9088621b45bdba1deb3213f (diff)
downloadcask-d8ae2662b44cc70d296832e3dc54685de4008a2d.tar.gz
cask-d8ae2662b44cc70d296832e3dc54685de4008a2d.tar.bz2
cask-d8ae2662b44cc70d296832e3dc54685de4008a2d.zip
Allow Decorators to short-circuit request processing by bailing out early
-rw-r--r--cask/src/cask/endpoints/FormEndpoint.scala25
-rw-r--r--cask/src/cask/endpoints/JsonEndpoint.scala21
-rw-r--r--cask/src/cask/endpoints/StaticEndpoints.scala2
-rw-r--r--cask/src/cask/endpoints/WebEndpoints.scala10
-rw-r--r--cask/src/cask/internal/Util.scala23
-rw-r--r--cask/src/cask/main/ErrorMsgs.scala6
-rw-r--r--cask/src/cask/main/Main.scala24
-rw-r--r--cask/src/cask/main/Routes.scala4
-rw-r--r--cask/test/src/test/cask/Decorated.scala2
-rw-r--r--cask/test/src/test/cask/FailureTests.scala2
10 files changed, 85 insertions, 34 deletions
diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala
index 2d9fc1f..2a7fbd2 100644
--- a/cask/src/cask/endpoints/FormEndpoint.scala
+++ b/cask/src/cask/endpoints/FormEndpoint.scala
@@ -1,6 +1,6 @@
package cask.endpoints
-import cask.internal.Router
+import cask.internal.{Router, Util}
import cask.main.Routes
import cask.model.{FormValue, ParamContext, Response}
import io.undertow.server.handlers.form.FormParserFactory
@@ -40,14 +40,21 @@ class postForm(val path: String, override val subpath: Boolean = false) extends
type Input = Seq[FormValue]
type InputParser[T] = FormReader[T]
def getRawParams(ctx: ParamContext) = {
- val formData = FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking()
- val formDataBindings =
- formData
- .iterator()
- .asScala
- .map(k => (k, formData.get(k).asScala.map(FormValue.fromUndertow).toSeq))
- .toMap
- formDataBindings
+ for{
+ formData <-
+ try Right(FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking())
+ catch{case e: Exception => Left(cask.model.Response(
+ "Unable to parse form data: " + e + "\n" + Util.stackTraceString(e)
+ ))}
+ } yield {
+ val formDataBindings =
+ formData
+ .iterator()
+ .asScala
+ .map(k => (k, formData.get(k).asScala.map(FormValue.fromUndertow).toSeq))
+ .toMap
+ formDataBindings
+ }
}
def wrapPathSegment(s: String): Input = Seq(FormValue.Plain(s, new io.undertow.util.HeaderMap))
}
diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala
index 80fac9a..d9d39c3 100644
--- a/cask/src/cask/endpoints/JsonEndpoint.scala
+++ b/cask/src/cask/endpoints/JsonEndpoint.scala
@@ -1,6 +1,6 @@
package cask.endpoints
-import cask.internal.Router
+import cask.internal.{Router, Util}
import cask.internal.Router.EntryPoint
import cask.main.Routes
import cask.model.{ParamContext, Response}
@@ -28,7 +28,22 @@ class postJson(val path: String, override val subpath: Boolean = false) extends
val methods = Seq("post")
type Input = ujson.Js.Value
type InputParser[T] = JsReader[T]
- def getRawParams(ctx: ParamContext) =
- ujson.read(new String(ctx.exchange.getInputStream.readAllBytes())).obj.toMap
+ def getRawParams(ctx: ParamContext) = {
+ for{
+ str <-
+ try Right(new String(ctx.exchange.getInputStream.readAllBytes()))
+ catch{case e: Throwable => Left(cask.model.Response(
+ "Unable to deserialize input JSON text: " + e + "\n" + Util.stackTraceString(e)
+ ))}
+ json <-
+ try Right(ujson.read(str))
+ catch{case e: Throwable => Left(cask.model.Response(
+ "Input text is invalid JSON: " + e + "\n" + Util.stackTraceString(e)
+ ))}
+ obj <-
+ try Right(json.obj)
+ catch {case e: Throwable => Left(cask.model.Response("Input JSON must be a dictionary"))}
+ } yield obj.toMap
+ }
def wrapPathSegment(s: String): Input = ujson.Js.Str(s)
}
diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala
index 5d57144..31528aa 100644
--- a/cask/src/cask/endpoints/StaticEndpoints.scala
+++ b/cask/src/cask/endpoints/StaticEndpoints.scala
@@ -14,6 +14,6 @@ class static(val path: String) extends Routes.Endpoint[String] {
Router.Result.Success(cask.model.Static(t + "/" + ctx.remaining.mkString("/")))
}
- def getRawParams(ctx: ParamContext) = Map()
+ def getRawParams(ctx: ParamContext) = Right(Map())
def wrapPathSegment(s: String): Input = Seq(s)
}
diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala
index 82009df..cd707b8 100644
--- a/cask/src/cask/endpoints/WebEndpoints.scala
+++ b/cask/src/cask/endpoints/WebEndpoints.scala
@@ -11,10 +11,12 @@ import collection.JavaConverters._
trait WebEndpoint extends Routes.Endpoint[BaseResponse]{
type Input = Seq[String]
type InputParser[T] = QueryParamReader[T]
- def getRawParams(ctx: ParamContext) = ctx.exchange.getQueryParameters
- .asScala
- .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)}
- .toMap
+ def getRawParams(ctx: ParamContext) = Right(
+ ctx.exchange.getQueryParameters
+ .asScala
+ .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)}
+ .toMap
+ )
def wrapPathSegment(s: String) = Seq(s)
}
class get(val path: String, override val subpath: Boolean = false) extends WebEndpoint{
diff --git a/cask/src/cask/internal/Util.scala b/cask/src/cask/internal/Util.scala
index 59bc2ce..98c30c4 100644
--- a/cask/src/cask/internal/Util.scala
+++ b/cask/src/cask/internal/Util.scala
@@ -1,12 +1,25 @@
package cask.internal
+import java.io.{PrintWriter, StringWriter}
+
+import scala.collection.generic.CanBuildFrom
+import scala.collection.mutable
+
object Util {
def pluralize(s: String, n: Int) = {
if (n == 1) s else s + "s"
}
- def splitPath(p: String) =
+ def splitPath(p: String) = {
p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/').filter(_.nonEmpty)
+ }
+ def stackTraceString(e: Throwable) = {
+ val trace = new StringWriter()
+ val pw = new PrintWriter(trace)
+ e.printStackTrace(pw)
+ pw.flush()
+ trace.toString
+ }
def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
val oneLine = s.lines.mkString(" ").split(' ')
@@ -28,4 +41,12 @@ object Util {
}
output.mkString
}
+ def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])(
+ implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = {
+ in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) {
+ case (acc, el) =>
+ for (a <- acc; e <- el) yield a += e
+ }
+ .map(_.result())
+ }
}
diff --git a/cask/src/cask/main/ErrorMsgs.scala b/cask/src/cask/main/ErrorMsgs.scala
index c5ce978..e54ea88 100644
--- a/cask/src/cask/main/ErrorMsgs.scala
+++ b/cask/src/cask/main/ErrorMsgs.scala
@@ -97,12 +97,10 @@ object ErrorMsgs {
val thingies = x.map{
case Router.Result.ParamError.Invalid(p, v, ex) =>
val literalV = literalize(v)
- val trace = new StringWriter()
- ex.printStackTrace(new PrintWriter(trace))
+ val trace = Util.stackTraceString(ex)
s"${p.name}: ${p.typeString} = $literalV failed to parse with $ex\n$trace"
case Router.Result.ParamError.DefaultFailed(p, ex) =>
- val trace = new StringWriter()
- ex.printStackTrace(new PrintWriter(trace))
+ val trace = Util.stackTraceString(ex)
s"${p.name}'s default value failed to evaluate with $ex\n$trace"
}
diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala
index f28ae12..9ac0022 100644
--- a/cask/src/cask/main/Main.scala
+++ b/cask/src/cask/main/Main.scala
@@ -56,14 +56,22 @@ abstract class BaseMain{
routeTries(exchange.getRequestMethod.toString.toLowerCase()).lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{
case None => writeResponse(exchange, handleError(404))
case Some(((routes, metadata), bindings, remaining)) =>
- val providers =
- Seq(metadata.endpoint.getRawParams(ParamContext(exchange, remaining)) ++
- bindings.mapValues(metadata.endpoint.wrapPathSegment)) ++
- metadata.decorators.map(e => e.getRawParams(ParamContext(exchange, remaining)))
-
- val result = metadata.entryPoint
- .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]]
- .invoke(routes, ParamContext(exchange, remaining), providers)
+ val params = for{
+ endpointParams <- metadata.endpoint.getRawParams(ParamContext(exchange, remaining))
+ decoratorParams <- Util.sequenceEither(
+ metadata.decorators.map(e => e.getRawParams(ParamContext(exchange, remaining)))
+ )
+ } yield (endpointParams ++ bindings.mapValues(metadata.endpoint.wrapPathSegment)) +: decoratorParams
+
+ val result = params match{
+ case Left(resp) => resp
+ case Right(paramValues) =>
+ metadata.entryPoint
+ .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]]
+ .invoke(routes, ParamContext(exchange, remaining), paramValues)
+ }
+
+
result match{
case Router.Result.Success(response: BaseResponse) => writeResponse(exchange, response)
case e: Router.Result.Error =>
diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala
index 76a84d0..957a293 100644
--- a/cask/src/cask/main/Routes.scala
+++ b/cask/src/cask/main/Routes.scala
@@ -16,7 +16,7 @@ object Routes{
def wrapMethodOutput(ctx: ParamContext,t: R): cask.internal.Router.Result[Any] = {
cask.internal.Router.Result.Success(t)
}
- def getRawParams(ctx: ParamContext): Map[String, Input]
+
def wrapPathSegment(s: String): Input
}
@@ -37,7 +37,7 @@ object Routes{
trait BaseDecorator{
type Input
type InputParser[T] <: ArgReader[Input, T, ParamContext]
- def getRawParams(ctx: ParamContext): Map[String, Input]
+ def getRawParams(ctx: ParamContext): Either[cask.model.Response, Map[String, Input]]
def getParamParser[T](implicit p: InputParser[T]) = p
}
diff --git a/cask/test/src/test/cask/Decorated.scala b/cask/test/src/test/cask/Decorated.scala
index c6e048a..e20149a 100644
--- a/cask/test/src/test/cask/Decorated.scala
+++ b/cask/test/src/test/cask/Decorated.scala
@@ -4,7 +4,7 @@ import cask.model.ParamContext
object Decorated extends cask.MainRoutes{
class myDecorator extends cask.Routes.Decorator {
- def getRawParams(ctx: ParamContext) = Map("extra" -> 31337)
+ def getRawParams(ctx: ParamContext) = Right(Map("extra" -> 31337))
}
@myDecorator()
diff --git a/cask/test/src/test/cask/FailureTests.scala b/cask/test/src/test/cask/FailureTests.scala
index c2db0f4..03f3777 100644
--- a/cask/test/src/test/cask/FailureTests.scala
+++ b/cask/test/src/test/cask/FailureTests.scala
@@ -5,7 +5,7 @@ import utest._
object FailureTests extends TestSuite {
class myDecorator extends cask.Routes.Decorator {
- def getRawParams(ctx: ParamContext) = Map("extra" -> 31337)
+ def getRawParams(ctx: ParamContext) = Right(Map("extra" -> 31337))
}
val tests = Tests{