aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorktomskikh <ktomskih@datamonsters.co>2017-09-21 12:00:50 +0700
committerGitHub <noreply@github.com>2017-09-21 12:00:50 +0700
commitc466ce359ec923d8c1f9a8188191ecee8085312c (patch)
tree4eed12a4ebb2829e336a3da673c7c8462e7ab845
parentd4b18efda238f506103dddbf3b400ae17c797276 (diff)
parent9968eaefa2a97ebe495fa51b640e31c78db61ac6 (diff)
downloadrest-query-0.3.15.tar.gz
rest-query-0.3.15.tar.bz2
rest-query-0.3.15.zip
Merge pull request #20 from drivergroup/slick-query-builderv0.3.15
Refactoting QueryBuilder for using with Slick
-rw-r--r--src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala116
-rw-r--r--src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala357
2 files changed, 473 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala
new file mode 100644
index 0000000..f882441
--- /dev/null
+++ b/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala
@@ -0,0 +1,116 @@
+package xyz.driver.pdsuicommon.db
+
+import java.time.{LocalDateTime, ZoneOffset}
+
+import slick.driver.JdbcProfile
+import slick.jdbc.GetResult
+import xyz.driver.core.database.SlickDal
+import xyz.driver.pdsuicommon.logging._
+
+import scala.collection.breakOut
+import scala.concurrent.ExecutionContext
+
+object SlickPostgresQueryBuilder extends PhiLogging {
+
+ import xyz.driver.pdsuicommon.db.SlickQueryBuilder._
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink],
+ runner: Runner[T],
+ countRunner: CountRunner)(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ val parameters = SlickPostgresQueryBuilderParameters(
+ databaseName = databaseName,
+ tableData = TableData(tableName, lastUpdateFieldName, nullableFields),
+ links = links.map(x => x.foreignTableName -> x)(breakOut)
+ )
+ new SlickPostgresQueryBuilder[T](parameters)(runner, countRunner)
+ }
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ apply[T](databaseName,
+ tableName,
+ SlickQueryBuilderParameters.AllFields,
+ lastUpdateFieldName,
+ nullableFields,
+ links)
+ }
+
+ def apply[T](databaseName: String,
+ tableName: String,
+ fields: Set[String],
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+
+ val runner: Runner[T] = { parameters =>
+ val sql = parameters.toSql(countQuery = false, fields = fields).as[T]
+ logger.debug(phi"${Unsafe(sql)}")
+ sqlContext.execute(sql)
+ }
+
+ val countRunner: CountRunner = { parameters =>
+ implicit val getCountResult: GetResult[(Int, Option[LocalDateTime])] = GetResult({ r =>
+ val count = r.rs.getInt(1)
+ val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) {
+ Option(r.rs.getTimestamp(2)).map(timestampToLocalDateTime)
+ } else None
+ (count, lastUpdate)
+ })
+ val sql = parameters.toSql(countQuery = true).as[(Int, Option[LocalDateTime])]
+ logger.debug(phi"${Unsafe(sql)}")
+ sqlContext.execute(sql).map(_.head)
+ }
+
+ apply[T](
+ databaseName = databaseName,
+ tableName = tableName,
+ lastUpdateFieldName = lastUpdateFieldName,
+ nullableFields = nullableFields,
+ links = links,
+ runner = runner,
+ countRunner = countRunner
+ )
+ }
+
+ def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = {
+ LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC)
+ }
+}
+
+class SlickPostgresQueryBuilder[T](parameters: SlickPostgresQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner)
+ extends SlickQueryBuilder[T](parameters) {
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(filter = newFilter))
+ }
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(sorting = newSorting))
+ }
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination)))
+ }
+
+ def resetPagination: SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = None))
+ }
+}
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala
new file mode 100644
index 0000000..79cb114
--- /dev/null
+++ b/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala
@@ -0,0 +1,357 @@
+package xyz.driver.pdsuicommon.db
+
+import java.sql.PreparedStatement
+import java.time.LocalDateTime
+
+import slick.driver.JdbcProfile
+import slick.jdbc.{PositionedParameters, SQLActionBuilder, SetParameter}
+import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential}
+import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending}
+
+import scala.concurrent.{ExecutionContext, Future}
+
+object SlickQueryBuilder {
+
+ type Runner[T] = SlickQueryBuilderParameters => Future[Seq[T]]
+
+ type CountResult = Future[(Int, Option[LocalDateTime])]
+
+ type CountRunner = SlickQueryBuilderParameters => CountResult
+
+ /**
+ * Binder for PreparedStatement
+ */
+ type Binder = PreparedStatement => PreparedStatement
+
+ final case class TableData(tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
+
+ val AllFields = Set("*")
+
+ implicit class SQLActionBuilderConcat(a: SQLActionBuilder) {
+ def concat(b: SQLActionBuilder): SQLActionBuilder = {
+ SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] {
+ def apply(p: Unit, pp: PositionedParameters): Unit = {
+ a.unitPConv.apply(p, pp)
+ b.unitPConv.apply(p, pp)
+ }
+ })
+ }
+ }
+}
+
+final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String)
+
+object SlickQueryBuilderParameters {
+ val AllFields = Set("*")
+}
+
+sealed trait SlickQueryBuilderParameters {
+ import SlickQueryBuilder._
+
+ def databaseName: String
+ def tableData: SlickQueryBuilder.TableData
+ def links: Map[String, SlickTableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def qs: String
+
+ def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
+ toSql(countQuery, QueryBuilderParameters.AllFields)
+ }
+
+ def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs"""
+ val fieldsSql: String = if (countQuery) {
+ val suffix: String = tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)"
+ case None => ""
+ }
+ s"count(*) $suffix"
+ } else {
+ if (fields == SlickQueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.$qs$field$qs"
+ }
+ .mkString(", ")
+ }
+ }
+ val where = filterToSql(escapedTableName, filter)
+ val orderBy = sortingToSql(escapedTableName, sorting)
+
+ val limitSql = limitToSql()
+
+ val sql = sql"""select #$fieldsSql from #$escapedTableName"""
+
+ val filtersTableLinks: Seq[SlickTableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = {
+ if (tableLinks.nonEmpty) {
+ tableLinks.head match {
+ case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs"
+ val join = sql""" inner join #$escapedForeignTableName
+ on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs"""
+ fkSql(fkLinksSql concat join, tableLinks.tail)
+ }
+ } else fkLinksSql
+ }
+ val foreignTableLinksSql = fkSql(sql"", foreignTableLinks)
+
+ val whereSql = if (where.queryParts.size > 1) {
+ sql" where " concat where
+ } else sql""
+
+ val orderSql = if (orderBy.nonEmpty && !countQuery) {
+ sql" order by #$orderBy"
+ } else sql""
+
+ val limSql = if (limitSql.queryParts.size > 1 && !countQuery) {
+ sql" " concat limitSql
+ } else sql""
+
+ sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+ import profile.api._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ s"$escapedTableName.$qs${dimension.name}$qs"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
+ }
+
+ def multipleSqlToAction(first: Boolean,
+ op: String,
+ conditions: Seq[SQLActionBuilder],
+ sql: SQLActionBuilder): SQLActionBuilder = {
+ if (conditions.nonEmpty) {
+ val condition = conditions.head
+ if (first) {
+ multipleSqlToAction(false, op, conditions.tail, condition)
+ } else {
+ multipleSqlToAction(false, op, conditions.tail, sql concat sql" #${op} " concat condition)
+ }
+ } else sql
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ sql""
+
+ case AllowAll =>
+ sql"1"
+
+ case DenyAll =>
+ sql"0"
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is not NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ sql"(#${escapedColumn} is null or #${escapedColumn} != ${value.toString})"
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => sql"="
+ case NotEq => sql"!="
+ case Like => sql"like"
+ case Gt => sql">"
+ case GtEq => sql">="
+ case Lt => sql"<"
+ case LtEq => sql"<="
+ }
+ sql"#${escapeDimension(dimension)}" concat operator concat sql"""${value.toString}"""
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => sql"in"
+ case SearchFilterNAryOperation.NotIn => sql"not in"
+ }
+
+ val formattedValues = if (values.nonEmpty) {
+ sql"${values.mkString(",")}"
+ } else sql"NULL"
+ sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
+
+ case Intersection(operands) =>
+ val filter = multipleSqlToAction(true, "and", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+
+ case Union(operands) =>
+ val filter = multipleSqlToAction(true, "or", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+ }
+ }
+
+ protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName =
+ optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.$qs$field$qs"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach {
+ case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
+
+final case class SlickPostgresQueryBuilderParameters(databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #${pagination.pageSize} OFFSET #$startFrom"
+ } getOrElse (sql"")
+ }
+
+ val qs = """""""
+
+}
+
+/**
+ * @param links Links to another tables grouped by foreignTableName
+ */
+final case class SlickMysqlQueryBuilderParameters(databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination
+ .map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #$startFrom, #${pagination.pageSize}"
+ }
+ .getOrElse(sql"")
+ }
+
+ val qs = """`"""
+
+}
+
+abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner) {
+
+ def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters)
+
+ def runCount()(implicit ec: ExecutionContext): SlickQueryBuilder.CountResult = countRunner(parameters)
+
+ /**
+ * Runs the query and returns total found rows without considering of pagination.
+ */
+ def runWithCount()(implicit ec: ExecutionContext): Future[(Seq[T], Int, Option[LocalDateTime])] = {
+ for {
+ all <- run
+ (total, lastUpdate) <- runCount
+ } yield (all, total, lastUpdate)
+ }
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T]
+
+ def withFilter(filter: Option[SearchFilterExpr]): SlickQueryBuilder[T] = {
+ filter.fold(this)(withFilter)
+ }
+
+ def resetFilter: SlickQueryBuilder[T] = withFilter(SearchFilterExpr.Empty)
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T]
+
+ def withSorting(sorting: Option[Sorting]): SlickQueryBuilder[T] = {
+ sorting.fold(this)(withSorting)
+ }
+
+ def resetSorting: SlickQueryBuilder[T] = withSorting(Sorting.Empty)
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T]
+
+ def withPagination(pagination: Option[Pagination]): SlickQueryBuilder[T] = {
+ pagination.fold(this)(withPagination)
+ }
+
+ def resetPagination: SlickQueryBuilder[T]
+
+}