aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKseniya Tomskikh <ktomskih@datamonsters.co>2017-08-17 16:34:06 +0700
committerKseniya Tomskikh <ktomskih@datamonsters.co>2017-08-17 19:10:21 +0700
commit0fc769e8e2141451da34247e7733fcf0a3396b9c (patch)
tree1f197c0ae06e0d8d8f349410f8266082f5fbc0fd
parent322ea28ecf5ad5f65d3376f3e97e004d229d4736 (diff)
downloadrest-query-0fc769e8e2141451da34247e7733fcf0a3396b9c.tar.gz
rest-query-0fc769e8e2141451da34247e7733fcf0a3396b9c.tar.bz2
rest-query-0fc769e8e2141451da34247e7733fcf0a3396b9c.zip
Created SlickQueryBuilder
-rw-r--r--build.sbt3
-rw-r--r--src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala103
-rw-r--r--src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala341
3 files changed, 446 insertions, 1 deletions
diff --git a/build.sbt b/build.sbt
index 8801de7..7778e02 100644
--- a/build.sbt
+++ b/build.sbt
@@ -28,5 +28,6 @@ lazy val core = (project in file("."))
"org.asynchttpclient" % "async-http-client" % "2.0.24",
"org.slf4j" % "slf4j-api" % "1.7.21",
"ai.x" %% "diff" % "1.2.0-get-simple-name-fix" % "test",
- "org.scalatest" %% "scalatest" % "3.0.0" % "test"
+ "org.scalatest" %% "scalatest" % "3.0.0" % "test",
+ "xyz.driver" %% "core" % "0.16.3" excludeAll (ExclusionRule(organization = "io.netty"))
))
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala
new file mode 100644
index 0000000..66434f0
--- /dev/null
+++ b/src/main/scala/xyz/driver/pdsuicommon/db/SlickPostgresQueryBuilder.scala
@@ -0,0 +1,103 @@
+package xyz.driver.pdsuicommon.db
+
+import java.time.{LocalDateTime, ZoneOffset}
+
+import slick.driver.JdbcProfile
+import slick.jdbc.GetResult
+import xyz.driver.core.database.SlickDal
+
+import scala.collection.breakOut
+import scala.concurrent.ExecutionContext
+
+object SlickPostgresQueryBuilder {
+
+ import xyz.driver.pdsuicommon.db.SlickQueryBuilder._
+
+ def apply[T](tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink],
+ runner: Runner[T],
+ countRunner: CountRunner)(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ val parameters = SlickPostgresQueryBuilderParameters(
+ tableData = TableData(tableName, lastUpdateFieldName, nullableFields),
+ links = links.map(x => x.foreignTableName -> x)(breakOut)
+ )
+ new SlickPostgresQueryBuilder[T](parameters)(runner, countRunner)
+ }
+
+ def apply[T](tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+ apply(tableName, SlickQueryBuilderParameters.AllFields, lastUpdateFieldName, nullableFields, links)
+ }
+
+ def apply[T](tableName: String,
+ fields: Set[String],
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[SlickTableLink])(implicit sqlContext: SlickDal,
+ profile: JdbcProfile,
+ getResult: GetResult[T],
+ ec: ExecutionContext): SlickPostgresQueryBuilder[T] = {
+
+ val runner: Runner[T] = { parameters =>
+ val sql = parameters.toSql(countQuery = false, fields = fields).as[T]
+ sqlContext.execute(sql)
+ }
+
+ val countRunner: CountRunner = { parameters =>
+ implicit val getCountResult: GetResult[(Int, Option[LocalDateTime])] = GetResult({ r =>
+ val count = r.rs.getInt(1)
+ val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) {
+ Option(r.rs.getTimestamp(2)).map(timestampToLocalDateTime)
+ } else None
+ (count, lastUpdate)
+ })
+ val sql = parameters.toSql(countQuery = true).as[(Int, Option[LocalDateTime])]
+ sqlContext.execute(sql).map(_.head)
+ }
+
+ apply[T](
+ tableName = tableName,
+ lastUpdateFieldName = lastUpdateFieldName,
+ nullableFields = nullableFields,
+ links = links,
+ runner = runner,
+ countRunner = countRunner
+ )
+ }
+
+ def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = {
+ LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC)
+ }
+}
+
+class SlickPostgresQueryBuilder[T](parameters: SlickPostgresQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner)
+ extends SlickQueryBuilder[T](parameters) {
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(filter = newFilter))
+ }
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(sorting = newSorting))
+ }
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination)))
+ }
+
+ def resetPagination: SlickQueryBuilder[T] = {
+ new SlickPostgresQueryBuilder[T](parameters.copy(pagination = None))
+ }
+}
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala
new file mode 100644
index 0000000..e45ff87
--- /dev/null
+++ b/src/main/scala/xyz/driver/pdsuicommon/db/SlickQueryBuilder.scala
@@ -0,0 +1,341 @@
+package xyz.driver.pdsuicommon.db
+
+import java.sql.PreparedStatement
+import java.time.LocalDateTime
+
+import slick.driver.JdbcProfile
+import slick.jdbc.{PositionedParameters, SQLActionBuilder, SetParameter}
+import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential}
+import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending}
+
+import scala.concurrent.{ExecutionContext, Future}
+
+object SlickQueryBuilder {
+
+ type Runner[T] = SlickQueryBuilderParameters => Future[Seq[T]]
+
+ type CountResult = Future[(Int, Option[LocalDateTime])]
+
+ type CountRunner = SlickQueryBuilderParameters => CountResult
+
+ /**
+ * Binder for PreparedStatement
+ */
+ type Binder = PreparedStatement => PreparedStatement
+
+ final case class TableData(tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
+
+ val AllFields = Set("*")
+
+ implicit class SQLActionBuilderConcat(a: SQLActionBuilder) {
+ def concat(b: SQLActionBuilder): SQLActionBuilder = {
+ SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] {
+ def apply(p: Unit, pp: PositionedParameters): Unit = {
+ a.unitPConv.apply(p, pp)
+ b.unitPConv.apply(p, pp)
+ }
+ })
+ }
+ }
+}
+
+final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String)
+
+object SlickQueryBuilderParameters {
+ val AllFields = Set("*")
+}
+
+sealed trait SlickQueryBuilderParameters {
+ import SlickQueryBuilder._
+
+ def tableData: SlickQueryBuilder.TableData
+ def links: Map[String, SlickTableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
+ toSql(countQuery, QueryBuilderParameters.AllFields)
+ }
+
+ def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ val escapedTableName = tableData.tableName
+ val fieldsSql: String = if (countQuery) {
+ val suffix: String = tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.$lastUpdateField)"
+ case None => ""
+ }
+ "count(*)" + suffix
+ } else {
+ if (fields == SlickQueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.$field"
+ }
+ .mkString(", ")
+ }
+ }
+ val where = filterToSql(escapedTableName, filter)
+ val orderBy = sortingToSql(escapedTableName, sorting)
+
+ val limitSql = limitToSql()
+
+ val sql = sql"select #$fieldsSql from #$escapedTableName"
+
+ val filtersTableLinks: Seq[SlickTableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ foreignTableLinks.foreach {
+ case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ sql = sql concat sql"""inner join #$foreignTableName
+ on #$escapedTableName.#$keyColumnName = #$foreignTableName.#$foreignKeyColumnName"""
+ }
+
+ if (where.toString.nonEmpty) {
+ sql = sql concat sql"where #$where"
+ }
+
+ if (orderBy.toString.nonEmpty && !countQuery) {
+ sql = sql concat sql"order by #$orderBy"
+ }
+
+ if (limitSql.toString.nonEmpty && !countQuery) {
+ sql = sql concat sql"#$limitSql"
+ }
+
+ sql
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+ import profile.api._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ val tableName = escapedTableName
+ s"$tableName.$dimension.name"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
+ }
+
+ def multipleSqlToAction(op: String, conditions: Seq[SQLActionBuilder]): SQLActionBuilder = {
+ var first = true
+ var filterSql = sql"("
+ for (condition <- conditions) {
+ if (first) {
+ filterSql = filterSql concat condition
+ first = false
+ } else {
+ filterSql = filterSql concat sql" #$op " concat condition
+ }
+ }
+ filterSql concat sql")"
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ sql""
+
+ case AllowAll =>
+ sql"1"
+
+ case DenyAll =>
+ sql"0"
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is not NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ sql"(#$escapedColumn is null or #$escapedColumn != #$value)"
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => sql"="
+ case NotEq => sql"!="
+ case Like => sql"like"
+ case Gt => sql">"
+ case GtEq => sql">="
+ case Lt => sql"<"
+ case LtEq => sql"<="
+ }
+ sql"#${escapeDimension(dimension)}" concat operator concat sql"#$value"
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => sql"in"
+ case SearchFilterNAryOperation.NotIn => sql"not in"
+ }
+
+ val formattedValues = if (values.nonEmpty) {
+ sql"#$values"
+ } else sql"NULL"
+ sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
+
+ case Intersection(operands) =>
+ multipleSqlToAction("and", filterToSqlMultiple(operands))
+
+ case Union(operands) =>
+ multipleSqlToAction("or", filterToSqlMultiple(operands))
+ }
+ }
+
+ protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName = optSortingTableName.getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.$field"
+
+ sql"#$fullName #${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ sql"#${xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")}"
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach {
+ case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
+
+final case class SlickPostgresQueryBuilderParameters(tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #${pagination.pageSize} OFFSET #$startFrom"
+ } getOrElse (sql"")
+ }
+
+}
+
+/**
+ * @param links Links to another tables grouped by foreignTableName
+ */
+final case class SlickMysqlQueryBuilderParameters(tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
+ extends SlickQueryBuilderParameters {
+
+ def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ pagination
+ .map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ sql"limit #$startFrom, #${pagination.pageSize}"
+ }
+ .getOrElse(sql"")
+ }
+
+}
+
+abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)(
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner) {
+
+ def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters)
+
+ def runCount()(implicit ec: ExecutionContext): SlickQueryBuilder.CountResult = countRunner(parameters)
+
+ /**
+ * Runs the query and returns total found rows without considering of pagination.
+ */
+ def runWithCount()(implicit ec: ExecutionContext): Future[(Seq[T], Int, Option[LocalDateTime])] = {
+ for {
+ all <- run
+ (total, lastUpdate) <- runCount
+ } yield (all, total, lastUpdate)
+ }
+
+ def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T]
+
+ def withFilter(filter: Option[SearchFilterExpr]): SlickQueryBuilder[T] = {
+ filter.fold(this)(withFilter)
+ }
+
+ def resetFilter: SlickQueryBuilder[T] = withFilter(SearchFilterExpr.Empty)
+
+ def withSorting(newSorting: Sorting): SlickQueryBuilder[T]
+
+ def withSorting(sorting: Option[Sorting]): SlickQueryBuilder[T] = {
+ sorting.fold(this)(withSorting)
+ }
+
+ def resetSorting: SlickQueryBuilder[T] = withSorting(Sorting.Empty)
+
+ def withPagination(newPagination: Pagination): SlickQueryBuilder[T]
+
+ def withPagination(pagination: Option[Pagination]): SlickQueryBuilder[T] = {
+ pagination.fold(this)(withPagination)
+ }
+
+ def resetPagination: SlickQueryBuilder[T]
+
+}