diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core')
-rw-r--r-- | src/main/scala/xyz/driver/core/database/Dal.scala | 20 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/date.scala | 17 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/domain.scala | 18 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/generators.scala | 7 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/json.scala | 84 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/PatchDirectives.scala | 104 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/time.scala | 87 |
7 files changed, 299 insertions, 38 deletions
diff --git a/src/main/scala/xyz/driver/core/database/Dal.scala b/src/main/scala/xyz/driver/core/database/Dal.scala index 581bd0f..bcde0de 100644 --- a/src/main/scala/xyz/driver/core/database/Dal.scala +++ b/src/main/scala/xyz/driver/core/database/Dal.scala @@ -1,10 +1,11 @@ package xyz.driver.core.database -import slick.lifted.AbstractTable +import scalaz.std.scalaFuture._ +import scalaz.{ListT, Monad, OptionT} +import slick.lifted.{AbstractTable, CanBeQueryCondition, RunnableCompiled} +import slick.{lifted => sl} import scala.concurrent.{ExecutionContext, Future} -import scalaz.{ListT, Monad, OptionT} -import scalaz.std.scalaFuture._ trait Dal { type T[D] @@ -34,16 +35,20 @@ class SlickDal(database: Database, executionContext: ExecutionContext) extends D override type T[D] = slick.dbio.DBIO[D] - implicit protected class QueryOps[U](query: Query[_, U, Seq]) { + implicit protected class QueryOps[+E, U](query: Query[E, U, Seq]) { def resultT: ListT[T, U] = ListT[T, U](query.result.map(_.toList)) + + def maybeFilter[V, R: CanBeQueryCondition](data: Option[V])(f: V => E => R): sl.Query[E, U, Seq] = + data.map(v => query.withFilter(f(v))).getOrElse(query) } - implicit protected class CompiledQueryOps[U](compiledQuery: slick.lifted.RunnableCompiled[_, Seq[U]]) { + implicit protected class CompiledQueryOps[U](compiledQuery: RunnableCompiled[_, Seq[U]]) { def resultT: ListT[T, U] = ListT.listT[T](compiledQuery.result.map(_.toList)) } private val dbioMonad = new Monad[T] { - override def point[A](a: => A): T[A] = DBIO.successful(a) + override def point[A](a: => A): T[A] = DBIO.successful(a) + override def bind[A, B](fa: T[A])(f: A => T[B]): T[B] = fa.flatMap(f) } @@ -53,7 +58,8 @@ class SlickDal(database: Database, executionContext: ExecutionContext) extends D database.database.run(readOperations.transactionally) } - override def noAction[V](v: V): T[V] = DBIO.successful(v) + override def noAction[V](v: V): T[V] = DBIO.successful(v) + override def customAction[R](action: => Future[R]): T[R] = DBIO.from(action) def affectsRows(updatesCount: Int): Option[Unit] = { diff --git a/src/main/scala/xyz/driver/core/date.scala b/src/main/scala/xyz/driver/core/date.scala index fe35c91..5454093 100644 --- a/src/main/scala/xyz/driver/core/date.scala +++ b/src/main/scala/xyz/driver/core/date.scala @@ -2,12 +2,13 @@ package xyz.driver.core import java.util.Calendar -import scala.util.Try - +import enumeratum._ import scalaz.std.anyVal._ -import scalaz.Scalaz.stringInstance import scalaz.syntax.equal._ +import scala.collection.immutable.IndexedSeq +import scala.util.Try + /** * Driver Date type and related validators/extractors. * Day, Month, and Year extractors are from ISO 8601 strings => driver...Date integers. @@ -15,8 +16,8 @@ import scalaz.syntax.equal._ */ object date { - sealed trait DayOfWeek - object DayOfWeek { + sealed trait DayOfWeek extends EnumEntry + object DayOfWeek extends Enum[DayOfWeek] { case object Monday extends DayOfWeek case object Tuesday extends DayOfWeek case object Wednesday extends DayOfWeek @@ -25,9 +26,11 @@ object date { case object Saturday extends DayOfWeek case object Sunday extends DayOfWeek - val All: Set[DayOfWeek] = Set(Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday) + val values: IndexedSeq[DayOfWeek] = findValues + + val All: Set[DayOfWeek] = values.toSet - def fromString(day: String): Option[DayOfWeek] = All.find(_.toString === day) + def fromString(day: String): Option[DayOfWeek] = withNameInsensitiveOption(day) } type Day = Int @@ Day.type diff --git a/src/main/scala/xyz/driver/core/domain.scala b/src/main/scala/xyz/driver/core/domain.scala index 48943a7..7731345 100644 --- a/src/main/scala/xyz/driver/core/domain.scala +++ b/src/main/scala/xyz/driver/core/domain.scala @@ -1,13 +1,14 @@ package xyz.driver.core +import com.google.i18n.phonenumbers.PhoneNumberUtil import scalaz.Equal -import scalaz.syntax.equal._ import scalaz.std.string._ +import scalaz.syntax.equal._ object domain { final case class Email(username: String, domain: String) { - override def toString = username + "@" + domain + override def toString: String = username + "@" + domain } object Email { @@ -27,16 +28,13 @@ object domain { } object PhoneNumber { - def parse(phoneNumberString: String): Option[PhoneNumber] = { - val onlyDigits = phoneNumberString.replaceAll("[^\\d.]", "") - if (onlyDigits.length < 10) None - else { - val tenDigitNumber = onlyDigits.takeRight(10) - val countryCode = Option(onlyDigits.dropRight(10)).filter(_.nonEmpty).getOrElse("1") + private val phoneUtil = PhoneNumberUtil.getInstance() - Some(PhoneNumber(countryCode, tenDigitNumber)) - } + def parse(phoneNumber: String): Option[PhoneNumber] = { + val phone = phoneUtil.parseAndKeepRawInput(phoneNumber, "US") + if (!phoneUtil.isValidNumber(phone)) None + else Some(PhoneNumber(phone.getCountryCode.toString, phone.getNationalNumber.toString)) } } } diff --git a/src/main/scala/xyz/driver/core/generators.scala b/src/main/scala/xyz/driver/core/generators.scala index e3ff326..3c85447 100644 --- a/src/main/scala/xyz/driver/core/generators.scala +++ b/src/main/scala/xyz/driver/core/generators.scala @@ -1,9 +1,10 @@ package xyz.driver.core +import enumeratum._ import java.math.MathContext import java.util.UUID -import xyz.driver.core.time.{Time, TimeRange} +import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} import xyz.driver.core.date.{Date, DayOfWeek} import scala.reflect.ClassTag @@ -69,6 +70,8 @@ object generators { def nextTime(): Time = Time(math.abs(nextLong() % System.currentTimeMillis)) + def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) + def nextTimeRange(): TimeRange = { val oneTime = nextTime() val anotherTime = nextTime() @@ -89,6 +92,8 @@ object generators { def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) + def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) + def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = Array.fill(nextInt(maxLength, minLength))(generator) diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala index 02a35fd..06a8837 100644 --- a/src/main/scala/xyz/driver/core/json.scala +++ b/src/main/scala/xyz/driver/core/json.scala @@ -1,23 +1,25 @@ package xyz.driver.core import java.net.InetAddress -import java.util.UUID +import java.util.{TimeZone, UUID} -import scala.reflect.runtime.universe._ -import scala.util.Try +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} import akka.http.scaladsl.model.Uri.Path -import akka.http.scaladsl.server._ import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} -import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshaller +import enumeratum._ +import eu.timepit.refined.api.{Refined, Validate} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV import spray.json._ import xyz.driver.core.auth.AuthCredentials import xyz.driver.core.date.{Date, DayOfWeek, Month} import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.time.Time -import eu.timepit.refined.refineV -import eu.timepit.refined.api.{Refined, Validate} -import eu.timepit.refined.collection.NonEmpty +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.reflect.runtime.universe._ +import scala.util.Try object json { import DefaultJsonProtocol._ @@ -80,8 +82,34 @@ object json { } } - implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = - new EnumJsonFormat[DayOfWeek](DayOfWeek.All.map(w => w.toString -> w)(collection.breakOut): _*) + implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { + private val formatter = TimeOfDay.getFormatter + def read(json: JsValue): java.time.LocalTime = json match { + case JsString(chars) => + java.time.LocalTime.parse(chars) + case _ => deserializationError(s"Expected time string got ${json.toString}") + } + + def write(obj: java.time.LocalTime): JsValue = { + JsString(obj.format(formatter)) + } + } + + implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { + override def write(obj: TimeZone): JsValue = { + JsString(obj.getID()) + } + + override def read(json: JsValue): TimeZone = json match { + case JsString(chars) => + java.util.TimeZone.getTimeZone(chars) + case _ => deserializationError(s"Expected time zone string got ${json.toString}") + } + } + + implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) + + implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) implicit val dateFormat = new RootJsonFormat[Date] { def write(date: Date) = JsString(date.toString) @@ -109,9 +137,9 @@ object json { } implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = - Unmarshaller.strict[String, Revision[T]](Revision[T](_)) + Unmarshaller.strict[String, Revision[T]](Revision[T]) - implicit def revisionFormat[T] = new RootJsonFormat[Revision[T]] { + implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { def write(revision: Revision[T]) = JsString(revision.id.toString) def read(value: JsValue): Revision[T] = value match { @@ -159,6 +187,36 @@ object json { JsString(obj.getHostAddress) } + object enumeratum { + + def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + + trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => + + implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) + + implicit val unmarshaller: Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + } + + class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { + override def read(json: JsValue): T = json match { + case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + + override def write(obj: T): JsValue = JsString(obj.entryName) + } + + private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = + deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") + } + class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { private val map = mapping.toMap diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..256358c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject(oldObj.fields.map({ + case (key, oldValue) => + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/time.scala b/src/main/scala/xyz/driver/core/time.scala index 3bcc7bc..bab304d 100644 --- a/src/main/scala/xyz/driver/core/time.scala +++ b/src/main/scala/xyz/driver/core/time.scala @@ -4,7 +4,10 @@ import java.text.SimpleDateFormat import java.util._ import java.util.concurrent.TimeUnit +import xyz.driver.core.date.Month + import scala.concurrent.duration._ +import scala.util.Try object time { @@ -39,6 +42,90 @@ object time { } } + /** + * Encapsulates a time and timezone without a specific date. + */ + final case class TimeOfDay(localTime: java.time.LocalTime, timeZone: TimeZone) { + + /** + * Is this time before another time on a specific day. Day light savings safe. These are zero-indexed + * for month/day. + */ + def isBefore(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).before(other.toCalendar(day, month, year)) + } + + /** + * Is this time after another time on a specific day. Day light savings safe. + */ + def isAfter(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).after(other.toCalendar(day, month, year)) + } + + def sameTimeAs(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).equals(other.toCalendar(day, month, year)) + } + + /** + * Enforces the same formatting as expected by [[java.sql.Time]] + * @return string formatted for `java.sql.Time` + */ + def timeString: String = { + localTime.format(TimeOfDay.getFormatter) + } + + /** + * @return a string parsable by [[java.util.TimeZone]] + */ + def timeZoneString: String = { + timeZone.getID + } + + /** + * @return this [[TimeOfDay]] as [[java.sql.Time]] object, [[java.sql.Time.valueOf]] will + * throw when the string is not valid, but this is protected by [[timeString]] method. + */ + def toTime: java.sql.Time = { + java.sql.Time.valueOf(timeString) + } + + private def toCalendar(day: Int, month: Int, year: Int): Calendar = { + val cal = Calendar.getInstance(timeZone) + cal.set(year, month, day, localTime.getHour, localTime.getMinute, localTime.getSecond) + cal + } + } + + object TimeOfDay { + def now(): TimeOfDay = { + TimeOfDay(java.time.LocalTime.now(), TimeZone.getDefault) + } + + /** + * Throws when [s] is not parsable by [[java.time.LocalTime.parse]], uses default [[java.util.TimeZone]] + */ + def parseTimeString(tz: TimeZone = TimeZone.getDefault)(s: String): TimeOfDay = { + TimeOfDay(java.time.LocalTime.parse(s), tz) + } + + def fromString(tz: TimeZone)(s: String): Option[TimeOfDay] = { + val op = Try(java.time.LocalTime.parse(s)).toOption + op.map(lt => TimeOfDay(lt, tz)) + } + + def fromStrings(zoneId: String)(s: String): Option[TimeOfDay] = { + val op = Try(TimeZone.getTimeZone(zoneId)).toOption + op.map(tz => TimeOfDay.parseTimeString(tz)(s)) + } + + /** + * Formatter that enforces `HH:mm:ss` which is expected by [[java.sql.Time]] + */ + def getFormatter: java.time.format.DateTimeFormatter = { + java.time.format.DateTimeFormatter.ofPattern("HH:mm:ss") + } + } + object Time { implicit def timeOrdering: Ordering[Time] = Ordering.by(_.millis) |