diff options
Diffstat (limited to 'src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala')
-rw-r--r-- | src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala | 340 |
1 files changed, 0 insertions, 340 deletions
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala deleted file mode 100644 index 0bf1ed6..0000000 --- a/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala +++ /dev/null @@ -1,340 +0,0 @@ -package xyz.driver.pdsuicommon.db - -import java.sql.PreparedStatement -import java.time.LocalDateTime - -import io.getquill.NamingStrategy -import io.getquill.context.sql.idiom.SqlIdiom -import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential} -import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending} - -import scala.collection.mutable.ListBuffer - -object QueryBuilder { - - type Runner[T] = QueryBuilderParameters => Seq[T] - - type CountResult = (Int, Option[LocalDateTime]) - - type CountRunner = QueryBuilderParameters => CountResult - - /** - * Binder for PreparedStatement - */ - type Binder = PreparedStatement => PreparedStatement - - final case class TableData(tableName: String, - lastUpdateFieldName: Option[String] = None, - nullableFields: Set[String] = Set.empty) - - val AllFields = Set("*") - -} - -final case class TableLink(keyColumnName: String, foreignTableName: String, foreignKeyColumnName: String) - -object QueryBuilderParameters { - val AllFields = Set("*") -} - -sealed trait QueryBuilderParameters { - - def tableData: QueryBuilder.TableData - def links: Map[String, TableLink] - def filter: SearchFilterExpr - def sorting: Sorting - def pagination: Option[Pagination] - - def findLink(tableName: String): TableLink = links.get(tableName) match { - case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") - case Some(link) => link - } - - def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { - toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy) - } - - def toSql(countQuery: Boolean, fields: Set[String], namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { - val escapedTableName = namingStrategy.table(tableData.tableName) - val fieldsSql: String = if (countQuery) { - val suffix: String = tableData.lastUpdateFieldName match { - case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})" - case None => "" - } - "count(*)" + suffix - } else { - if (fields == QueryBuilderParameters.AllFields) { - s"$escapedTableName.*" - } else { - fields - .map { field => - s"$escapedTableName.${namingStrategy.column(field)}" - } - .mkString(", ") - } - } - val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy) - val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy) - - val limitSql = limitToSql() - - val sql = new StringBuilder() - sql.append("select ") - sql.append(fieldsSql) - sql.append("\nfrom ") - sql.append(escapedTableName) - - val filtersTableLinks: Seq[TableLink] = { - import SearchFilterExpr._ - def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match { - case Atom.TableName(tableName) => List(findLink(tableName)) - case Intersection(xs) => xs.flatMap(aux) - case Union(xs) => xs.flatMap(aux) - case _ => Nil - } - aux(filter) - } - - val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) { - case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) - } - - // Combine links from sorting and filter without duplicates - val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct - - foreignTableLinks.foreach { - case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => - val escapedForeignTableName = namingStrategy.table(foreignTableName) - - sql.append("\ninner join ") - sql.append(escapedForeignTableName) - sql.append(" on ") - - sql.append(escapedTableName) - sql.append('.') - sql.append(namingStrategy.column(keyColumnName)) - - sql.append(" = ") - - sql.append(escapedForeignTableName) - sql.append('.') - sql.append(namingStrategy.column(foreignKeyColumnName)) - } - - if (where.nonEmpty) { - sql.append("\nwhere ") - sql.append(where) - } - - if (orderBy.nonEmpty && !countQuery) { - sql.append("\norder by ") - sql.append(orderBy) - } - - if (limitSql.nonEmpty && !countQuery) { - sql.append("\n") - sql.append(limitSql) - } - - (sql.toString, binder(bindings)) - } - - /** - * Converts filter expression to SQL expression. - * - * @return Returns SQL string and list of values for binding in prepared statement. - */ - protected def filterToSql(escapedTableName: String, - filter: SearchFilterExpr, - namingStrategy: NamingStrategy): (String, List[AnyRef]) = { - import SearchFilterBinaryOperation._ - import SearchFilterExpr._ - - def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" - - def placeholder(field: String) = "?" - - def escapeDimension(dimension: SearchFilterExpr.Dimension) = { - val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table) - s"$tableName.${namingStrategy.column(dimension.name)}" - } - - def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { - case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy) - } - - filter match { - case x if isEmpty(x) => - ("", List.empty) - - case AllowAll => - ("1", List.empty) - - case DenyAll => - ("0", List.empty) - - case Atom.Binary(dimension, Eq, value) if isNull(value) => - (s"${escapeDimension(dimension)} is NULL", List.empty) - - case Atom.Binary(dimension, NotEq, value) if isNull(value) => - (s"${escapeDimension(dimension)} is not NULL", List.empty) - - case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => - // In MySQL NULL <> Any === NULL - // So, to handle NotEq for nullable fields we need to use more complex SQL expression. - // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html - val escapedColumn = escapeDimension(dimension) - val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})" - (sql, List(value)) - - case Atom.Binary(dimension, op, value) => - val operator = op match { - case Eq => "=" - case NotEq => "!=" - case Like => "like" - case Gt => ">" - case GtEq => ">=" - case Lt => "<" - case LtEq => "<=" - } - (s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value)) - - case Atom.NAry(dimension, op, values) => - val sqlOp = op match { - case SearchFilterNAryOperation.In => "in" - case SearchFilterNAryOperation.NotIn => "not in" - } - - val bindings = ListBuffer[AnyRef]() - val sqlPlaceholder = placeholder(dimension.name) - val formattedValues = if (values.nonEmpty) { - values - .map { value => - bindings += value - sqlPlaceholder - } - .mkString(", ") - } else "NULL" - (s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList) - - case Intersection(operands) => - val (sql, bindings) = filterToSqlMultiple(operands).unzip - (sql.mkString("(", " and ", ")"), bindings.flatten.toList) - - case Union(operands) => - val (sql, bindings) = filterToSqlMultiple(operands).unzip - (sql.mkString("(", " or ", ")"), bindings.flatten.toList) - } - } - - protected def limitToSql(): String - - /** - * @param escapedMainTableName Should be escaped - */ - protected def sortingToSql(escapedMainTableName: String, sorting: Sorting, namingStrategy: NamingStrategy): String = { - sorting match { - case Dimension(optSortingTableName, field, order) => - val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName) - val fullName = s"$sortingTableName.${namingStrategy.column(field)}" - - s"$fullName ${orderToSql(order)}" - - case Sequential(xs) => - xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ") - } - } - - protected def orderToSql(x: SortingOrder): String = x match { - case Ascending => "asc" - case Descending => "desc" - } - - protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { - bindings.zipWithIndex.foreach { - case (binding, index) => - bind.setObject(index + 1, binding) - } - - bind - } - -} - -final case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData, - links: Map[String, TableLink] = Map.empty, - filter: SearchFilterExpr = SearchFilterExpr.Empty, - sorting: Sorting = Sorting.Empty, - pagination: Option[Pagination] = None) - extends QueryBuilderParameters { - - def limitToSql(): String = { - pagination.map { pagination => - val startFrom = (pagination.pageNumber - 1) * pagination.pageSize - s"limit ${pagination.pageSize} OFFSET $startFrom" - } getOrElse "" - } - -} - -/** - * @param links Links to another tables grouped by foreignTableName - */ -final case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData, - links: Map[String, TableLink] = Map.empty, - filter: SearchFilterExpr = SearchFilterExpr.Empty, - sorting: Sorting = Sorting.Empty, - pagination: Option[Pagination] = None) - extends QueryBuilderParameters { - - def limitToSql(): String = - pagination - .map { pagination => - val startFrom = (pagination.pageNumber - 1) * pagination.pageSize - s"limit $startFrom, ${pagination.pageSize}" - } - .getOrElse("") - -} - -abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters)( - implicit runner: QueryBuilder.Runner[T], - countRunner: QueryBuilder.CountRunner) { - - def run: Seq[T] = runner(parameters) - - def runCount: QueryBuilder.CountResult = countRunner(parameters) - - /** - * Runs the query and returns total found rows without considering of pagination. - */ - def runWithCount: (Seq[T], Int, Option[LocalDateTime]) = { - val (total, lastUpdate) = runCount - (run, total, lastUpdate) - } - - def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N] - - def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = { - filter.fold(this)(withFilter) - } - - def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty) - - def withSorting(newSorting: Sorting): QueryBuilder[T, D, N] - - def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = { - sorting.fold(this)(withSorting) - } - - def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty) - - def withPagination(newPagination: Pagination): QueryBuilder[T, D, N] - - def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = { - pagination.fold(this)(withPagination) - } - - def resetPagination: QueryBuilder[T, D, N] - -} |