aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala')
-rw-r--r--src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala299
1 files changed, 25 insertions, 274 deletions
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
index 9962edf..67ce9f4 100644
--- a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
+++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilder.scala
@@ -1,12 +1,10 @@
-package xyz.driver.pdsuicommon.db
+package xyz.driver.restquery.db
import java.sql.{JDBCType, PreparedStatement}
import java.time.LocalDateTime
import slick.jdbc.{JdbcProfile, PositionedParameters, SQLActionBuilder, SetParameter}
-import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential}
-import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending}
-import xyz.driver.pdsuicommon.domain.{LongId, StringId, UuidId}
+import xyz.driver.restquery.query._
import scala.concurrent.{ExecutionContext, Future}
@@ -23,19 +21,18 @@ object SlickQueryBuilder {
*/
type Binder = PreparedStatement => PreparedStatement
- final case class TableData(tableName: String,
- lastUpdateFieldName: Option[String] = None,
- nullableFields: Set[String] = Set.empty)
+ final case class TableData(
+ tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
val AllFields = Set("*")
implicit class SQLActionBuilderConcat(a: SQLActionBuilder) {
def concat(b: SQLActionBuilder): SQLActionBuilder = {
- SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] {
- def apply(p: Unit, pp: PositionedParameters): Unit = {
- a.unitPConv.apply(p, pp)
- b.unitPConv.apply(p, pp)
- }
+ SQLActionBuilder(a.queryParts ++ b.queryParts, (p: Unit, pp: PositionedParameters) => {
+ a.unitPConv.apply(p, pp)
+ b.unitPConv.apply(p, pp)
})
}
}
@@ -45,264 +42,17 @@ object SlickQueryBuilder {
pp.setObject(v, JDBCType.BINARY.getVendorTypeNumber)
}
}
-
- implicit def setLongIdQueryParameter[T]: SetParameter[LongId[T]] = SetParameter[LongId[T]] { (v, pp) =>
- pp.setLong(v.id)
- }
-
- implicit def setStringIdQueryParameter[T]: SetParameter[StringId[T]] = SetParameter[StringId[T]] { (v, pp) =>
- pp.setString(v.id)
- }
-
- implicit def setUuidIdQueryParameter[T]: SetParameter[UuidId[T]] = SetParameter[UuidId[T]] { (v, pp) =>
- pp.setObject(v.id, JDBCType.BINARY.getVendorTypeNumber)
- }
}
final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String)
-object SlickQueryBuilderParameters {
- val AllFields = Set("*")
-}
-
-sealed trait SlickQueryBuilderParameters {
- import SlickQueryBuilder._
-
- def databaseName: String
- def tableData: SlickQueryBuilder.TableData
- def links: Map[String, SlickTableLink]
- def filter: SearchFilterExpr
- def sorting: Sorting
- def pagination: Option[Pagination]
-
- def qs: String
-
- def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
- case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
- case Some(link) => link
- }
-
- def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
- toSql(countQuery, SlickQueryBuilderParameters.AllFields)
- }
-
- def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
- import profile.api._
- val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs"""
- val fieldsSql: String = if (countQuery) {
- val suffix: String = tableData.lastUpdateFieldName match {
- case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)"
- case None => ""
- }
- s"count(*) $suffix"
- } else {
- if (fields == SlickQueryBuilderParameters.AllFields) {
- s"$escapedTableName.*"
- } else {
- fields
- .map { field =>
- s"$escapedTableName.$qs$field$qs"
- }
- .mkString(", ")
- }
- }
- val where = filterToSql(escapedTableName, filter)
- val orderBy = sortingToSql(escapedTableName, sorting)
-
- val limitSql = limitToSql()
-
- val sql = sql"""select #$fieldsSql from #$escapedTableName"""
-
- val filtersTableLinks: Seq[SlickTableLink] = {
- import SearchFilterExpr._
- def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
- case Atom.TableName(tableName) => List(findLink(tableName))
- case Intersection(xs) => xs.flatMap(aux)
- case Union(xs) => xs.flatMap(aux)
- case _ => Nil
- }
- aux(filter)
- }
-
- val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
- case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
- }
-
- // Combine links from sorting and filter without duplicates
- val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
-
- def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = {
- if (tableLinks.nonEmpty) {
- tableLinks.head match {
- case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
- val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs"
- val join = sql""" inner join #$escapedForeignTableName
- on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs"""
- fkSql(fkLinksSql concat join, tableLinks.tail)
- }
- } else fkLinksSql
- }
- val foreignTableLinksSql = fkSql(sql"", foreignTableLinks)
-
- val whereSql = if (where.queryParts.size > 1) {
- sql" where " concat where
- } else sql""
-
- val orderSql = if (orderBy.nonEmpty && !countQuery) {
- sql" order by #$orderBy"
- } else sql""
-
- val limSql = if (limitSql.queryParts.size > 1 && !countQuery) {
- sql" " concat limitSql
- } else sql""
-
- sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql
- }
-
- /**
- * Converts filter expression to SQL expression.
- *
- * @return Returns SQL string and list of values for binding in prepared statement.
- */
- protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
- implicit profile: JdbcProfile): SQLActionBuilder = {
- import SearchFilterBinaryOperation._
- import SearchFilterExpr._
- import profile.api._
-
- def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
-
- def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
- s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs"
- }
-
- def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
- case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
- }
-
- def multipleSqlToAction(first: Boolean,
- op: String,
- conditions: Seq[SQLActionBuilder],
- sql: SQLActionBuilder): SQLActionBuilder = {
- if (conditions.nonEmpty) {
- val condition = conditions.head
- if (first) {
- multipleSqlToAction(first = false, op, conditions.tail, condition)
- } else {
- multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition)
- }
- } else sql
- }
-
- def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = {
- if (tail.nonEmpty) {
- if (!first) {
- concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail)
- } else {
- concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail)
- }
- } else sql concat sql")"
- }
-
- filter match {
- case x if isEmpty(x) =>
- sql""
-
- case AllowAll =>
- sql"1=1"
-
- case DenyAll =>
- sql"1=0"
-
- case Atom.Binary(dimension, Eq, value) if isNull(value) =>
- sql"#${escapeDimension(dimension)} is NULL"
-
- case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
- sql"#${escapeDimension(dimension)} is not NULL"
-
- case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
- // In MySQL NULL <> Any === NULL
- // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
- // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
- val escapedColumn = escapeDimension(dimension)
- sql"(#${escapedColumn} is null or #${escapedColumn} != $value)"
-
- case Atom.Binary(dimension, op, value) =>
- val operator = op match {
- case Eq => sql"="
- case NotEq => sql"!="
- case Like => sql" like "
- case Gt => sql">"
- case GtEq => sql">="
- case Lt => sql"<"
- case LtEq => sql"<="
- }
- sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value"""
-
- case Atom.NAry(dimension, op, values) =>
- val sqlOp = op match {
- case SearchFilterNAryOperation.In => sql" in "
- case SearchFilterNAryOperation.NotIn => sql" not in "
- }
-
- if (values.nonEmpty) {
- val formattedValues = concatenateParameters(sql"", first = true, values)
- sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
- } else {
- sql"1=0"
- }
-
- case Intersection(operands) =>
- val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"")
- sql"(" concat filter concat sql")"
-
- case Union(operands) =>
- val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"")
- sql"(" concat filter concat sql")"
- }
- }
-
- protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
-
- /**
- * @param escapedMainTableName Should be escaped
- */
- protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = {
- sorting match {
- case Dimension(optSortingTableName, field, order) =>
- val sortingTableName =
- optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName)
- val fullName = s"$sortingTableName.$qs$field$qs"
-
- s"$fullName ${orderToSql(order)}"
-
- case Sequential(xs) =>
- xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")
- }
- }
-
- protected def orderToSql(x: SortingOrder): String = x match {
- case Ascending => "asc"
- case Descending => "desc"
- }
-
- protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
- bindings.zipWithIndex.foreach {
- case (binding, index) =>
- bind.setObject(index + 1, binding)
- }
-
- bind
- }
-
-}
-
-final case class SlickPostgresQueryBuilderParameters(databaseName: String,
- tableData: SlickQueryBuilder.TableData,
- links: Map[String, SlickTableLink] = Map.empty,
- filter: SearchFilterExpr = SearchFilterExpr.Empty,
- sorting: Sorting = Sorting.Empty,
- pagination: Option[Pagination] = None)
+final case class SlickPostgresQueryBuilderParameters(
+ databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
extends SlickQueryBuilderParameters {
def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
@@ -320,12 +70,13 @@ final case class SlickPostgresQueryBuilderParameters(databaseName: String,
/**
* @param links Links to another tables grouped by foreignTableName
*/
-final case class SlickMysqlQueryBuilderParameters(databaseName: String,
- tableData: SlickQueryBuilder.TableData,
- links: Map[String, SlickTableLink] = Map.empty,
- filter: SearchFilterExpr = SearchFilterExpr.Empty,
- sorting: Sorting = Sorting.Empty,
- pagination: Option[Pagination] = None)
+final case class SlickMysqlQueryBuilderParameters(
+ databaseName: String,
+ tableData: SlickQueryBuilder.TableData,
+ links: Map[String, SlickTableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None)
extends SlickQueryBuilderParameters {
def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = {
@@ -343,8 +94,8 @@ final case class SlickMysqlQueryBuilderParameters(databaseName: String,
}
abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)(
- implicit runner: SlickQueryBuilder.Runner[T],
- countRunner: SlickQueryBuilder.CountRunner) {
+ implicit runner: SlickQueryBuilder.Runner[T],
+ countRunner: SlickQueryBuilder.CountRunner) {
def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters)