aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/common/db/QueryBuilder.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/common/db/QueryBuilder.scala')
-rw-r--r--src/main/scala/xyz/driver/common/db/QueryBuilder.scala344
1 files changed, 344 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/common/db/QueryBuilder.scala b/src/main/scala/xyz/driver/common/db/QueryBuilder.scala
new file mode 100644
index 0000000..f0beca6
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/QueryBuilder.scala
@@ -0,0 +1,344 @@
+package xyz.driver.common.db
+
+import java.sql.PreparedStatement
+import java.time.LocalDateTime
+
+import io.getquill.NamingStrategy
+import io.getquill.context.sql.idiom.SqlIdiom
+import xyz.driver.common.db.Sorting.{Dimension, Sequential}
+import xyz.driver.common.db.SortingOrder.{Ascending, Descending}
+
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.{ExecutionContext, Future}
+
+object QueryBuilder {
+
+ type Runner[T] = (QueryBuilderParameters) => Future[Seq[T]]
+
+ type CountResult = (Int, Option[LocalDateTime])
+
+ type CountRunner = (QueryBuilderParameters) => Future[CountResult]
+
+ /**
+ * Binder for PreparedStatement
+ */
+ type Binder = PreparedStatement => PreparedStatement
+
+ case class TableData(tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
+
+ val AllFields = Set("*")
+
+}
+
+case class TableLink(keyColumnName: String,
+ foreignTableName: String,
+ foreignKeyColumnName: String)
+
+object QueryBuilderParameters {
+ val AllFields = Set("*")
+}
+
+sealed trait QueryBuilderParameters {
+
+ def tableData: QueryBuilder.TableData
+ def links: Map[String, TableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def findLink(tableName: String): TableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
+ toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy)
+ }
+
+ def toSql(countQuery: Boolean,
+ fields: Set[String],
+ namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
+ val escapedTableName = namingStrategy.table(tableData.tableName)
+ val fieldsSql: String = if (countQuery) {
+ "count(*)" + (tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})"
+ case None => ""
+ })
+ } else {
+ if (fields == QueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.${namingStrategy.column(field)}"
+ }
+ .mkString(", ")
+ }
+ }
+ val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy)
+ val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy)
+
+ val limitSql = limitToSql()
+
+ val sql = new StringBuilder()
+ sql.append("select ")
+ sql.append(fieldsSql)
+ sql.append("\nfrom ")
+ sql.append(escapedTableName)
+
+ val filtersTableLinks: Seq[TableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ foreignTableLinks.foreach {
+ case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = namingStrategy.table(foreignTableName)
+
+ sql.append("\ninner join ")
+ sql.append(escapedForeignTableName)
+ sql.append(" on ")
+
+ sql.append(escapedTableName)
+ sql.append('.')
+ sql.append(namingStrategy.column(keyColumnName))
+
+ sql.append(" = ")
+
+ sql.append(escapedForeignTableName)
+ sql.append('.')
+ sql.append(namingStrategy.column(foreignKeyColumnName))
+ }
+
+ if (where.nonEmpty) {
+ sql.append("\nwhere ")
+ sql.append(where)
+ }
+
+ if (orderBy.nonEmpty && !countQuery) {
+ sql.append("\norder by ")
+ sql.append(orderBy)
+ }
+
+ if (limitSql.nonEmpty && !countQuery) {
+ sql.append("\n")
+ sql.append(limitSql)
+ }
+
+ (sql.toString, binder(bindings))
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String,
+ filter: SearchFilterExpr,
+ namingStrategy: NamingStrategy): (String, List[AnyRef]) = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def placeholder(field: String) = "?"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table)
+ s"$tableName.${namingStrategy.column(dimension.name)}"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy)
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ ("", List.empty)
+
+ case AllowAll =>
+ ("1", List.empty)
+
+ case DenyAll =>
+ ("0", List.empty)
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ (s"${escapeDimension(dimension)} is NULL", List.empty)
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ (s"${escapeDimension(dimension)} is not NULL", List.empty)
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})"
+ (sql, List(value))
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => "="
+ case NotEq => "!="
+ case Like => "like"
+ case Gt => ">"
+ case GtEq => ">="
+ case Lt => "<"
+ case LtEq => "<="
+ }
+ (s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value))
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => "in"
+ case SearchFilterNAryOperation.NotIn => "not in"
+ }
+
+ val bindings = ListBuffer[AnyRef]()
+ val sqlPlaceholder = placeholder(dimension.name)
+ val formattedValues = values.map { value =>
+ bindings += value
+ sqlPlaceholder
+ }.mkString(", ")
+ (s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList)
+
+ case Intersection(operands) =>
+ val (sql, bindings) = filterToSqlMultiple(operands).unzip
+ (sql.mkString("(", " and ", ")"), bindings.flatten.toList)
+
+ case Union(operands) =>
+ val (sql, bindings) = filterToSqlMultiple(operands).unzip
+ (sql.mkString("(", " or ", ")"), bindings.flatten.toList)
+ }
+ }
+
+ protected def limitToSql(): String
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String,
+ sorting: Sorting,
+ namingStrategy: NamingStrategy): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.${namingStrategy.column(field)}"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])
+ (bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach { case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
+
+case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData,
+ links: Map[String, TableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None) extends QueryBuilderParameters {
+
+ def limitToSql(): String = {
+ pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ s"limit ${pagination.pageSize} OFFSET $startFrom"
+ } getOrElse ""
+ }
+
+}
+
+/**
+ * @param links Links to another tables grouped by foreignTableName
+ */
+case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData,
+ links: Map[String, TableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None) extends QueryBuilderParameters {
+
+ def limitToSql(): String = pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ s"limit $startFrom, ${pagination.pageSize}"
+ }.getOrElse("")
+
+}
+
+abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters)
+ (implicit runner: QueryBuilder.Runner[T],
+ countRunner: QueryBuilder.CountRunner,
+ ec: ExecutionContext) {
+
+ def run: Future[Seq[T]] = runner(parameters)
+
+ def runCount: Future[QueryBuilder.CountResult] = countRunner(parameters)
+
+ /**
+ * Runs the query and returns total found rows without considering of pagination.
+ */
+ def runWithCount: Future[(Seq[T], Int, Option[LocalDateTime])] = {
+ val countFuture = runCount
+ val selectAllFuture = run
+ for {
+ (total, lastUpdate) <- countFuture
+ all <- selectAllFuture
+ } yield (all, total, lastUpdate)
+ }
+
+ def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N]
+
+ def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = {
+ filter.fold(this)(withFilter)
+ }
+
+ def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty)
+
+
+ def withSorting(newSorting: Sorting): QueryBuilder[T, D, N]
+
+ def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = {
+ sorting.fold(this)(withSorting)
+ }
+
+ def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty)
+
+
+ def withPagination(newPagination: Pagination): QueryBuilder[T, D, N]
+
+ def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = {
+ pagination.fold(this)(withPagination)
+ }
+
+ def resetPagination: QueryBuilder[T, D, N]
+
+}