package xyz.driver.pdsuicommon.db import java.sql.PreparedStatement import java.time.LocalDateTime import slick.driver.JdbcProfile import slick.jdbc.{PositionedParameters, SQLActionBuilder, SetParameter} import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential} import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending} import scala.concurrent.{ExecutionContext, Future} object SlickQueryBuilder { type Runner[T] = SlickQueryBuilderParameters => Future[Seq[T]] type CountResult = Future[(Int, Option[LocalDateTime])] type CountRunner = SlickQueryBuilderParameters => CountResult /** * Binder for PreparedStatement */ type Binder = PreparedStatement => PreparedStatement final case class TableData(tableName: String, lastUpdateFieldName: Option[String] = None, nullableFields: Set[String] = Set.empty) val AllFields = Set("*") implicit class SQLActionBuilderConcat(a: SQLActionBuilder) { def concat(b: SQLActionBuilder): SQLActionBuilder = { SQLActionBuilder(a.queryParts ++ b.queryParts, new SetParameter[Unit] { def apply(p: Unit, pp: PositionedParameters): Unit = { a.unitPConv.apply(p, pp) b.unitPConv.apply(p, pp) } }) } } } final case class SlickTableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String) object SlickQueryBuilderParameters { val AllFields = Set("*") } sealed trait SlickQueryBuilderParameters { import SlickQueryBuilder._ def tableData: SlickQueryBuilder.TableData def links: Map[String, SlickTableLink] def filter: SearchFilterExpr def sorting: Sorting def pagination: Option[Pagination] def findLink(tableName: String): SlickTableLink = links.get(tableName) match { case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") case Some(link) => link } def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = { toSql(countQuery, QueryBuilderParameters.AllFields) } def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = { import profile.api._ val escapedTableName = tableData.tableName val fieldsSql: String = if (countQuery) { val suffix: String = tableData.lastUpdateFieldName match { case Some(lastUpdateField) => s", max($escapedTableName.$lastUpdateField)" case None => "" } "count(*)" + suffix } else { if (fields == SlickQueryBuilderParameters.AllFields) { s"$escapedTableName.*" } else { fields .map { field => s"$escapedTableName.$field" } .mkString(", ") } } val where = filterToSql(escapedTableName, filter) val orderBy = sortingToSql(escapedTableName, sorting) val limitSql = limitToSql() val sql = sql"select #$fieldsSql from #$escapedTableName" val filtersTableLinks: Seq[SlickTableLink] = { import SearchFilterExpr._ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match { case Atom.TableName(tableName) => List(findLink(tableName)) case Intersection(xs) => xs.flatMap(aux) case Union(xs) => xs.flatMap(aux) case _ => Nil } aux(filter) } val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) { case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) } // Combine links from sorting and filter without duplicates val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct foreignTableLinks.foreach { case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => sql = sql concat sql"""inner join #$foreignTableName on #$escapedTableName.#$keyColumnName = #$foreignTableName.#$foreignKeyColumnName""" } if (where.toString.nonEmpty) { sql = sql concat sql"where #$where" } if (orderBy.toString.nonEmpty && !countQuery) { sql = sql concat sql"order by #$orderBy" } if (limitSql.toString.nonEmpty && !countQuery) { sql = sql concat sql"#$limitSql" } sql } /** * Converts filter expression to SQL expression. * * @return Returns SQL string and list of values for binding in prepared statement. */ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)( implicit profile: JdbcProfile): SQLActionBuilder = { import SearchFilterBinaryOperation._ import SearchFilterExpr._ import profile.api._ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" def escapeDimension(dimension: SearchFilterExpr.Dimension) = { val tableName = escapedTableName s"$tableName.$dimension.name" } def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x) } def multipleSqlToAction(op: String, conditions: Seq[SQLActionBuilder]): SQLActionBuilder = { var first = true var filterSql = sql"(" for (condition <- conditions) { if (first) { filterSql = filterSql concat condition first = false } else { filterSql = filterSql concat sql" #$op " concat condition } } filterSql concat sql")" } filter match { case x if isEmpty(x) => sql"" case AllowAll => sql"1" case DenyAll => sql"0" case Atom.Binary(dimension, Eq, value) if isNull(value) => sql"#${escapeDimension(dimension)} is NULL" case Atom.Binary(dimension, NotEq, value) if isNull(value) => sql"#${escapeDimension(dimension)} is not NULL" case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => // In MySQL NULL <> Any === NULL // So, to handle NotEq for nullable fields we need to use more complex SQL expression. // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html val escapedColumn = escapeDimension(dimension) sql"(#$escapedColumn is null or #$escapedColumn != #$value)" case Atom.Binary(dimension, op, value) => val operator = op match { case Eq => sql"=" case NotEq => sql"!=" case Like => sql"like" case Gt => sql">" case GtEq => sql">=" case Lt => sql"<" case LtEq => sql"<=" } sql"#${escapeDimension(dimension)}" concat operator concat sql"#$value" case Atom.NAry(dimension, op, values) => val sqlOp = op match { case SearchFilterNAryOperation.In => sql"in" case SearchFilterNAryOperation.NotIn => sql"not in" } val formattedValues = if (values.nonEmpty) { sql"#$values" } else sql"NULL" sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues case Intersection(operands) => multipleSqlToAction("and", filterToSqlMultiple(operands)) case Union(operands) => multipleSqlToAction("or", filterToSqlMultiple(operands)) } } protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder /** * @param escapedMainTableName Should be escaped */ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)( implicit profile: JdbcProfile): SQLActionBuilder = { import profile.api._ sorting match { case Dimension(optSortingTableName, field, order) => val sortingTableName = optSortingTableName.getOrElse(escapedMainTableName) val fullName = s"$sortingTableName.$field" sql"#$fullName #${orderToSql(order)}" case Sequential(xs) => sql"#${xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")}" } } protected def orderToSql(x: SortingOrder): String = x match { case Ascending => "asc" case Descending => "desc" } protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { bindings.zipWithIndex.foreach { case (binding, index) => bind.setObject(index + 1, binding) } bind } } final case class SlickPostgresQueryBuilderParameters(tableData: SlickQueryBuilder.TableData, links: Map[String, SlickTableLink] = Map.empty, filter: SearchFilterExpr = SearchFilterExpr.Empty, sorting: Sorting = Sorting.Empty, pagination: Option[Pagination] = None) extends SlickQueryBuilderParameters { def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = { import profile.api._ pagination.map { pagination => val startFrom = (pagination.pageNumber - 1) * pagination.pageSize sql"limit #${pagination.pageSize} OFFSET #$startFrom" } getOrElse (sql"") } } /** * @param links Links to another tables grouped by foreignTableName */ final case class SlickMysqlQueryBuilderParameters(tableData: SlickQueryBuilder.TableData, links: Map[String, SlickTableLink] = Map.empty, filter: SearchFilterExpr = SearchFilterExpr.Empty, sorting: Sorting = Sorting.Empty, pagination: Option[Pagination] = None) extends SlickQueryBuilderParameters { def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder = { import profile.api._ pagination .map { pagination => val startFrom = (pagination.pageNumber - 1) * pagination.pageSize sql"limit #$startFrom, #${pagination.pageSize}" } .getOrElse(sql"") } } abstract class SlickQueryBuilder[T](val parameters: SlickQueryBuilderParameters)( implicit runner: SlickQueryBuilder.Runner[T], countRunner: SlickQueryBuilder.CountRunner) { def run()(implicit ec: ExecutionContext): Future[Seq[T]] = runner(parameters) def runCount()(implicit ec: ExecutionContext): SlickQueryBuilder.CountResult = countRunner(parameters) /** * Runs the query and returns total found rows without considering of pagination. */ def runWithCount()(implicit ec: ExecutionContext): Future[(Seq[T], Int, Option[LocalDateTime])] = { for { all <- run (total, lastUpdate) <- runCount } yield (all, total, lastUpdate) } def withFilter(newFilter: SearchFilterExpr): SlickQueryBuilder[T] def withFilter(filter: Option[SearchFilterExpr]): SlickQueryBuilder[T] = { filter.fold(this)(withFilter) } def resetFilter: SlickQueryBuilder[T] = withFilter(SearchFilterExpr.Empty) def withSorting(newSorting: Sorting): SlickQueryBuilder[T] def withSorting(sorting: Option[Sorting]): SlickQueryBuilder[T] = { sorting.fold(this)(withSorting) } def resetSorting: SlickQueryBuilder[T] = withSorting(Sorting.Empty) def withPagination(newPagination: Pagination): SlickQueryBuilder[T] def withPagination(pagination: Option[Pagination]): SlickQueryBuilder[T] = { pagination.fold(this)(withPagination) } def resetPagination: SlickQueryBuilder[T] }