aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala')
-rw-r--r--src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala240
1 files changed, 240 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala
new file mode 100644
index 0000000..6ab7eb4
--- /dev/null
+++ b/src/main/scala/xyz/driver/restquery/db/SlickQueryBuilderParameters.scala
@@ -0,0 +1,240 @@
+package xyz.driver.restquery.query
+
+import java.sql.PreparedStatement
+
+import slick.jdbc.{JdbcProfile, SQLActionBuilder}
+import xyz.driver.restquery.db.{SlickQueryBuilder, SlickQueryBuilderParameters, SlickTableLink}
+import xyz.driver.restquery.query.Sorting.{Dimension, Sequential}
+import xyz.driver.restquery.query.SortingOrder.{Ascending, Descending}
+
+trait SlickQueryBuilderParameters {
+ import SlickQueryBuilder._
+
+ def databaseName: String
+ def tableData: SlickQueryBuilder.TableData
+ def links: Map[String, SlickTableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def qs: String
+
+ def findLink(tableName: String): SlickTableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = {
+ toSql(countQuery, SlickQueryBuilderParameters.AllFields)
+ }
+
+ def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = {
+ import profile.api._
+ val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs"""
+ val fieldsSql: String = if (countQuery) {
+ val suffix: String = tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)"
+ case None => ""
+ }
+ s"count(*) $suffix"
+ } else {
+ if (fields == SlickQueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.$qs$field$qs"
+ }
+ .mkString(", ")
+ }
+ }
+ val where = filterToSql(escapedTableName, filter)
+ val orderBy = sortingToSql(escapedTableName, sorting)
+
+ val limitSql = limitToSql()
+
+ val sql = sql"""select #$fieldsSql from #$escapedTableName"""
+
+ val filtersTableLinks: Seq[SlickTableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = {
+ if (tableLinks.nonEmpty) {
+ tableLinks.head match {
+ case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs"
+ val join = sql""" inner join #$escapedForeignTableName
+ on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs"""
+ fkSql(fkLinksSql concat join, tableLinks.tail)
+ }
+ } else fkLinksSql
+ }
+ val foreignTableLinksSql = fkSql(sql"", foreignTableLinks)
+
+ val whereSql = if (where.queryParts.size > 1) {
+ sql" where " concat where
+ } else sql""
+
+ val orderSql = if (orderBy.nonEmpty && !countQuery) {
+ sql" order by #$orderBy"
+ } else sql""
+
+ val limSql = if (limitSql.queryParts.size > 1 && !countQuery) {
+ sql" " concat limitSql
+ } else sql""
+
+ sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)(
+ implicit profile: JdbcProfile): SQLActionBuilder = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+ import profile.api._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x)
+ }
+
+ def multipleSqlToAction(first: Boolean,
+ op: String,
+ conditions: Seq[SQLActionBuilder],
+ sql: SQLActionBuilder): SQLActionBuilder = {
+ if (conditions.nonEmpty) {
+ val condition = conditions.head
+ if (first) {
+ multipleSqlToAction(first = false, op, conditions.tail, condition)
+ } else {
+ multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition)
+ }
+ } else sql
+ }
+
+ def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = {
+ if (tail.nonEmpty) {
+ if (!first) {
+ concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail)
+ } else {
+ concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail)
+ }
+ } else sql concat sql")"
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ sql""
+
+ case AllowAll =>
+ sql"1=1"
+
+ case DenyAll =>
+ sql"1=0"
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ sql"#${escapeDimension(dimension)} is not NULL"
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ sql"(#${escapedColumn} is null or #${escapedColumn} != $value)"
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => sql"="
+ case NotEq => sql"!="
+ case Like => sql" like "
+ case Gt => sql">"
+ case GtEq => sql">="
+ case Lt => sql"<"
+ case LtEq => sql"<="
+ }
+ sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value"""
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => sql" in "
+ case SearchFilterNAryOperation.NotIn => sql" not in "
+ }
+
+ if (values.nonEmpty) {
+ val formattedValues = concatenateParameters(sql"", first = true, values)
+ sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues
+ } else {
+ sql"1=0"
+ }
+
+ case Intersection(operands) =>
+ val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+
+ case Union(operands) =>
+ val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"")
+ sql"(" concat filter concat sql")"
+ }
+ }
+
+ protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName =
+ optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.$qs$field$qs"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach {
+ case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}