aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/restquery
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/restquery')
-rw-r--r--src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala116
-rw-r--r--src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala387
-rw-r--r--src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala240
-rw-r--r--src/main/scala/xyz/driver/restquery/query/Pagination.scala12
-rw-r--r--src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala5
-rw-r--r--src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala196
-rw-r--r--src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala5
-rw-r--r--src/main/scala/xyz/driver/restquery/query/Sorting.scala56
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/Directives.scala55
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala30
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala23
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala3
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala178
-rw-r--r--src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala64
-rw-r--r--src/main/scala/xyz/driver/restquery/utils/Utils.scala79
15 files changed, 1449 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala b/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala
new file mode 100644
index 0000000..7038aa0
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala
@@ -0,0 +1,116 @@
+package xyz.driver.pdsuicommon.db
+
+import java.time.{LocalDateTime, ZoneOffset}
+
+import org.slf4j.LoggerFactory
+import slick.jdbc.{GetResult, JdbcProfile}
+import xyz.driver.core.database.SlickDal
+
+import scala.collection.breakOut
+import scala.concurrent.ExecutionContext
+
+object SlickPostgresQueryBuilder {
+ private val logger = LoggerFactory.getLogger(this.getClass)
+
+ import xyz.driver.pdsuicommon.db.SlickQueryBuilder._
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink],
+ runner: Runner[T],
+ countRunner: CountRunner)(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ val parameters = SlickPostgresQueryBuilderParameters(
+ databaseName = databaseName,
+ tableData = TableData(tableName, lastUpdateFieldName, nullableFields),
+ links = links.map(x => x.foreignTableName -> x)(breakOut)
+ )
+ new SlickPostgresQueryBuilder[T](parameters)(runner, countRunner)
+ }
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ apply[T](databaseName,
+ tableName,
+ SlickQueryBuilderParameters.AllFields,
+ lastUpdateFieldName,
+ nullableFields,
+ links)
+ }
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ fields: Set[String],
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+
+ val runner: Runner[T] = { parameters =>
+ val sql = parameters.toSql(countQuery = false, fields = fields).as[T]
+ logger.debug(s"Built an SQL query: $sql")
+ sqlContext.execute(sql)
+ }
+
+ val countRunner: CountRunner = { parameters =>
+ implicit val getCountResult: GetResult[(Int, Option[LocalDateTime])] = GetResult({ r =>
+ val count = r.rs.getInt(1)
+ val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) {
+ Option(r.rs.getTimestamp(2)).map(timestampToLocalDateTime)
+ } else None
+ (count, lastUpdate)
+ })
+ val sql = parameters.toSql(countQuery = true).as[(Int, Option[LocalDateTime])]
+ logger.debug(s"Built an SQL query returning count: $sql")
+ sqlContext.execute(sql).map(_.head)
+ }
+
+ apply[T](
+ databaseName = databaseName,
+ tableName = tableName,
+ lastUpdateFieldName = lastUpdateFieldName,
+ nullableFields = nullableFields,
+ links = links,
+ runner = runner,
+ countRunner = countRunner
+ )
+ }
+
+ def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = {
+ LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC)
+ }
+}
+
+class SlickPostgresQueryBuilder[T](parameters: SlickPostgresQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner)
+ extends SlickQueryBuilder[T](parameters) {
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(filter = newFilter))
+ }
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(sorting = newSorting))
+ }
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination)))
+ }
+
+ def resetPagination: SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = None))
+ }
+}
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
new file mode 100644
index 0000000..9962edf
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
@@ -0,0 +1,387 @@
+package xyz.driver.pdsuicommon.db
+
+import java.sql.{JDBCType, PreparedStatement}
+import java.time.LocalDateTime
+
+import slick.jdbc.{JdbcProfile, PositionedParameters, SQLActionBuilder, SetParameter}
+import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential}
+import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending}
+import xyz.driver.pdsuicommon.domain.{LongId, StringId, UuidId}
+
+import scala.concurrent.{ExecutionContext, Future}
+
+object SlickQueryBuilder {
+
+ type Runner[T] = SlickQueryBuilderParameters => Future[Seq[T]]
+
+ type CountResult = Future[(Int, Option[LocalDateTime])]
+
+ type CountRunner = SlickQueryBuilderParameters => CountResult
+
+ /**
+ * Binder for PreparedStatement
+ */
+ type Binder = PreparedStatement => PreparedStatement
+
+ final case class TableData(tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
+
+ val AllFields = Set("*")
+
+ implicit class SQLActionBuilderConcat(a: SQLActionBuilder) {
+ def concat(b: SQLActionBuilder): SQLActionBuilder = {
+ SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] {
+ def apply(p: Unit, pp: PositionedParameters): Unit = {
+ a.unitPConv.apply(p, pp)
+ b.unitPConv.apply(p, pp)
+ }
+ })
+ }
+ }
+
+ implicit object SetQueryParameter extends SetParameter[AnyRef] {
+ def apply(v: AnyRef, pp: PositionedParameters) = {
+ pp.setObject(v, JDBCType.BINARY.getVendorTypeNumber)
+ }
+ }
+
+ implicit def setLongIdQueryParameter[T]: SetParameter[LongId[T]] = SetParameter[LongId[T]] { (v, pp) =>
+ pp.setLong(v.id)
+ }
+
+ implicit def setStringIdQueryParameter[T]: SetParameter[StringId[T]] = SetParameter[StringId[T]] { (v, pp) =>
+ pp.setString(v.id)
+ }
+
+ implicit def setUuidIdQueryParameter[T]: SetParameter[UuidId[T]] = SetParameter[UuidId[T]] { (v, pp) =>
+ pp.setObject(v.id, JDBCType.BINARY.getVendorTypeNumber)
+ }
+}
+
+final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String)
+
+object SlickQueryBuilderParameters {
+ val AllFields = Set("*")
+}
+
+sealed trait SlickQueryBuilderParameters {
+ import SlickQueryBuilder._
+
+ def databaseName: String
+ def tableData: SlickQueryBuilder.TableData
+ def links: Map[String, SlickTableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def qs: String
+
+ def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
+ toSql(countQuery, SlickQueryBuilderParameters.AllFields)
+ }
+
+ def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs"""
+ val fieldsSql: String = if (countQuery) {
+ val suffix: String = tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)"
+ case None => ""
+ }
+ s"count(*) $suffix"
+ } else {
+ if (fields == SlickQueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.$qs$field$qs"
+ }
+ .mkString(", ")
+ }
+ }
+ val where = filterToSql(escapedTableName, filter)
+ val orderBy = sortingToSql(escapedTableName, sorting)
+
+ val limitSql = limitToSql()
+
+ val sql = sql"""select #$fieldsSql from #$escapedTableName"""
+
+ val filtersTableLinks: Seq[SlickTableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = {
+ if (tableLinks.nonEmpty) {
+ tableLinks.head match {
+ case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs"
+ val join = sql""" inner join #$escapedForeignTableName
+ on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs"""
+ fkSql(fkLinksSql concat join, tableLinks.tail)
+ }
+ } else fkLinksSql
+ }
+ val foreignTableLinksSql = fkSql(sql"", foreignTableLinks)
+
+ val whereSql = if (where.queryParts.size > 1) {
+ sql" where " concat where
+ } else sql""
+
+ val orderSql = if (orderBy.nonEmpty && !countQuery) {
+ sql" order by #$orderBy"
+ } else sql""
+
+ val limSql = if (limitSql.queryParts.size > 1 && !countQuery) {
+ sql" " concat limitSql
+ } else sql""
+
+ sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+ import profile.api._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
+ }
+
+ def multipleSqlToAction(first: Boolean,
+ op: String,
+ conditions: Seq[SQLActionBuilder],
+ sql: SQLActionBuilder): SQLActionBuilder = {
+ if (conditions.nonEmpty) {
+ val condition = conditions.head
+ if (first) {
+ multipleSqlToAction(first = false, op, conditions.tail, condition)
+ } else {
+ multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition)
+ }
+ } else sql
+ }
+
+ def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = {
+ if (tail.nonEmpty) {
+ if (!first) {
+ concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail)
+ } else {
+ concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail)
+ }
+ } else sql concat sql")"
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ sql""
+
+ case AllowAll =>
+ sql"1=1"
+
+ case DenyAll =>
+ sql"1=0"
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is not NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ sql"(#${escapedColumn} is null or #${escapedColumn} != $value)"
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => sql"="
+ case NotEq => sql"!="
+ case Like => sql" like "
+ case Gt => sql">"
+ case GtEq => sql">="
+ case Lt => sql"<"
+ case LtEq => sql"<="
+ }
+ sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value"""
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => sql" in "
+ case SearchFilterNAryOperation.NotIn => sql" not in "
+ }
+
+ if (values.nonEmpty) {
+ val formattedValues = concatenateParameters(sql"", first = true, values)
+ sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
+ } else {
+ sql"1=0"
+ }
+
+ case Intersection(operands) =>
+ val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+
+ case Union(operands) =>
+ val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+ }
+ }
+
+ protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName =
+ optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.$qs$field$qs"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach {
+ case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
+
+final case class SlickPostgresQueryBuilderParameters(databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #${pagination.pageSize} OFFSET #$startFrom"
+ } getOrElse (sql"")
+ }
+
+ val qs = """""""
+
+}
+
+/**
+ * @param links Links to another tables grouped by foreignTableName
+ */
+final case class SlickMysqlQueryBuilderParameters(databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination
+ .map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #$startFrom, #${pagination.pageSize}"
+ }
+ .getOrElse(sql"")
+ }
+
+ val qs = """`"""
+
+}
+
+abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner) {
+
+ def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters)
+
+ def runCount()(implicit ec: ExecutionContext): SlickQueryBuilder.CountResult = countRunner(parameters)
+
+ /**
+ * Runs the query and returns total found rows without considering of pagination.
+ */
+ def runWithCount()(implicit ec: ExecutionContext): Future[(Seq[T], Int, Option[LocalDateTime])] = {
+ for {
+ all <- run
+ (total, lastUpdate) <- runCount
+ } yield (all, total, lastUpdate)
+ }
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T]
+
+ def withFilter(filter: Option[SearchFilterExpr]): SlickQueryBuilder[T] = {
+ filter.fold(this)(withFilter)
+ }
+
+ def resetFilter: SlickQueryBuilder[T] = withFilter(SearchFilterExpr.Empty)
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T]
+
+ def withSorting(sorting: Option[Sorting]): SlickQueryBuilder[T] = {
+ sorting.fold(this)(withSorting)
+ }
+
+ def resetSorting: SlickQueryBuilder[T] = withSorting(Sorting.Empty)
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T]
+
+ def withPagination(pagination: Option[Pagination]): SlickQueryBuilder[T] = {
+ pagination.fold(this)(withPagination)
+ }
+
+ def resetPagination: SlickQueryBuilder[T]
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala
new file mode 100644
index 0000000..6ab7eb4
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala
@@ -0,0 +1,240 @@
+package xyz.driver.restquery.query
+
+import java.sql.PreparedStatement
+
+import slick.jdbc.{JdbcProfile, SQLActionBuilder}
+import xyz.driver.restquery.db.{SlickQueryBuilder, SlickQueryBuilderParameters, SlickTableLink}
+import xyz.driver.restquery.query.Sorting.{Dimension, Sequential}
+import xyz.driver.restquery.query.SortingOrder.{Ascending, Descending}
+
+trait SlickQueryBuilderParameters {
+ import SlickQueryBuilder._
+
+ def databaseName: String
+ def tableData: SlickQueryBuilder.TableData
+ def links: Map[String, SlickTableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def qs: String
+
+ def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
+ toSql(countQuery, SlickQueryBuilderParameters.AllFields)
+ }
+
+ def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs"""
+ val fieldsSql: String = if (countQuery) {
+ val suffix: String = tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)"
+ case None => ""
+ }
+ s"count(*) $suffix"
+ } else {
+ if (fields == SlickQueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.$qs$field$qs"
+ }
+ .mkString(", ")
+ }
+ }
+ val where = filterToSql(escapedTableName, filter)
+ val orderBy = sortingToSql(escapedTableName, sorting)
+
+ val limitSql = limitToSql()
+
+ val sql = sql"""select #$fieldsSql from #$escapedTableName"""
+
+ val filtersTableLinks: Seq[SlickTableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = {
+ if (tableLinks.nonEmpty) {
+ tableLinks.head match {
+ case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs"
+ val join = sql""" inner join #$escapedForeignTableName
+ on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs"""
+ fkSql(fkLinksSql concat join, tableLinks.tail)
+ }
+ } else fkLinksSql
+ }
+ val foreignTableLinksSql = fkSql(sql"", foreignTableLinks)
+
+ val whereSql = if (where.queryParts.size > 1) {
+ sql" where " concat where
+ } else sql""
+
+ val orderSql = if (orderBy.nonEmpty && !countQuery) {
+ sql" order by #$orderBy"
+ } else sql""
+
+ val limSql = if (limitSql.queryParts.size > 1 && !countQuery) {
+ sql" " concat limitSql
+ } else sql""
+
+ sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+ import profile.api._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
+ }
+
+ def multipleSqlToAction(first: Boolean,
+ op: String,
+ conditions: Seq[SQLActionBuilder],
+ sql: SQLActionBuilder): SQLActionBuilder = {
+ if (conditions.nonEmpty) {
+ val condition = conditions.head
+ if (first) {
+ multipleSqlToAction(first = false, op, conditions.tail, condition)
+ } else {
+ multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition)
+ }
+ } else sql
+ }
+
+ def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = {
+ if (tail.nonEmpty) {
+ if (!first) {
+ concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail)
+ } else {
+ concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail)
+ }
+ } else sql concat sql")"
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ sql""
+
+ case AllowAll =>
+ sql"1=1"
+
+ case DenyAll =>
+ sql"1=0"
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is not NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ sql"(#${escapedColumn} is null or #${escapedColumn} != $value)"
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => sql"="
+ case NotEq => sql"!="
+ case Like => sql" like "
+ case Gt => sql">"
+ case GtEq => sql">="
+ case Lt => sql"<"
+ case LtEq => sql"<="
+ }
+ sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value"""
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => sql" in "
+ case SearchFilterNAryOperation.NotIn => sql" not in "
+ }
+
+ if (values.nonEmpty) {
+ val formattedValues = concatenateParameters(sql"", first = true, values)
+ sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
+ } else {
+ sql"1=0"
+ }
+
+ case Intersection(operands) =>
+ val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+
+ case Union(operands) =>
+ val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+ }
+ }
+
+ protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName =
+ optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.$qs$field$qs"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach {
+ case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/query/Pagination.scala b/src/main/scala/xyz/driver/restquery/query/Pagination.scala
new file mode 100644
index 0000000..27b8f12
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/query/Pagination.scala
@@ -0,0 +1,12 @@
+package xyz.driver.restquery.domain
+
+/**
+ * @param pageNumber Starts with 1
+ */
+final case class Pagination(pageSize: Int, pageNumber: Int)
+
+object Pagination {
+
+ // @see https://driverinc.atlassian.net/wiki/display/RA/REST+API+Specification#RESTAPISpecification-CommonRequestQueryParametersForWebServices
+ val Default = Pagination(pageSize = 100, pageNumber = 1)
+}
diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala
new file mode 100644
index 0000000..dab466b
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala
@@ -0,0 +1,5 @@
+package xyz.driver.restquery.query
+
+class SearchFilterBinaryOperation {
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala
new file mode 100644
index 0000000..6d2cb9a
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala
@@ -0,0 +1,196 @@
+package xyz.driver.restquery.domain
+
+sealed trait SearchFilterExpr {
+ def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr]
+ def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr
+}
+
+object SearchFilterExpr {
+
+ val Empty: Intersection = Intersection.Empty
+ val Forbid = Atom.Binary(
+ dimension = Dimension(None, "true"),
+ op = SearchFilterBinaryOperation.Eq,
+ value = "false"
+ )
+
+ final case class Dimension(tableName: Option[String], name: String) {
+ def isForeign: Boolean = tableName.isDefined
+ }
+
+ sealed trait Atom extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ Some(this).filter(p)
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ object Atom {
+ final case class Binary(dimension: Dimension, op: SearchFilterBinaryOperation, value: AnyRef) extends Atom
+ object Binary {
+ def apply(field: String, op: SearchFilterBinaryOperation, value: AnyRef): Binary =
+ Binary(Dimension(None, field), op, value)
+ }
+
+ final case class NAry(dimension: Dimension, op: SearchFilterNAryOperation, values: Seq[AnyRef]) extends Atom
+ object NAry {
+ def apply(field: String, op: SearchFilterNAryOperation, values: Seq[AnyRef]): NAry =
+ NAry(Dimension(None, field), op, values)
+ }
+
+ /** dimension.tableName extractor */
+ object TableName {
+ def unapply(value: Atom): Option[String] = value match {
+ case Binary(Dimension(tableNameOpt, _), _, _) => tableNameOpt
+ case NAry(Dimension(tableNameOpt, _), _, _) => tableNameOpt
+ }
+ }
+ }
+
+ final case class Intersection private (operands: Seq[SearchFilterExpr])
+ extends SearchFilterExpr with SearchFilterExprSeqOps {
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else {
+ this.copy(operands.map(_.replace(f)))
+ }
+ }
+
+ }
+
+ object Intersection {
+
+ val Empty = Intersection(Seq())
+
+ def create(operands: SearchFilterExpr*): SearchFilterExpr = {
+ val filtered = operands.filterNot(SearchFilterExpr.isEmpty)
+ filtered.size match {
+ case 0 => Empty
+ case 1 => filtered.head
+ case _ => Intersection(filtered)
+ }
+ }
+ }
+
+ final case class Union private (operands: Seq[SearchFilterExpr])
+ extends SearchFilterExpr with SearchFilterExprSeqOps {
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else {
+ this.copy(operands.map(_.replace(f)))
+ }
+ }
+
+ }
+
+ object Union {
+
+ val Empty = Union(Seq())
+
+ def create(operands: SearchFilterExpr*): SearchFilterExpr = {
+ val filtered = operands.filterNot(SearchFilterExpr.isEmpty)
+ filtered.size match {
+ case 0 => Empty
+ case 1 => filtered.head
+ case _ => Union(filtered)
+ }
+ }
+
+ def create(dimension: Dimension, values: String*): SearchFilterExpr = values.size match {
+ case 0 => SearchFilterExpr.Empty
+ case 1 => SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, values.head)
+ case _ =>
+ val filters = values.map { value =>
+ SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, value)
+ }
+
+ create(filters: _*)
+ }
+
+ def create(dimension: Dimension, values: Set[String]): SearchFilterExpr =
+ create(dimension, values.toSeq: _*)
+
+ // Backwards compatible API
+
+ /** Create SearchFilterExpr with empty tableName */
+ def create(field: String, values: String*): SearchFilterExpr =
+ create(Dimension(None, field), values: _*)
+
+ /** Create SearchFilterExpr with empty tableName */
+ def create(field: String, values: Set[String]): SearchFilterExpr =
+ create(Dimension(None, field), values)
+ }
+
+ case object AllowAll extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ Some(this).filter(p)
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ case object DenyAll extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ Some(this).filter(p)
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ def isEmpty(expr: SearchFilterExpr): Boolean = {
+ expr == Intersection.Empty || expr == Union.Empty
+ }
+
+ sealed trait SearchFilterExprSeqOps { this: SearchFilterExpr =>
+
+ val operands: Seq[SearchFilterExpr]
+
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ if (p(this)) Some(this)
+ else {
+ // Search the first expr among operands, which satisfy p
+ // Is's ok to use foldLeft. If there will be performance issues, replace it by recursive loop
+ operands.foldLeft(Option.empty[SearchFilterExpr]) {
+ case (None, expr) => expr.find(p)
+ case (x, _) => x
+ }
+ }
+ }
+ }
+
+}
+
+sealed trait SearchFilterBinaryOperation
+
+object SearchFilterBinaryOperation {
+
+ case object Eq extends SearchFilterBinaryOperation
+ case object NotEq extends SearchFilterBinaryOperation
+ case object Like extends SearchFilterBinaryOperation
+ case object Gt extends SearchFilterBinaryOperation
+ case object GtEq extends SearchFilterBinaryOperation
+ case object Lt extends SearchFilterBinaryOperation
+ case object LtEq extends SearchFilterBinaryOperation
+
+}
+
+sealed trait SearchFilterNAryOperation
+
+object SearchFilterNAryOperation {
+
+ case object In extends SearchFilterNAryOperation
+ case object NotIn extends SearchFilterNAryOperation
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala
new file mode 100644
index 0000000..0604c8f
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala
@@ -0,0 +1,5 @@
+package xyz.driver.restquery.query
+
+class SearchFilterNAryOperation {
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/query/Sorting.scala b/src/main/scala/xyz/driver/restquery/query/Sorting.scala
new file mode 100644
index 0000000..e2642ad
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/query/Sorting.scala
@@ -0,0 +1,56 @@
+package xyz.driver.restquery.domain
+
+import scala.collection.generic.CanBuildFrom
+
+sealed trait SortingOrder
+object SortingOrder {
+
+ case object Ascending extends SortingOrder
+ case object Descending extends SortingOrder
+
+}
+
+sealed trait Sorting
+
+object Sorting {
+
+ val Empty = Sequential(Seq.empty)
+
+ /**
+ * @param tableName None if the table is default (same)
+ * @param name Dimension name
+ * @param order Order
+ */
+ final case class Dimension(tableName: Option[String], name: String, order: SortingOrder) extends Sorting {
+ def isForeign: Boolean = tableName.isDefined
+ }
+
+ final case class Sequential(sorting: Seq[Dimension]) extends Sorting {
+ override def toString: String = if (isEmpty(this)) "Empty" else super.toString
+ }
+
+ def isEmpty(input: Sorting): Boolean = {
+ input match {
+ case Sequential(Seq()) => true
+ case _ => false
+ }
+ }
+
+ def filter(sorting: Sorting, p: Dimension => Boolean): Seq[Dimension] = sorting match {
+ case x: Dimension if p(x) => Seq(x)
+ case _: Dimension => Seq.empty
+ case Sequential(xs) => xs.filter(p)
+ }
+
+ def collect[B, That](sorting: Sorting)(f: PartialFunction[Dimension, B])(
+ implicit bf: CanBuildFrom[Seq[Dimension], B, That]): That = sorting match {
+ case x: Dimension if f.isDefinedAt(x) =>
+ val r = bf.apply()
+ r += f(x)
+ r.result()
+
+ case _: Dimension => bf.apply().result()
+ case Sequential(xs) => xs.collect(f)
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/rest/Directives.scala b/src/main/scala/xyz/driver/restquery/rest/Directives.scala
new file mode 100644
index 0000000..2936f70
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/Directives.scala
@@ -0,0 +1,55 @@
+package xyz.driver.restquery.http
+
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.server._
+import xyz.driver.restquery.domain.{SearchFilterExpr, _}
+import xyz.driver.restquery.http.parsers._
+
+import scala.util._
+
+trait Directives {
+
+ val paginated: Directive1[Pagination] = parameterSeq.flatMap { params =>
+ PaginationParser.parse(params) match {
+ case Success(pagination) => provide(pagination)
+ case Failure(ex) =>
+ reject(ValidationRejection("invalid pagination parameter", Some(ex)))
+ }
+ }
+
+ def sorted(validDimensions: Set[String] = Set.empty): Directive1[Sorting] = parameterSeq.flatMap { params =>
+ SortingParser.parse(validDimensions, params) match {
+ case Success(sorting) => provide(sorting)
+ case Failure(ex) =>
+ reject(ValidationRejection("invalid sorting parameter", Some(ex)))
+ }
+ }
+
+ val dimensioned: Directive1[Dimensions] = parameterSeq.flatMap { params =>
+ DimensionsParser.tryParse(params) match {
+ case Success(dims) => provide(dims)
+ case Failure(ex) =>
+ reject(ValidationRejection("invalid dimension parameter", Some(ex)))
+ }
+ }
+
+ val searchFiltered: Directive1[SearchFilterExpr] = parameterSeq.flatMap { params =>
+ SearchFilterParser.parse(params) match {
+ case Success(sorting) => provide(sorting)
+ case Failure(ex) =>
+ reject(ValidationRejection("invalid filter parameter", Some(ex)))
+ }
+ }
+
+ def StringIdInPath[T]: PathMatcher1[StringId[T]] =
+ PathMatchers.Segment.map((id) => StringId(id.toString))
+
+ def LongIdInPath[T]: PathMatcher1[LongId[T]] =
+ PathMatchers.LongNumber.map((id) => LongId(id))
+
+ def UuidIdInPath[T]: PathMatcher1[UuidId[T]] =
+ PathMatchers.JavaUUID.map((id) => UuidId(id))
+
+}
+
+object Directives extends Directives
diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala
new file mode 100644
index 0000000..7e139db
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala
@@ -0,0 +1,30 @@
+package xyz.driver.restquery.http.parsers
+
+import scala.util.{Failure, Success, Try}
+
+class Dimensions(private val xs: Set[String] = Set.empty) {
+ def contains(x: String): Boolean = xs.isEmpty || xs.contains(x)
+}
+
+object DimensionsParser {
+
+ @deprecated("play-akka transition", "0")
+ def tryParse(query: Map[String, Seq[String]]): Try[Dimensions] =
+ tryParse(query.toSeq.flatMap {
+ case (key, values) =>
+ values.map(value => key -> value)
+ })
+
+ def tryParse(query: Seq[(String, String)]): Try[Dimensions] = {
+ query.collect { case ("dimensions", value) => value } match {
+ case Nil => Success(new Dimensions())
+
+ case x +: Nil =>
+ val raw: Set[String] = x.split(",").view.map(_.trim).filter(_.nonEmpty).to[Set]
+ Success(new Dimensions(raw))
+
+ case xs =>
+ Failure(new IllegalArgumentException(s"Dimensions are specified ${xs.size} times"))
+ }
+ }
+}
diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala
new file mode 100644
index 0000000..2b4547b
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala
@@ -0,0 +1,23 @@
+package xyz.driver.restquery.http.parsers
+
+import xyz.driver.restquery.domain.Pagination
+
+import scala.util._
+
+object PaginationParser {
+
+ def parse(query: Seq[(String, String)]): Try[Pagination] = {
+ val IntString = """(\d+)""".r
+ def validate(field: String, default: Int) = query.collectFirst { case (`field`, size) => size } match {
+ case Some(IntString(x)) if x.toInt > 0 => x.toInt
+ case Some(IntString(x)) => throw new ParseQueryArgException((field, s"must greater than zero (found $x)"))
+ case Some(str) => throw new ParseQueryArgException((field, s"must be an integer (found $str)"))
+ case None => default
+ }
+
+ Try {
+ Pagination(validate("pageSize", Pagination.Default.pageSize),
+ validate("pageNumber", Pagination.Default.pageNumber))
+ }
+ }
+}
diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala
new file mode 100644
index 0000000..096c28f
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala
@@ -0,0 +1,3 @@
+package xyz.driver.restquery.http.parsers
+
+class ParseQueryArgException(val errors: (String, String)*) extends Exception(errors.mkString(","))
diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala
new file mode 100644
index 0000000..ce3009b
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala
@@ -0,0 +1,178 @@
+package xyz.driver.restquery.http.parsers
+
+import java.util.UUID
+
+import fastparse.all._
+import fastparse.core.Parsed
+import xyz.driver.restquery.domain.{SearchFilterBinaryOperation, SearchFilterExpr, SearchFilterNAryOperation}
+import xyz.driver.restquery.utils.Utils
+import xyz.driver.restquery.utils.Utils._
+
+import scala.util.Try
+
+@SuppressWarnings(Array("org.wartremover.warts.Product", "org.wartremover.warts.Serializable"))
+object SearchFilterParser {
+
+ private object BinaryAtomFromTuple {
+ def unapply(input: (SearchFilterExpr.Dimension, (String, Any))): Option[SearchFilterExpr.Atom.Binary] = {
+ val (dimensionName, (strOperation, value)) = input
+ val updatedValue = trimIfString(value)
+
+ parseOperation(strOperation.toLowerCase).map { op =>
+ SearchFilterExpr.Atom.Binary(dimensionName, op, updatedValue.asInstanceOf[AnyRef])
+ }
+ }
+ }
+
+ private object NAryAtomFromTuple {
+ // Compiler warning: unchecked since it is eliminated by erasure, if we user Seq[String]
+ def unapply(input: (SearchFilterExpr.Dimension, (String, Seq[_]))): Option[SearchFilterExpr.Atom.NAry] = {
+ val (dimensionName, (strOperation, xs)) = input
+ val updatedValues = xs.map(trimIfString)
+
+ if (strOperation.toLowerCase == "in") {
+ Some(
+ SearchFilterExpr.Atom
+ .NAry(dimensionName, SearchFilterNAryOperation.In, updatedValues.map(_.asInstanceOf[AnyRef])))
+ } else {
+ None
+ }
+ }
+ }
+
+ private def trimIfString(value: Any) =
+ value match {
+ case s: String => Utils.safeTrim(s)
+ case a => a
+ }
+
+ private val operationsMapping = {
+ import xyz.driver.restquery.domain.SearchFilterBinaryOperation._
+
+ Map[String, SearchFilterBinaryOperation](
+ "eq" -> Eq,
+ "noteq" -> NotEq,
+ "like" -> Like,
+ "gt" -> Gt,
+ "gteq" -> GtEq,
+ "lt" -> Lt,
+ "lteq" -> LtEq
+ )
+ }
+
+ private def parseOperation(x: String): Option[SearchFilterBinaryOperation] = operationsMapping.get(x)
+
+ private val whitespaceParser = P(CharPred(Utils.isSafeWhitespace))
+
+ val dimensionParser: Parser[SearchFilterExpr.Dimension] = {
+ val identParser = P(
+ CharPred(c => c.isLetterOrDigit)
+ .rep(min = 1)).!.map(s => SearchFilterExpr.Dimension(None, toSnakeCase(s)))
+ val pathParser = P(identParser.! ~ "." ~ identParser.!) map {
+ case (left, right) =>
+ SearchFilterExpr.Dimension(Some(toSnakeCase(left)), toSnakeCase(right))
+ }
+ P(pathParser | identParser)
+ }
+
+ private val commonOperatorParser: Parser[String] = {
+ P(IgnoreCase("eq") | IgnoreCase("like") | IgnoreCase("noteq")).!
+ }
+
+ private val numericOperatorParser: Parser[String] = {
+ P(IgnoreCase("eq") | IgnoreCase("noteq") | ((IgnoreCase("gt") | IgnoreCase("lt")) ~ IgnoreCase("eq").?)).!
+ }
+
+ private val naryOperatorParser: Parser[String] = P(IgnoreCase("in")).!
+
+ private val isPositiveParser: Parser[Boolean] = P(CharIn("-+").!.?).map {
+ case Some("-") => false
+ case _ => true
+ }
+
+ private val digitsParser: Parser[String] = P(CharIn('0' to '9').rep(min = 1).!) // Exclude Unicode "digits"
+
+ private val numberParser: Parser[String] = P(isPositiveParser ~ digitsParser.! ~ ("." ~ digitsParser).!.?).map {
+ case (false, intPart, Some(fracPart)) => s"-$intPart.${fracPart.tail}"
+ case (false, intPart, None) => s"-$intPart"
+ case (_, intPart, Some(fracPart)) => s"$intPart.${fracPart.tail}"
+ case (_, intPart, None) => s"$intPart"
+ }
+
+ private val nAryValueParser: Parser[String] = P(CharPred(_ != ',').rep(min = 1).!)
+
+ private val longParser: Parser[Long] = P(CharIn('0' to '9').rep(min = 1).!.map(_.toLong))
+
+ private val booleanParser: Parser[Boolean] =
+ P((IgnoreCase("true") | IgnoreCase("false")).!.map(_.toBoolean))
+
+ private val hexDigit: Parser[String] = P((CharIn('a' to 'f') | CharIn('A' to 'F') | CharIn('0' to '9')).!)
+
+ private val uuidParser: Parser[UUID] =
+ P(
+ hexDigit.rep(8).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit
+ .rep(12)
+ .!).map {
+ case (group1, group2, group3, group4, group5) => UUID.fromString(s"$group1-$group2-$group3-$group4-$group5")
+ }
+
+ private val binaryAtomParser: Parser[SearchFilterExpr.Atom.Binary] = P(
+ dimensionParser ~ whitespaceParser ~
+ ((numericOperatorParser.! ~ whitespaceParser ~ (longParser | numberParser.!) ~ End) |
+ (commonOperatorParser.! ~ whitespaceParser ~ (uuidParser | booleanParser | AnyChar.rep(min = 1).!) ~ End))
+ ).map {
+ case BinaryAtomFromTuple(atom) => atom
+ }
+
+ private val nAryAtomParser: Parser[SearchFilterExpr.Atom.NAry] = P(
+ dimensionParser ~ whitespaceParser ~ (
+ naryOperatorParser ~ whitespaceParser ~
+ ((longParser.rep(min = 1, sep = ",") ~ End) | (booleanParser.rep(min = 1, sep = ",") ~ End) |
+ (nAryValueParser.!.rep(min = 1, sep = ",") ~ End))
+ )
+ ).map {
+ case NAryAtomFromTuple(atom) => atom
+ }
+
+ private val atomParser: Parser[SearchFilterExpr.Atom] = P(binaryAtomParser | nAryAtomParser)
+
+ @deprecated("play-akka transition", "0")
+ def parse(query: Map[String, Seq[String]]): Try[SearchFilterExpr] =
+ parse(query.toSeq.flatMap {
+ case (key, values) =>
+ values.map(value => key -> value)
+ })
+
+ def parse(query: Seq[(String, String)]): Try[SearchFilterExpr] = Try {
+ query.toList.collect { case ("filters", value) => value } match {
+ case Nil => SearchFilterExpr.Empty
+
+ case head :: Nil =>
+ atomParser.parse(head) match {
+ case Parsed.Success(x, _) => x
+ case e: Parsed.Failure[_, _] => throw new ParseQueryArgException("filters" -> formatFailure(1, e))
+ }
+
+ case xs =>
+ val parsed = xs.map(x => atomParser.parse(x))
+ val failures: Seq[String] = parsed.zipWithIndex.collect {
+ case (e: Parsed.Failure[_, _], index) => formatFailure(index, e)
+ }
+
+ if (failures.isEmpty) {
+ val filters = parsed.collect {
+ case Parsed.Success(x, _) => x
+ }
+
+ SearchFilterExpr.Intersection.create(filters: _*)
+ } else {
+ throw new ParseQueryArgException("filters" -> failures.mkString("; "))
+ }
+ }
+ }
+
+ private def formatFailure(sectionIndex: Int, e: Parsed.Failure[_, _]): String = {
+ s"section $sectionIndex: ${fastparse.core.ParseError.msg(e.extra.input, e.extra.traced.expected, e.index)}"
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala
new file mode 100644
index 0000000..f2a3c04
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala
@@ -0,0 +1,64 @@
+package xyz.driver.restquery.http.parsers
+
+import fastparse.all._
+import fastparse.core.Parsed
+import xyz.driver.restquery.domain.{Sorting, SortingOrder}
+import xyz.driver.restquery.utils.Utils._
+
+import scala.util.Try
+
+object SortingParser {
+
+ private val sortingOrderParser: Parser[SortingOrder] = P("-".!.?).map {
+ case Some(_) => SortingOrder.Descending
+ case None => SortingOrder.Ascending
+ }
+
+ private def dimensionSortingParser(validDimensions: Seq[String]): Parser[Sorting.Dimension] = {
+ P(sortingOrderParser ~ StringIn(validDimensions: _*).!).map {
+ case (sortingOrder, field) =>
+ val prefixedFields = field.split("\\.", 2)
+ prefixedFields.size match {
+ case 1 => Sorting.Dimension(None, toSnakeCase(field), sortingOrder)
+ case 2 =>
+ Sorting.Dimension(Some(prefixedFields.head).map(toSnakeCase),
+ toSnakeCase(prefixedFields.last),
+ sortingOrder)
+ }
+ }
+ }
+
+ private def sequentialSortingParser(validDimensions: Seq[String]): Parser[Sorting.Sequential] = {
+ P(dimensionSortingParser(validDimensions).rep(min = 1, sep = ",") ~ End).map { dimensions =>
+ Sorting.Sequential(dimensions)
+ }
+ }
+
+ @deprecated("play-akka transition", "0")
+ def parse(validDimensions: Set[String], query: Map[String, Seq[String]]): Try[Sorting] =
+ parse(validDimensions, query.toSeq.flatMap {
+ case (key, values) =>
+ values.map(value => key -> value)
+ })
+
+ def parse(validDimensions: Set[String], query: Seq[(String, String)]): Try[Sorting] = Try {
+ query.toList.collect { case ("sort", value) => value } match {
+ case Nil => Sorting.Sequential(Seq.empty)
+
+ case rawSorting :: Nil =>
+ val parser = sequentialSortingParser(validDimensions.toSeq)
+ parser.parse(rawSorting) match {
+ case Parsed.Success(x, _) => x
+ case e: Parsed.Failure[_, _] =>
+ throw new ParseQueryArgException("sort" -> formatFailure(e))
+ }
+
+ case _ => throw new ParseQueryArgException("sort" -> "multiple sections are not allowed")
+ }
+ }
+
+ private def formatFailure(e: Parsed.Failure[_, _]): String = {
+ fastparse.core.ParseError.msg(e.extra.input, e.extra.traced.expected, e.index)
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/restquery/utils/Utils.scala b/src/main/scala/xyz/driver/restquery/utils/Utils.scala
new file mode 100644
index 0000000..86c65d7
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/utils/Utils.scala
@@ -0,0 +1,79 @@
+package xyz.driver.pdsuicommon.utils
+
+import java.time.LocalDateTime
+import java.util.regex.{Matcher, Pattern}
+
+object Utils {
+
+ implicit val localDateTimeOrdering: Ordering[LocalDateTime] = Ordering.fromLessThan(_ isBefore _)
+
+ /**
+ * Hack to avoid scala compiler bug with getSimpleName
+ * @see https://issues.scala-lang.org/browse/SI-2034
+ */
+ def getClassSimpleName(klass: Class[_]): String = {
+ try {
+ klass.getSimpleName
+ } catch {
+ case _: InternalError =>
+ val fullName = klass.getName.stripSuffix("$")
+ val fullClassName = fullName.substring(fullName.lastIndexOf(".") + 1)
+ fullClassName.substring(fullClassName.lastIndexOf("$") + 1)
+ }
+ }
+
+ def toSnakeCase(str: String): String =
+ str
+ .replaceAll("([A-Z]+)([A-Z][a-z])", "$1_$2")
+ .replaceAll("([a-z\\d])([A-Z])", "$1_$2")
+ .toLowerCase
+
+ def toCamelCase(str: String): String = {
+ val sb = new StringBuffer()
+ def loop(m: Matcher): Unit = if (m.find()) {
+ m.appendReplacement(sb, m.group(1).toUpperCase())
+ loop(m)
+ }
+ val m: Matcher = Pattern.compile("_(.)").matcher(str)
+ loop(m)
+ m.appendTail(sb)
+ sb.toString
+ }
+
+ object Whitespace {
+ private val Table: String =
+ "\u2002\u3000\r\u0085\u200A\u2005\u2000\u3000" +
+ "\u2029\u000B\u3000\u2008\u2003\u205F\u3000\u1680" +
+ "\u0009\u0020\u2006\u2001\u202F\u00A0\u000C\u2009" +
+ "\u3000\u2004\u3000\u3000\u2028\n\u2007\u3000"
+
+ private val Multiplier: Int = 1682554634
+ private val Shift: Int = Integer.numberOfLeadingZeros(Table.length - 1)
+
+ def matches(c: Char): Boolean = Table.charAt((Multiplier * c) >>> Shift) == c
+ }
+
+ def isSafeWhitespace(char: Char): Boolean = Whitespace.matches(char)
+
+ // From Guava
+ def isSafeControl(char: Char): Boolean =
+ char <= '\u001f' || (char >= '\u007f' && char <= '\u009f')
+
+
+ def safeTrim(string: String): String = {
+ def shouldKeep(c: Char): Boolean = !isSafeControl(c) && !isSafeWhitespace(c)
+
+ if (string.isEmpty) {
+ ""
+ } else {
+ val start = string.indexWhere(shouldKeep)
+ val end = string.lastIndexWhere(shouldKeep)
+
+ if (start >= 0 && end >= 0) {
+ string.substring(start, end + 1)
+ } else {
+ ""
+ }
+ }
+ }
+}