package xyz.driver.pdsuicommon.db import java.sql.PreparedStatement import java.time.LocalDateTime import io.getquill.NamingStrategy import io.getquill.context.sql.idiom.SqlIdiom import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential} import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending} import scala.collection.mutable.ListBuffer object QueryBuilder { type Runner[T] = QueryBuilderParameters => Seq[T] type CountResult = (Int, Option[LocalDateTime]) type CountRunner = QueryBuilderParameters => CountResult /** * Binder for PreparedStatement */ type Binder = PreparedStatement => PreparedStatement final case class TableData(tableName: String, lastUpdateFieldName: Option[String] = None, nullableFields: Set[String] = Set.empty) val AllFields = Set("*") } final case class TableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String) object QueryBuilderParameters { val AllFields = Set("*") } sealed trait QueryBuilderParameters { def tableData: QueryBuilder.TableData def links: Map[String, TableLink] def filter: SearchFilterExpr def sorting: Sorting def pagination: Option[Pagination] def findLink(tableName: String): TableLink = links.get(tableName) match { case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") case Some(link) => link } def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy) } def toSql(countQuery: Boolean, fields: Set[String], namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { val escapedTableName = namingStrategy.table(tableData.tableName) val fieldsSql: String = if (countQuery) { val suffix: String = (tableData.lastUpdateFieldName match { case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})" case None => "" }) "count(*)" + suffix } else { if (fields == QueryBuilderParameters.AllFields) { s"$escapedTableName.*" } else { fields .map { field => s"$escapedTableName.${namingStrategy.column(field)}" } .mkString(", ") } } val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy) val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy) val limitSql = limitToSql() val sql = new StringBuilder() sql.append("select ") sql.append(fieldsSql) sql.append("\nfrom ") sql.append(escapedTableName) val filtersTableLinks: Seq[TableLink] = { import SearchFilterExpr._ def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match { case Atom.TableName(tableName) => List(findLink(tableName)) case Intersection(xs) => xs.flatMap(aux) case Union(xs) => xs.flatMap(aux) case _ => Nil } aux(filter) } val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) { case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) } // Combine links from sorting and filter without duplicates val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct foreignTableLinks.foreach { case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => val escapedForeignTableName = namingStrategy.table(foreignTableName) sql.append("\ninner join ") sql.append(escapedForeignTableName) sql.append(" on ") sql.append(escapedTableName) sql.append('.') sql.append(namingStrategy.column(keyColumnName)) sql.append(" = ") sql.append(escapedForeignTableName) sql.append('.') sql.append(namingStrategy.column(foreignKeyColumnName)) } if (where.nonEmpty) { sql.append("\nwhere ") sql.append(where) } if (orderBy.nonEmpty && !countQuery) { sql.append("\norder by ") sql.append(orderBy) } if (limitSql.nonEmpty && !countQuery) { sql.append("\n") sql.append(limitSql) } (sql.toString, binder(bindings)) } /** * Converts filter expression to SQL expression. * * @return Returns SQL string and list of values for binding in prepared statement. */ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr, namingStrategy: NamingStrategy): (String, List[AnyRef]) = { import SearchFilterBinaryOperation._ import SearchFilterExpr._ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" def placeholder(field: String) = "?" def escapeDimension(dimension: SearchFilterExpr.Dimension) = { val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table) s"$tableName.${namingStrategy.column(dimension.name)}" } def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy) } filter match { case x if isEmpty(x) => ("", List.empty) case AllowAll => ("1", List.empty) case DenyAll => ("0", List.empty) case Atom.Binary(dimension, Eq, value) if isNull(value) => (s"${escapeDimension(dimension)} is NULL", List.empty) case Atom.Binary(dimension, NotEq, value) if isNull(value) => (s"${escapeDimension(dimension)} is not NULL", List.empty) case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => // In MySQL NULL <> Any === NULL // So, to handle NotEq for nullable fields we need to use more complex SQL expression. // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html val escapedColumn = escapeDimension(dimension) val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})" (sql, List(value)) case Atom.Binary(dimension, op, value) => val operator = op match { case Eq => "=" case NotEq => "!=" case Like => "like" case Gt => ">" case GtEq => ">=" case Lt => "<" case LtEq => "<=" } (s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value)) case Atom.NAry(dimension, op, values) => val sqlOp = op match { case SearchFilterNAryOperation.In => "in" case SearchFilterNAryOperation.NotIn => "not in" } val bindings = ListBuffer[AnyRef]() val sqlPlaceholder = placeholder(dimension.name) val formattedValues = if (values.nonEmpty) { values .map { value => bindings += value sqlPlaceholder } .mkString(", ") } else "NULL" (s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList) case Intersection(operands) => val (sql, bindings) = filterToSqlMultiple(operands).unzip (sql.mkString("(", " and ", ")"), bindings.flatten.toList) case Union(operands) => val (sql, bindings) = filterToSqlMultiple(operands).unzip (sql.mkString("(", " or ", ")"), bindings.flatten.toList) } } protected def limitToSql(): String /** * @param escapedMainTableName Should be escaped */ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting, namingStrategy: NamingStrategy): String = { sorting match { case Dimension(optSortingTableName, field, order) => val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName) val fullName = s"$sortingTableName.${namingStrategy.column(field)}" s"$fullName ${orderToSql(order)}" case Sequential(xs) => xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ") } } protected def orderToSql(x: SortingOrder): String = x match { case Ascending => "asc" case Descending => "desc" } protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { bindings.zipWithIndex.foreach { case (binding, index) => bind.setObject(index + 1, binding) } bind } } final case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData, links: Map[String, TableLink] = Map.empty, filter: SearchFilterExpr = SearchFilterExpr.Empty, sorting: Sorting = Sorting.Empty, pagination: Option[Pagination] = None) extends QueryBuilderParameters { def limitToSql(): String = { pagination.map { pagination => val startFrom = (pagination.pageNumber - 1) * pagination.pageSize s"limit ${pagination.pageSize} OFFSET $startFrom" } getOrElse "" } } /** * @param links Links to another tables grouped by foreignTableName */ final case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData, links: Map[String, TableLink] = Map.empty, filter: SearchFilterExpr = SearchFilterExpr.Empty, sorting: Sorting = Sorting.Empty, pagination: Option[Pagination] = None) extends QueryBuilderParameters { def limitToSql(): String = pagination .map { pagination => val startFrom = (pagination.pageNumber - 1) * pagination.pageSize s"limit $startFrom, ${pagination.pageSize}" } .getOrElse("") } abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters)( implicit runner: QueryBuilder.Runner[T], countRunner: QueryBuilder.CountRunner) { def run: Seq[T] = runner(parameters) def runCount: QueryBuilder.CountResult = countRunner(parameters) /** * Runs the query and returns total found rows without considering of pagination. */ def runWithCount: (Seq[T], Int, Option[LocalDateTime]) = { val (total, lastUpdate) = runCount (run, total, lastUpdate) } def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N] def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = { filter.fold(this)(withFilter) } def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty) def withSorting(newSorting: Sorting): QueryBuilder[T, D, N] def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = { sorting.fold(this)(withSorting) } def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty) def withPagination(newPagination: Pagination): QueryBuilder[T, D, N] def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = { pagination.fold(this)(withPagination) } def resetPagination: QueryBuilder[T, D, N] }