diff options
Diffstat (limited to 'src/main/scala/xyz/driver/restquery')
15 files changed, 1449 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala b/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala new file mode 100644 index 0000000..7038aa0 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/db/SlickPostgresQueryBuilder.scala @@ -0,0 +1,116 @@ +package xyz.driver.pdsuicommon.db + +import java.time.{LocalDateTime, ZoneOffset} + +import org.slf4j.LoggerFactory +import slick.jdbc.{GetResult, JdbcProfile} +import xyz.driver.core.database.SlickDal + +import scala.collection.breakOut +import scala.concurrent.ExecutionContext + +object SlickPostgresQueryBuilder { + private val logger = LoggerFactory.getLogger(this.getClass) + + import xyz.driver.pdsuicommon.db.SlickQueryBuilder._ + + def apply[T](databaseName: String, + tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[SlickTableLink], + runner: Runner[T], + countRunner: CountRunner)(implicit sqlContext: SlickDal, + profile: JdbcProfile, + getResult: GetResult[T], + ec: ExecutionContext): SlickPostgresQueryBuilder[T] = { + val parameters = SlickPostgresQueryBuilderParameters( + databaseName = databaseName, + tableData = TableData(tableName, lastUpdateFieldName, nullableFields), + links = links.map(x => x.foreignTableName -> x)(breakOut) + ) + new SlickPostgresQueryBuilder[T](parameters)(runner, countRunner) + } + + def apply[T](databaseName: String, + tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[SlickTableLink])(implicit sqlContext: SlickDal, + profile: JdbcProfile, + getResult: GetResult[T], + ec: ExecutionContext): SlickPostgresQueryBuilder[T] = { + apply[T](databaseName, + tableName, + SlickQueryBuilderParameters.AllFields, + lastUpdateFieldName, + nullableFields, + links) + } + + def apply[T](databaseName: String, + tableName: String, + fields: Set[String], + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[SlickTableLink])(implicit sqlContext: SlickDal, + profile: JdbcProfile, + getResult: GetResult[T], + ec: ExecutionContext): SlickPostgresQueryBuilder[T] = { + + val runner: Runner[T] = { parameters => + val sql = parameters.toSql(countQuery = false, fields = fields).as[T] + logger.debug(s"Built an SQL query: $sql") + sqlContext.execute(sql) + } + + val countRunner: CountRunner = { parameters => + implicit val getCountResult: GetResult[(Int, Option[LocalDateTime])] = GetResult({ r => + val count = r.rs.getInt(1) + val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) { + Option(r.rs.getTimestamp(2)).map(timestampToLocalDateTime) + } else None + (count, lastUpdate) + }) + val sql = parameters.toSql(countQuery = true).as[(Int, Option[LocalDateTime])] + logger.debug(s"Built an SQL query returning count: $sql") + sqlContext.execute(sql).map(_.head) + } + + apply[T]( + databaseName = databaseName, + tableName = tableName, + lastUpdateFieldName = lastUpdateFieldName, + nullableFields = nullableFields, + links = links, + runner = runner, + countRunner = countRunner + ) + } + + def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = { + LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC) + } +} + +class SlickPostgresQueryBuilder[T](parameters: SlickPostgresQueryBuilderParameters)( + implicit runner: SlickQueryBuilder.Runner[T], + countRunner: SlickQueryBuilder.CountRunner) + extends SlickQueryBuilder[T](parameters) { + + def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] = { + new SlickPostgresQueryBuilder[T](parameters.copy(filter = newFilter)) + } + + def withSorting(newSorting: Sorting): SlickQueryBuilder[T] = { + new SlickPostgresQueryBuilder[T](parameters.copy(sorting = newSorting)) + } + + def withPagination(newPagination: Pagination): SlickQueryBuilder[T] = { + new SlickPostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination))) + } + + def resetPagination: SlickQueryBuilder[T] = { + new SlickPostgresQueryBuilder[T](parameters.copy(pagination = None)) + } +} diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala new file mode 100644 index 0000000..9962edf --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala @@ -0,0 +1,387 @@ +package xyz.driver.pdsuicommon.db + +import java.sql.{JDBCType, PreparedStatement} +import java.time.LocalDateTime + +import slick.jdbc.{JdbcProfile, PositionedParameters, SQLActionBuilder, SetParameter} +import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential} +import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending} +import xyz.driver.pdsuicommon.domain.{LongId, StringId, UuidId} + +import scala.concurrent.{ExecutionContext, Future} + +object SlickQueryBuilder { + + type Runner[T] = SlickQueryBuilderParameters => Future[Seq[T]] + + type CountResult = Future[(Int, Option[LocalDateTime])] + + type CountRunner = SlickQueryBuilderParameters => CountResult + + /** + * Binder for PreparedStatement + */ + type Binder = PreparedStatement => PreparedStatement + + final case class TableData(tableName: String, + lastUpdateFieldName: Option[String] = None, + nullableFields: Set[String] = Set.empty) + + val AllFields = Set("*") + + implicit class SQLActionBuilderConcat(a: SQLActionBuilder) { + def concat(b: SQLActionBuilder): SQLActionBuilder = { + SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] { + def apply(p: Unit, pp: PositionedParameters): Unit = { + a.unitPConv.apply(p, pp) + b.unitPConv.apply(p, pp) + } + }) + } + } + + implicit object SetQueryParameter extends SetParameter[AnyRef] { + def apply(v: AnyRef, pp: PositionedParameters) = { + pp.setObject(v, JDBCType.BINARY.getVendorTypeNumber) + } + } + + implicit def setLongIdQueryParameter[T]: SetParameter[LongId[T]] = SetParameter[LongId[T]] { (v, pp) => + pp.setLong(v.id) + } + + implicit def setStringIdQueryParameter[T]: SetParameter[StringId[T]] = SetParameter[StringId[T]] { (v, pp) => + pp.setString(v.id) + } + + implicit def setUuidIdQueryParameter[T]: SetParameter[UuidId[T]] = SetParameter[UuidId[T]] { (v, pp) => + pp.setObject(v.id, JDBCType.BINARY.getVendorTypeNumber) + } +} + +final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String) + +object SlickQueryBuilderParameters { + val AllFields = Set("*") +} + +sealed trait SlickQueryBuilderParameters { + import SlickQueryBuilder._ + + def databaseName: String + def tableData: SlickQueryBuilder.TableData + def links: Map[String, SlickTableLink] + def filter: SearchFilterExpr + def sorting: Sorting + def pagination: Option[Pagination] + + def qs: String + + def findLink(tableName: String): SlickTableLink = links.get(tableName) match { + case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") + case Some(link) => link + } + + def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = { + toSql(countQuery, SlickQueryBuilderParameters.AllFields) + } + + def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = { + import profile.api._ + val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs""" + val fieldsSql: String = if (countQuery) { + val suffix: String = tableData.lastUpdateFieldName match { + case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)" + case None => "" + } + s"count(*) $suffix" + } else { + if (fields == SlickQueryBuilderParameters.AllFields) { + s"$escapedTableName.*" + } else { + fields + .map { field => + s"$escapedTableName.$qs$field$qs" + } + .mkString(", ") + } + } + val where = filterToSql(escapedTableName, filter) + val orderBy = sortingToSql(escapedTableName, sorting) + + val limitSql = limitToSql() + + val sql = sql"""select #$fieldsSql from #$escapedTableName""" + + val filtersTableLinks: Seq[SlickTableLink] = { + import SearchFilterExpr._ + def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match { + case Atom.TableName(tableName) => List(findLink(tableName)) + case Intersection(xs) => xs.flatMap(aux) + case Union(xs) => xs.flatMap(aux) + case _ => Nil + } + aux(filter) + } + + val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) { + case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) + } + + // Combine links from sorting and filter without duplicates + val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct + + def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = { + if (tableLinks.nonEmpty) { + tableLinks.head match { + case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => + val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs" + val join = sql""" inner join #$escapedForeignTableName + on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs""" + fkSql(fkLinksSql concat join, tableLinks.tail) + } + } else fkLinksSql + } + val foreignTableLinksSql = fkSql(sql"", foreignTableLinks) + + val whereSql = if (where.queryParts.size > 1) { + sql" where " concat where + } else sql"" + + val orderSql = if (orderBy.nonEmpty && !countQuery) { + sql" order by #$orderBy" + } else sql"" + + val limSql = if (limitSql.queryParts.size > 1 && !countQuery) { + sql" " concat limitSql + } else sql"" + + sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql + } + + /** + * Converts filter expression to SQL expression. + * + * @return Returns SQL string and list of values for binding in prepared statement. + */ + protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)( + implicit profile: JdbcProfile): SQLActionBuilder = { + import SearchFilterBinaryOperation._ + import SearchFilterExpr._ + import profile.api._ + + def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" + + def escapeDimension(dimension: SearchFilterExpr.Dimension) = { + s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs" + } + + def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { + case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x) + } + + def multipleSqlToAction(first: Boolean, + op: String, + conditions: Seq[SQLActionBuilder], + sql: SQLActionBuilder): SQLActionBuilder = { + if (conditions.nonEmpty) { + val condition = conditions.head + if (first) { + multipleSqlToAction(first = false, op, conditions.tail, condition) + } else { + multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition) + } + } else sql + } + + def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = { + if (tail.nonEmpty) { + if (!first) { + concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail) + } else { + concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail) + } + } else sql concat sql")" + } + + filter match { + case x if isEmpty(x) => + sql"" + + case AllowAll => + sql"1=1" + + case DenyAll => + sql"1=0" + + case Atom.Binary(dimension, Eq, value) if isNull(value) => + sql"#${escapeDimension(dimension)} is NULL" + + case Atom.Binary(dimension, NotEq, value) if isNull(value) => + sql"#${escapeDimension(dimension)} is not NULL" + + case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => + // In MySQL NULL <> Any === NULL + // So, to handle NotEq for nullable fields we need to use more complex SQL expression. + // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html + val escapedColumn = escapeDimension(dimension) + sql"(#${escapedColumn} is null or #${escapedColumn} != $value)" + + case Atom.Binary(dimension, op, value) => + val operator = op match { + case Eq => sql"=" + case NotEq => sql"!=" + case Like => sql" like " + case Gt => sql">" + case GtEq => sql">=" + case Lt => sql"<" + case LtEq => sql"<=" + } + sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value""" + + case Atom.NAry(dimension, op, values) => + val sqlOp = op match { + case SearchFilterNAryOperation.In => sql" in " + case SearchFilterNAryOperation.NotIn => sql" not in " + } + + if (values.nonEmpty) { + val formattedValues = concatenateParameters(sql"", first = true, values) + sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues + } else { + sql"1=0" + } + + case Intersection(operands) => + val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"") + sql"(" concat filter concat sql")" + + case Union(operands) => + val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"") + sql"(" concat filter concat sql")" + } + } + + protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder + + /** + * @param escapedMainTableName Should be escaped + */ + protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = { + sorting match { + case Dimension(optSortingTableName, field, order) => + val sortingTableName = + optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName) + val fullName = s"$sortingTableName.$qs$field$qs" + + s"$fullName ${orderToSql(order)}" + + case Sequential(xs) => + xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ") + } + } + + protected def orderToSql(x: SortingOrder): String = x match { + case Ascending => "asc" + case Descending => "desc" + } + + protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { + bindings.zipWithIndex.foreach { + case (binding, index) => + bind.setObject(index + 1, binding) + } + + bind + } + +} + +final case class SlickPostgresQueryBuilderParameters(databaseName: String, + tableData: SlickQueryBuilder.TableData, + links: Map[String, SlickTableLink] = Map.empty, + filter: SearchFilterExpr = SearchFilterExpr.Empty, + sorting: Sorting = Sorting.Empty, + pagination: Option[Pagination] = None) + extends SlickQueryBuilderParameters { + + def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = { + import profile.api._ + pagination.map { pagination => + val startFrom = (pagination.pageNumber - 1) * pagination.pageSize + sql"limit #${pagination.pageSize} OFFSET #$startFrom" + } getOrElse (sql"") + } + + val qs = """"""" + +} + +/** + * @param links Links to another tables grouped by foreignTableName + */ +final case class SlickMysqlQueryBuilderParameters(databaseName: String, + tableData: SlickQueryBuilder.TableData, + links: Map[String, SlickTableLink] = Map.empty, + filter: SearchFilterExpr = SearchFilterExpr.Empty, + sorting: Sorting = Sorting.Empty, + pagination: Option[Pagination] = None) + extends SlickQueryBuilderParameters { + + def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = { + import profile.api._ + pagination + .map { pagination => + val startFrom = (pagination.pageNumber - 1) * pagination.pageSize + sql"limit #$startFrom, #${pagination.pageSize}" + } + .getOrElse(sql"") + } + + val qs = """`""" + +} + +abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)( + implicit runner: SlickQueryBuilder.Runner[T], + countRunner: SlickQueryBuilder.CountRunner) { + + def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters) + + def runCount()(implicit ec: ExecutionContext): SlickQueryBuilder.CountResult = countRunner(parameters) + + /** + * Runs the query and returns total found rows without considering of pagination. + */ + def runWithCount()(implicit ec: ExecutionContext): Future[(Seq[T], Int, Option[LocalDateTime])] = { + for { + all <- run + (total, lastUpdate) <- runCount + } yield (all, total, lastUpdate) + } + + def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] + + def withFilter(filter: Option[SearchFilterExpr]): SlickQueryBuilder[T] = { + filter.fold(this)(withFilter) + } + + def resetFilter: SlickQueryBuilder[T] = withFilter(SearchFilterExpr.Empty) + + def withSorting(newSorting: Sorting): SlickQueryBuilder[T] + + def withSorting(sorting: Option[Sorting]): SlickQueryBuilder[T] = { + sorting.fold(this)(withSorting) + } + + def resetSorting: SlickQueryBuilder[T] = withSorting(Sorting.Empty) + + def withPagination(newPagination: Pagination): SlickQueryBuilder[T] + + def withPagination(pagination: Option[Pagination]): SlickQueryBuilder[T] = { + pagination.fold(this)(withPagination) + } + + def resetPagination: SlickQueryBuilder[T] + +} diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala new file mode 100644 index 0000000..6ab7eb4 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala @@ -0,0 +1,240 @@ +package xyz.driver.restquery.query + +import java.sql.PreparedStatement + +import slick.jdbc.{JdbcProfile, SQLActionBuilder} +import xyz.driver.restquery.db.{SlickQueryBuilder, SlickQueryBuilderParameters, SlickTableLink} +import xyz.driver.restquery.query.Sorting.{Dimension, Sequential} +import xyz.driver.restquery.query.SortingOrder.{Ascending, Descending} + +trait SlickQueryBuilderParameters { + import SlickQueryBuilder._ + + def databaseName: String + def tableData: SlickQueryBuilder.TableData + def links: Map[String, SlickTableLink] + def filter: SearchFilterExpr + def sorting: Sorting + def pagination: Option[Pagination] + + def qs: String + + def findLink(tableName: String): SlickTableLink = links.get(tableName) match { + case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") + case Some(link) => link + } + + def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = { + toSql(countQuery, SlickQueryBuilderParameters.AllFields) + } + + def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = { + import profile.api._ + val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs""" + val fieldsSql: String = if (countQuery) { + val suffix: String = tableData.lastUpdateFieldName match { + case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)" + case None => "" + } + s"count(*) $suffix" + } else { + if (fields == SlickQueryBuilderParameters.AllFields) { + s"$escapedTableName.*" + } else { + fields + .map { field => + s"$escapedTableName.$qs$field$qs" + } + .mkString(", ") + } + } + val where = filterToSql(escapedTableName, filter) + val orderBy = sortingToSql(escapedTableName, sorting) + + val limitSql = limitToSql() + + val sql = sql"""select #$fieldsSql from #$escapedTableName""" + + val filtersTableLinks: Seq[SlickTableLink] = { + import SearchFilterExpr._ + def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match { + case Atom.TableName(tableName) => List(findLink(tableName)) + case Intersection(xs) => xs.flatMap(aux) + case Union(xs) => xs.flatMap(aux) + case _ => Nil + } + aux(filter) + } + + val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) { + case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) + } + + // Combine links from sorting and filter without duplicates + val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct + + def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = { + if (tableLinks.nonEmpty) { + tableLinks.head match { + case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => + val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs" + val join = sql""" inner join #$escapedForeignTableName + on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs""" + fkSql(fkLinksSql concat join, tableLinks.tail) + } + } else fkLinksSql + } + val foreignTableLinksSql = fkSql(sql"", foreignTableLinks) + + val whereSql = if (where.queryParts.size > 1) { + sql" where " concat where + } else sql"" + + val orderSql = if (orderBy.nonEmpty && !countQuery) { + sql" order by #$orderBy" + } else sql"" + + val limSql = if (limitSql.queryParts.size > 1 && !countQuery) { + sql" " concat limitSql + } else sql"" + + sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql + } + + /** + * Converts filter expression to SQL expression. + * + * @return Returns SQL string and list of values for binding in prepared statement. + */ + protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)( + implicit profile: JdbcProfile): SQLActionBuilder = { + import SearchFilterBinaryOperation._ + import SearchFilterExpr._ + import profile.api._ + + def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" + + def escapeDimension(dimension: SearchFilterExpr.Dimension) = { + s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs" + } + + def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { + case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x) + } + + def multipleSqlToAction(first: Boolean, + op: String, + conditions: Seq[SQLActionBuilder], + sql: SQLActionBuilder): SQLActionBuilder = { + if (conditions.nonEmpty) { + val condition = conditions.head + if (first) { + multipleSqlToAction(first = false, op, conditions.tail, condition) + } else { + multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition) + } + } else sql + } + + def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = { + if (tail.nonEmpty) { + if (!first) { + concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail) + } else { + concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail) + } + } else sql concat sql")" + } + + filter match { + case x if isEmpty(x) => + sql"" + + case AllowAll => + sql"1=1" + + case DenyAll => + sql"1=0" + + case Atom.Binary(dimension, Eq, value) if isNull(value) => + sql"#${escapeDimension(dimension)} is NULL" + + case Atom.Binary(dimension, NotEq, value) if isNull(value) => + sql"#${escapeDimension(dimension)} is not NULL" + + case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => + // In MySQL NULL <> Any === NULL + // So, to handle NotEq for nullable fields we need to use more complex SQL expression. + // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html + val escapedColumn = escapeDimension(dimension) + sql"(#${escapedColumn} is null or #${escapedColumn} != $value)" + + case Atom.Binary(dimension, op, value) => + val operator = op match { + case Eq => sql"=" + case NotEq => sql"!=" + case Like => sql" like " + case Gt => sql">" + case GtEq => sql">=" + case Lt => sql"<" + case LtEq => sql"<=" + } + sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value""" + + case Atom.NAry(dimension, op, values) => + val sqlOp = op match { + case SearchFilterNAryOperation.In => sql" in " + case SearchFilterNAryOperation.NotIn => sql" not in " + } + + if (values.nonEmpty) { + val formattedValues = concatenateParameters(sql"", first = true, values) + sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues + } else { + sql"1=0" + } + + case Intersection(operands) => + val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"") + sql"(" concat filter concat sql")" + + case Union(operands) => + val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"") + sql"(" concat filter concat sql")" + } + } + + protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder + + /** + * @param escapedMainTableName Should be escaped + */ + protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = { + sorting match { + case Dimension(optSortingTableName, field, order) => + val sortingTableName = + optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName) + val fullName = s"$sortingTableName.$qs$field$qs" + + s"$fullName ${orderToSql(order)}" + + case Sequential(xs) => + xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ") + } + } + + protected def orderToSql(x: SortingOrder): String = x match { + case Ascending => "asc" + case Descending => "desc" + } + + protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { + bindings.zipWithIndex.foreach { + case (binding, index) => + bind.setObject(index + 1, binding) + } + + bind + } + +} diff --git a/src/main/scala/xyz/driver/restquery/query/Pagination.scala b/src/main/scala/xyz/driver/restquery/query/Pagination.scala new file mode 100644 index 0000000..27b8f12 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/query/Pagination.scala @@ -0,0 +1,12 @@ +package xyz.driver.restquery.domain + +/** + * @param pageNumber Starts with 1 + */ +final case class Pagination(pageSize: Int, pageNumber: Int) + +object Pagination { + + // @see https://driverinc.atlassian.net/wiki/display/RA/REST+API+Specification#RESTAPISpecification-CommonRequestQueryParametersForWebServices + val Default = Pagination(pageSize = 100, pageNumber = 1) +} diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala new file mode 100644 index 0000000..dab466b --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterBinaryOperation.scala @@ -0,0 +1,5 @@ +package xyz.driver.restquery.query + +class SearchFilterBinaryOperation { + +} diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala new file mode 100644 index 0000000..6d2cb9a --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterExpr.scala @@ -0,0 +1,196 @@ +package xyz.driver.restquery.domain + +sealed trait SearchFilterExpr { + def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] + def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr +} + +object SearchFilterExpr { + + val Empty: Intersection = Intersection.Empty + val Forbid = Atom.Binary( + dimension = Dimension(None, "true"), + op = SearchFilterBinaryOperation.Eq, + value = "false" + ) + + final case class Dimension(tableName: Option[String], name: String) { + def isForeign: Boolean = tableName.isDefined + } + + sealed trait Atom extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + Some(this).filter(p) + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + object Atom { + final case class Binary(dimension: Dimension, op: SearchFilterBinaryOperation, value: AnyRef) extends Atom + object Binary { + def apply(field: String, op: SearchFilterBinaryOperation, value: AnyRef): Binary = + Binary(Dimension(None, field), op, value) + } + + final case class NAry(dimension: Dimension, op: SearchFilterNAryOperation, values: Seq[AnyRef]) extends Atom + object NAry { + def apply(field: String, op: SearchFilterNAryOperation, values: Seq[AnyRef]): NAry = + NAry(Dimension(None, field), op, values) + } + + /** dimension.tableName extractor */ + object TableName { + def unapply(value: Atom): Option[String] = value match { + case Binary(Dimension(tableNameOpt, _), _, _) => tableNameOpt + case NAry(Dimension(tableNameOpt, _), _, _) => tableNameOpt + } + } + } + + final case class Intersection private (operands: Seq[SearchFilterExpr]) + extends SearchFilterExpr with SearchFilterExprSeqOps { + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else { + this.copy(operands.map(_.replace(f))) + } + } + + } + + object Intersection { + + val Empty = Intersection(Seq()) + + def create(operands: SearchFilterExpr*): SearchFilterExpr = { + val filtered = operands.filterNot(SearchFilterExpr.isEmpty) + filtered.size match { + case 0 => Empty + case 1 => filtered.head + case _ => Intersection(filtered) + } + } + } + + final case class Union private (operands: Seq[SearchFilterExpr]) + extends SearchFilterExpr with SearchFilterExprSeqOps { + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else { + this.copy(operands.map(_.replace(f))) + } + } + + } + + object Union { + + val Empty = Union(Seq()) + + def create(operands: SearchFilterExpr*): SearchFilterExpr = { + val filtered = operands.filterNot(SearchFilterExpr.isEmpty) + filtered.size match { + case 0 => Empty + case 1 => filtered.head + case _ => Union(filtered) + } + } + + def create(dimension: Dimension, values: String*): SearchFilterExpr = values.size match { + case 0 => SearchFilterExpr.Empty + case 1 => SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, values.head) + case _ => + val filters = values.map { value => + SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, value) + } + + create(filters: _*) + } + + def create(dimension: Dimension, values: Set[String]): SearchFilterExpr = + create(dimension, values.toSeq: _*) + + // Backwards compatible API + + /** Create SearchFilterExpr with empty tableName */ + def create(field: String, values: String*): SearchFilterExpr = + create(Dimension(None, field), values: _*) + + /** Create SearchFilterExpr with empty tableName */ + def create(field: String, values: Set[String]): SearchFilterExpr = + create(Dimension(None, field), values) + } + + case object AllowAll extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + Some(this).filter(p) + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + case object DenyAll extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + Some(this).filter(p) + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + def isEmpty(expr: SearchFilterExpr): Boolean = { + expr == Intersection.Empty || expr == Union.Empty + } + + sealed trait SearchFilterExprSeqOps { this: SearchFilterExpr => + + val operands: Seq[SearchFilterExpr] + + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + if (p(this)) Some(this) + else { + // Search the first expr among operands, which satisfy p + // Is's ok to use foldLeft. If there will be performance issues, replace it by recursive loop + operands.foldLeft(Option.empty[SearchFilterExpr]) { + case (None, expr) => expr.find(p) + case (x, _) => x + } + } + } + } + +} + +sealed trait SearchFilterBinaryOperation + +object SearchFilterBinaryOperation { + + case object Eq extends SearchFilterBinaryOperation + case object NotEq extends SearchFilterBinaryOperation + case object Like extends SearchFilterBinaryOperation + case object Gt extends SearchFilterBinaryOperation + case object GtEq extends SearchFilterBinaryOperation + case object Lt extends SearchFilterBinaryOperation + case object LtEq extends SearchFilterBinaryOperation + +} + +sealed trait SearchFilterNAryOperation + +object SearchFilterNAryOperation { + + case object In extends SearchFilterNAryOperation + case object NotIn extends SearchFilterNAryOperation + +} diff --git a/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala b/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala new file mode 100644 index 0000000..0604c8f --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/query/SearchFilterNAryOperation.scala @@ -0,0 +1,5 @@ +package xyz.driver.restquery.query + +class SearchFilterNAryOperation { + +} diff --git a/src/main/scala/xyz/driver/restquery/query/Sorting.scala b/src/main/scala/xyz/driver/restquery/query/Sorting.scala new file mode 100644 index 0000000..e2642ad --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/query/Sorting.scala @@ -0,0 +1,56 @@ +package xyz.driver.restquery.domain + +import scala.collection.generic.CanBuildFrom + +sealed trait SortingOrder +object SortingOrder { + + case object Ascending extends SortingOrder + case object Descending extends SortingOrder + +} + +sealed trait Sorting + +object Sorting { + + val Empty = Sequential(Seq.empty) + + /** + * @param tableName None if the table is default (same) + * @param name Dimension name + * @param order Order + */ + final case class Dimension(tableName: Option[String], name: String, order: SortingOrder) extends Sorting { + def isForeign: Boolean = tableName.isDefined + } + + final case class Sequential(sorting: Seq[Dimension]) extends Sorting { + override def toString: String = if (isEmpty(this)) "Empty" else super.toString + } + + def isEmpty(input: Sorting): Boolean = { + input match { + case Sequential(Seq()) => true + case _ => false + } + } + + def filter(sorting: Sorting, p: Dimension => Boolean): Seq[Dimension] = sorting match { + case x: Dimension if p(x) => Seq(x) + case _: Dimension => Seq.empty + case Sequential(xs) => xs.filter(p) + } + + def collect[B, That](sorting: Sorting)(f: PartialFunction[Dimension, B])( + implicit bf: CanBuildFrom[Seq[Dimension], B, That]): That = sorting match { + case x: Dimension if f.isDefinedAt(x) => + val r = bf.apply() + r += f(x) + r.result() + + case _: Dimension => bf.apply().result() + case Sequential(xs) => xs.collect(f) + } + +} diff --git a/src/main/scala/xyz/driver/restquery/rest/Directives.scala b/src/main/scala/xyz/driver/restquery/rest/Directives.scala new file mode 100644 index 0000000..2936f70 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/Directives.scala @@ -0,0 +1,55 @@ +package xyz.driver.restquery.http + +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import xyz.driver.restquery.domain.{SearchFilterExpr, _} +import xyz.driver.restquery.http.parsers._ + +import scala.util._ + +trait Directives { + + val paginated: Directive1[Pagination] = parameterSeq.flatMap { params => + PaginationParser.parse(params) match { + case Success(pagination) => provide(pagination) + case Failure(ex) => + reject(ValidationRejection("invalid pagination parameter", Some(ex))) + } + } + + def sorted(validDimensions: Set[String] = Set.empty): Directive1[Sorting] = parameterSeq.flatMap { params => + SortingParser.parse(validDimensions, params) match { + case Success(sorting) => provide(sorting) + case Failure(ex) => + reject(ValidationRejection("invalid sorting parameter", Some(ex))) + } + } + + val dimensioned: Directive1[Dimensions] = parameterSeq.flatMap { params => + DimensionsParser.tryParse(params) match { + case Success(dims) => provide(dims) + case Failure(ex) => + reject(ValidationRejection("invalid dimension parameter", Some(ex))) + } + } + + val searchFiltered: Directive1[SearchFilterExpr] = parameterSeq.flatMap { params => + SearchFilterParser.parse(params) match { + case Success(sorting) => provide(sorting) + case Failure(ex) => + reject(ValidationRejection("invalid filter parameter", Some(ex))) + } + } + + def StringIdInPath[T]: PathMatcher1[StringId[T]] = + PathMatchers.Segment.map((id) => StringId(id.toString)) + + def LongIdInPath[T]: PathMatcher1[LongId[T]] = + PathMatchers.LongNumber.map((id) => LongId(id)) + + def UuidIdInPath[T]: PathMatcher1[UuidId[T]] = + PathMatchers.JavaUUID.map((id) => UuidId(id)) + +} + +object Directives extends Directives diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala new file mode 100644 index 0000000..7e139db --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/parsers/DimensionsParser.scala @@ -0,0 +1,30 @@ +package xyz.driver.restquery.http.parsers + +import scala.util.{Failure, Success, Try} + +class Dimensions(private val xs: Set[String] = Set.empty) { + def contains(x: String): Boolean = xs.isEmpty || xs.contains(x) +} + +object DimensionsParser { + + @deprecated("play-akka transition", "0") + def tryParse(query: Map[String, Seq[String]]): Try[Dimensions] = + tryParse(query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) + }) + + def tryParse(query: Seq[(String, String)]): Try[Dimensions] = { + query.collect { case ("dimensions", value) => value } match { + case Nil => Success(new Dimensions()) + + case x +: Nil => + val raw: Set[String] = x.split(",").view.map(_.trim).filter(_.nonEmpty).to[Set] + Success(new Dimensions(raw)) + + case xs => + Failure(new IllegalArgumentException(s"Dimensions are specified ${xs.size} times")) + } + } +} diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala new file mode 100644 index 0000000..2b4547b --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/parsers/PaginationParser.scala @@ -0,0 +1,23 @@ +package xyz.driver.restquery.http.parsers + +import xyz.driver.restquery.domain.Pagination + +import scala.util._ + +object PaginationParser { + + def parse(query: Seq[(String, String)]): Try[Pagination] = { + val IntString = """(\d+)""".r + def validate(field: String, default: Int) = query.collectFirst { case (`field`, size) => size } match { + case Some(IntString(x)) if x.toInt > 0 => x.toInt + case Some(IntString(x)) => throw new ParseQueryArgException((field, s"must greater than zero (found $x)")) + case Some(str) => throw new ParseQueryArgException((field, s"must be an integer (found $str)")) + case None => default + } + + Try { + Pagination(validate("pageSize", Pagination.Default.pageSize), + validate("pageNumber", Pagination.Default.pageNumber)) + } + } +} diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala new file mode 100644 index 0000000..096c28f --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/parsers/ParseQueryArgException.scala @@ -0,0 +1,3 @@ +package xyz.driver.restquery.http.parsers + +class ParseQueryArgException(val errors: (String, String)*) extends Exception(errors.mkString(",")) diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala new file mode 100644 index 0000000..ce3009b --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/parsers/SearchFilterParser.scala @@ -0,0 +1,178 @@ +package xyz.driver.restquery.http.parsers + +import java.util.UUID + +import fastparse.all._ +import fastparse.core.Parsed +import xyz.driver.restquery.domain.{SearchFilterBinaryOperation, SearchFilterExpr, SearchFilterNAryOperation} +import xyz.driver.restquery.utils.Utils +import xyz.driver.restquery.utils.Utils._ + +import scala.util.Try + +@SuppressWarnings(Array("org.wartremover.warts.Product", "org.wartremover.warts.Serializable")) +object SearchFilterParser { + + private object BinaryAtomFromTuple { + def unapply(input: (SearchFilterExpr.Dimension, (String, Any))): Option[SearchFilterExpr.Atom.Binary] = { + val (dimensionName, (strOperation, value)) = input + val updatedValue = trimIfString(value) + + parseOperation(strOperation.toLowerCase).map { op => + SearchFilterExpr.Atom.Binary(dimensionName, op, updatedValue.asInstanceOf[AnyRef]) + } + } + } + + private object NAryAtomFromTuple { + // Compiler warning: unchecked since it is eliminated by erasure, if we user Seq[String] + def unapply(input: (SearchFilterExpr.Dimension, (String, Seq[_]))): Option[SearchFilterExpr.Atom.NAry] = { + val (dimensionName, (strOperation, xs)) = input + val updatedValues = xs.map(trimIfString) + + if (strOperation.toLowerCase == "in") { + Some( + SearchFilterExpr.Atom + .NAry(dimensionName, SearchFilterNAryOperation.In, updatedValues.map(_.asInstanceOf[AnyRef]))) + } else { + None + } + } + } + + private def trimIfString(value: Any) = + value match { + case s: String => Utils.safeTrim(s) + case a => a + } + + private val operationsMapping = { + import xyz.driver.restquery.domain.SearchFilterBinaryOperation._ + + Map[String, SearchFilterBinaryOperation]( + "eq" -> Eq, + "noteq" -> NotEq, + "like" -> Like, + "gt" -> Gt, + "gteq" -> GtEq, + "lt" -> Lt, + "lteq" -> LtEq + ) + } + + private def parseOperation(x: String): Option[SearchFilterBinaryOperation] = operationsMapping.get(x) + + private val whitespaceParser = P(CharPred(Utils.isSafeWhitespace)) + + val dimensionParser: Parser[SearchFilterExpr.Dimension] = { + val identParser = P( + CharPred(c => c.isLetterOrDigit) + .rep(min = 1)).!.map(s => SearchFilterExpr.Dimension(None, toSnakeCase(s))) + val pathParser = P(identParser.! ~ "." ~ identParser.!) map { + case (left, right) => + SearchFilterExpr.Dimension(Some(toSnakeCase(left)), toSnakeCase(right)) + } + P(pathParser | identParser) + } + + private val commonOperatorParser: Parser[String] = { + P(IgnoreCase("eq") | IgnoreCase("like") | IgnoreCase("noteq")).! + } + + private val numericOperatorParser: Parser[String] = { + P(IgnoreCase("eq") | IgnoreCase("noteq") | ((IgnoreCase("gt") | IgnoreCase("lt")) ~ IgnoreCase("eq").?)).! + } + + private val naryOperatorParser: Parser[String] = P(IgnoreCase("in")).! + + private val isPositiveParser: Parser[Boolean] = P(CharIn("-+").!.?).map { + case Some("-") => false + case _ => true + } + + private val digitsParser: Parser[String] = P(CharIn('0' to '9').rep(min = 1).!) // Exclude Unicode "digits" + + private val numberParser: Parser[String] = P(isPositiveParser ~ digitsParser.! ~ ("." ~ digitsParser).!.?).map { + case (false, intPart, Some(fracPart)) => s"-$intPart.${fracPart.tail}" + case (false, intPart, None) => s"-$intPart" + case (_, intPart, Some(fracPart)) => s"$intPart.${fracPart.tail}" + case (_, intPart, None) => s"$intPart" + } + + private val nAryValueParser: Parser[String] = P(CharPred(_ != ',').rep(min = 1).!) + + private val longParser: Parser[Long] = P(CharIn('0' to '9').rep(min = 1).!.map(_.toLong)) + + private val booleanParser: Parser[Boolean] = + P((IgnoreCase("true") | IgnoreCase("false")).!.map(_.toBoolean)) + + private val hexDigit: Parser[String] = P((CharIn('a' to 'f') | CharIn('A' to 'F') | CharIn('0' to '9')).!) + + private val uuidParser: Parser[UUID] = + P( + hexDigit.rep(8).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit.rep(4).! ~ "-" ~ hexDigit + .rep(12) + .!).map { + case (group1, group2, group3, group4, group5) => UUID.fromString(s"$group1-$group2-$group3-$group4-$group5") + } + + private val binaryAtomParser: Parser[SearchFilterExpr.Atom.Binary] = P( + dimensionParser ~ whitespaceParser ~ + ((numericOperatorParser.! ~ whitespaceParser ~ (longParser | numberParser.!) ~ End) | + (commonOperatorParser.! ~ whitespaceParser ~ (uuidParser | booleanParser | AnyChar.rep(min = 1).!) ~ End)) + ).map { + case BinaryAtomFromTuple(atom) => atom + } + + private val nAryAtomParser: Parser[SearchFilterExpr.Atom.NAry] = P( + dimensionParser ~ whitespaceParser ~ ( + naryOperatorParser ~ whitespaceParser ~ + ((longParser.rep(min = 1, sep = ",") ~ End) | (booleanParser.rep(min = 1, sep = ",") ~ End) | + (nAryValueParser.!.rep(min = 1, sep = ",") ~ End)) + ) + ).map { + case NAryAtomFromTuple(atom) => atom + } + + private val atomParser: Parser[SearchFilterExpr.Atom] = P(binaryAtomParser | nAryAtomParser) + + @deprecated("play-akka transition", "0") + def parse(query: Map[String, Seq[String]]): Try[SearchFilterExpr] = + parse(query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) + }) + + def parse(query: Seq[(String, String)]): Try[SearchFilterExpr] = Try { + query.toList.collect { case ("filters", value) => value } match { + case Nil => SearchFilterExpr.Empty + + case head :: Nil => + atomParser.parse(head) match { + case Parsed.Success(x, _) => x + case e: Parsed.Failure[_, _] => throw new ParseQueryArgException("filters" -> formatFailure(1, e)) + } + + case xs => + val parsed = xs.map(x => atomParser.parse(x)) + val failures: Seq[String] = parsed.zipWithIndex.collect { + case (e: Parsed.Failure[_, _], index) => formatFailure(index, e) + } + + if (failures.isEmpty) { + val filters = parsed.collect { + case Parsed.Success(x, _) => x + } + + SearchFilterExpr.Intersection.create(filters: _*) + } else { + throw new ParseQueryArgException("filters" -> failures.mkString("; ")) + } + } + } + + private def formatFailure(sectionIndex: Int, e: Parsed.Failure[_, _]): String = { + s"section $sectionIndex: ${fastparse.core.ParseError.msg(e.extra.input, e.extra.traced.expected, e.index)}" + } + +} diff --git a/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala b/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala new file mode 100644 index 0000000..f2a3c04 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/rest/parsers/SortingParser.scala @@ -0,0 +1,64 @@ +package xyz.driver.restquery.http.parsers + +import fastparse.all._ +import fastparse.core.Parsed +import xyz.driver.restquery.domain.{Sorting, SortingOrder} +import xyz.driver.restquery.utils.Utils._ + +import scala.util.Try + +object SortingParser { + + private val sortingOrderParser: Parser[SortingOrder] = P("-".!.?).map { + case Some(_) => SortingOrder.Descending + case None => SortingOrder.Ascending + } + + private def dimensionSortingParser(validDimensions: Seq[String]): Parser[Sorting.Dimension] = { + P(sortingOrderParser ~ StringIn(validDimensions: _*).!).map { + case (sortingOrder, field) => + val prefixedFields = field.split("\\.", 2) + prefixedFields.size match { + case 1 => Sorting.Dimension(None, toSnakeCase(field), sortingOrder) + case 2 => + Sorting.Dimension(Some(prefixedFields.head).map(toSnakeCase), + toSnakeCase(prefixedFields.last), + sortingOrder) + } + } + } + + private def sequentialSortingParser(validDimensions: Seq[String]): Parser[Sorting.Sequential] = { + P(dimensionSortingParser(validDimensions).rep(min = 1, sep = ",") ~ End).map { dimensions => + Sorting.Sequential(dimensions) + } + } + + @deprecated("play-akka transition", "0") + def parse(validDimensions: Set[String], query: Map[String, Seq[String]]): Try[Sorting] = + parse(validDimensions, query.toSeq.flatMap { + case (key, values) => + values.map(value => key -> value) + }) + + def parse(validDimensions: Set[String], query: Seq[(String, String)]): Try[Sorting] = Try { + query.toList.collect { case ("sort", value) => value } match { + case Nil => Sorting.Sequential(Seq.empty) + + case rawSorting :: Nil => + val parser = sequentialSortingParser(validDimensions.toSeq) + parser.parse(rawSorting) match { + case Parsed.Success(x, _) => x + case e: Parsed.Failure[_, _] => + throw new ParseQueryArgException("sort" -> formatFailure(e)) + } + + case _ => throw new ParseQueryArgException("sort" -> "multiple sections are not allowed") + } + } + + private def formatFailure(e: Parsed.Failure[_, _]): String = { + fastparse.core.ParseError.msg(e.extra.input, e.extra.traced.expected, e.index) + } + +} diff --git a/src/main/scala/xyz/driver/restquery/utils/Utils.scala b/src/main/scala/xyz/driver/restquery/utils/Utils.scala new file mode 100644 index 0000000..86c65d7 --- /dev/null +++ b/src/main/scala/xyz/driver/restquery/utils/Utils.scala @@ -0,0 +1,79 @@ +package xyz.driver.pdsuicommon.utils + +import java.time.LocalDateTime +import java.util.regex.{Matcher, Pattern} + +object Utils { + + implicit val localDateTimeOrdering: Ordering[LocalDateTime] = Ordering.fromLessThan(_ isBefore _) + + /** + * Hack to avoid scala compiler bug with getSimpleName + * @see https://issues.scala-lang.org/browse/SI-2034 + */ + def getClassSimpleName(klass: Class[_]): String = { + try { + klass.getSimpleName + } catch { + case _: InternalError => + val fullName = klass.getName.stripSuffix("$") + val fullClassName = fullName.substring(fullName.lastIndexOf(".") + 1) + fullClassName.substring(fullClassName.lastIndexOf("$") + 1) + } + } + + def toSnakeCase(str: String): String = + str + .replaceAll("([A-Z]+)([A-Z][a-z])", "$1_$2") + .replaceAll("([a-z\\d])([A-Z])", "$1_$2") + .toLowerCase + + def toCamelCase(str: String): String = { + val sb = new StringBuffer() + def loop(m: Matcher): Unit = if (m.find()) { + m.appendReplacement(sb, m.group(1).toUpperCase()) + loop(m) + } + val m: Matcher = Pattern.compile("_(.)").matcher(str) + loop(m) + m.appendTail(sb) + sb.toString + } + + object Whitespace { + private val Table: String = + "\u2002\u3000\r\u0085\u200A\u2005\u2000\u3000" + + "\u2029\u000B\u3000\u2008\u2003\u205F\u3000\u1680" + + "\u0009\u0020\u2006\u2001\u202F\u00A0\u000C\u2009" + + "\u3000\u2004\u3000\u3000\u2028\n\u2007\u3000" + + private val Multiplier: Int = 1682554634 + private val Shift: Int = Integer.numberOfLeadingZeros(Table.length - 1) + + def matches(c: Char): Boolean = Table.charAt((Multiplier * c) >>> Shift) == c + } + + def isSafeWhitespace(char: Char): Boolean = Whitespace.matches(char) + + // From Guava + def isSafeControl(char: Char): Boolean = + char <= '\u001f' || (char >= '\u007f' && char <= '\u009f') + + + def safeTrim(string: String): String = { + def shouldKeep(c: Char): Boolean = !isSafeControl(c) && !isSafeWhitespace(c) + + if (string.isEmpty) { + "" + } else { + val start = string.indexWhere(shouldKeep) + val end = string.lastIndexWhere(shouldKeep) + + if (start >= 0 && end >= 0) { + string.substring(start, end + 1) + } else { + "" + } + } + } +} |