package xyz.driver.pdsuicommon.db
import java.sql.PreparedStatement
import java.time.LocalDateTime
import io.getquill.NamingStrategy
import io.getquill.context.sql.idiom.SqlIdiom
import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential}
import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending}
import scala.collection.mutable.ListBuffer
object QueryBuilder {
type Runner[T] = QueryBuilderParameters => Seq[T]
type CountResult = (Int, Option[LocalDateTime])
type CountRunner = QueryBuilderParameters => CountResult
/**
* Binder for PreparedStatement
*/
type Binder = PreparedStatement => PreparedStatement
final case class TableData(tableName: String,
lastUpdateFieldName: Option[String] = None,
nullableFields: Set[String] = Set.empty)
val AllFields = Set("*")
}
final case class TableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String)
object QueryBuilderParameters {
val AllFields = Set("*")
}
sealed trait QueryBuilderParameters {
def tableData: QueryBuilder.TableData
def links: Map[String, TableLink]
def filter: SearchFilterExpr
def sorting: Sorting
def pagination: Option[Pagination]
def findLink(tableName: String): TableLink = links.get(tableName) match {
case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
case Some(link) => link
}
def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy)
}
def toSql(countQuery: Boolean, fields: Set[String], namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
val escapedTableName = namingStrategy.table(tableData.tableName)
val fieldsSql: String = if (countQuery) {
val suffix: String = tableData.lastUpdateFieldName match {
case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})"
case None => ""
}
"count(*)" + suffix
} else {
if (fields == QueryBuilderParameters.AllFields) {
s"$escapedTableName.*"
} else {
fields
.map { field =>
s"$escapedTableName.${namingStrategy.column(field)}"
}
.mkString(", ")
}
}
val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy)
val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy)
val limitSql = limitToSql()
val sql = new StringBuilder()
sql.append("select ")
sql.append(fieldsSql)
sql.append("\nfrom ")
sql.append(escapedTableName)
val filtersTableLinks: Seq[TableLink] = {
import SearchFilterExpr._
def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match {
case Atom.TableName(tableName) => List(findLink(tableName))
case Intersection(xs) => xs.flatMap(aux)
case Union(xs) => xs.flatMap(aux)
case _ => Nil
}
aux(filter)
}
val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) {
case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
}
// Combine links from sorting and filter without duplicates
val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
foreignTableLinks.foreach {
case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
val escapedForeignTableName = namingStrategy.table(foreignTableName)
sql.append("\ninner join ")
sql.append(escapedForeignTableName)
sql.append(" on ")
sql.append(escapedTableName)
sql.append('.')
sql.append(namingStrategy.column(keyColumnName))
sql.append(" = ")
sql.append(escapedForeignTableName)
sql.append('.')
sql.append(namingStrategy.column(foreignKeyColumnName))
}
if (where.nonEmpty) {
sql.append("\nwhere ")
sql.append(where)
}
if (orderBy.nonEmpty && !countQuery) {
sql.append("\norder by ")
sql.append(orderBy)
}
if (limitSql.nonEmpty && !countQuery) {
sql.append("\n")
sql.append(limitSql)
}
(sql.toString, binder(bindings))
}
/**
* Converts filter expression to SQL expression.
*
* @return Returns SQL string and list of values for binding in prepared statement.
*/
protected def filterToSql(escapedTableName: String,
filter: SearchFilterExpr,
namingStrategy: NamingStrategy): (String, List[AnyRef]) = {
import SearchFilterBinaryOperation._
import SearchFilterExpr._
def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
def placeholder(field: String) = "?"
def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table)
s"$tableName.${namingStrategy.column(dimension.name)}"
}
def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy)
}
filter match {
case x if isEmpty(x) =>
("", List.empty)
case AllowAll =>
("1", List.empty)
case DenyAll =>
("0", List.empty)
case Atom.Binary(dimension, Eq, value) if isNull(value) =>
(s"${escapeDimension(dimension)} is NULL", List.empty)
case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
(s"${escapeDimension(dimension)} is not NULL", List.empty)
case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
// In MySQL NULL <> Any === NULL
// So, to handle NotEq for nullable fields we need to use more complex SQL expression.
// http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
val escapedColumn = escapeDimension(dimension)
val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})"
(sql, List(value))
case Atom.Binary(dimension, op, value) =>
val operator = op match {
case Eq => "="
case NotEq => "!="
case Like => "like"
case Gt => ">"
case GtEq => ">="
case Lt => "<"
case LtEq => "<="
}
(s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value))
case Atom.NAry(dimension, op, values) =>
val sqlOp = op match {
case SearchFilterNAryOperation.In => "in"
case SearchFilterNAryOperation.NotIn => "not in"
}
val bindings = ListBuffer[AnyRef]()
val sqlPlaceholder = placeholder(dimension.name)
val formattedValues = if (values.nonEmpty) {
values
.map { value =>
bindings += value
sqlPlaceholder
}
.mkString(", ")
} else "NULL"
(s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList)
case Intersection(operands) =>
val (sql, bindings) = filterToSqlMultiple(operands).unzip
(sql.mkString("(", " and ", ")"), bindings.flatten.toList)
case Union(operands) =>
val (sql, bindings) = filterToSqlMultiple(operands).unzip
(sql.mkString("(", " or ", ")"), bindings.flatten.toList)
}
}
protected def limitToSql(): String
/**
* @param escapedMainTableName Should be escaped
*/
protected def sortingToSql(escapedMainTableName: String, sorting: Sorting, namingStrategy: NamingStrategy): String = {
sorting match {
case Dimension(optSortingTableName, field, order) =>
val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName)
val fullName = s"$sortingTableName.${namingStrategy.column(field)}"
s"$fullName ${orderToSql(order)}"
case Sequential(xs) =>
xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ")
}
}
protected def orderToSql(x: SortingOrder): String = x match {
case Ascending => "asc"
case Descending => "desc"
}
protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
bindings.zipWithIndex.foreach {
case (binding, index) =>
bind.setObject(index + 1, binding)
}
bind
}
}
final case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData,
links: Map[String, TableLink] = Map.empty,
filter: SearchFilterExpr = SearchFilterExpr.Empty,
sorting: Sorting = Sorting.Empty,
pagination: Option[Pagination] = None)
extends QueryBuilderParameters {
def limitToSql(): String = {
pagination.map { pagination =>
val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
s"limit ${pagination.pageSize} OFFSET $startFrom"
} getOrElse ""
}
}
/**
* @param links Links to another tables grouped by foreignTableName
*/
final case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData,
links: Map[String, TableLink] = Map.empty,
filter: SearchFilterExpr = SearchFilterExpr.Empty,
sorting: Sorting = Sorting.Empty,
pagination: Option[Pagination] = None)
extends QueryBuilderParameters {
def limitToSql(): String =
pagination
.map { pagination =>
val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
s"limit $startFrom, ${pagination.pageSize}"
}
.getOrElse("")
}
abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters)(
implicit runner: QueryBuilder.Runner[T],
countRunner: QueryBuilder.CountRunner) {
def run: Seq[T] = runner(parameters)
def runCount: QueryBuilder.CountResult = countRunner(parameters)
/**
* Runs the query and returns total found rows without considering of pagination.
*/
def runWithCount: (Seq[T], Int, Option[LocalDateTime]) = {
val (total, lastUpdate) = runCount
(run, total, lastUpdate)
}
def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N]
def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = {
filter.fold(this)(withFilter)
}
def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty)
def withSorting(newSorting: Sorting): QueryBuilder[T, D, N]
def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = {
sorting.fold(this)(withSorting)
}
def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty)
def withPagination(newPagination: Pagination): QueryBuilder[T, D, N]
def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = {
pagination.fold(this)(withPagination)
}
def resetPagination: QueryBuilder[T, D, N]
}