package xyz.driver.restquery.db import java.sql.PreparedStatement import slick.jdbc.{JdbcProfile, SQLActionBuilder} import xyz.driver.restquery.query.Sorting.{Dimension, Sequential} import xyz.driver.restquery.query.SortingOrder.{Ascending, Descending} import xyz.driver.restquery.query._ object SlickQueryBuilderParameters { val AllFields = Set("*") } trait SlickQueryBuilderParameters { import SlickQueryBuilder._ def databaseName: String def tableData: SlickQueryBuilder.TableData def links: Map[String, SlickTableLink] def filter: SearchFilterExpr def sorting: Sorting def pagination: Option[Pagination] def qs: String def findLink(tableName: String): SlickTableLink = links.get(tableName) match { case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") case Some(link) => link } def toSql(countQuery: Boolean = false)(implicit profile: JdbcProfile): SQLActionBuilder = { toSql(countQuery, SlickQueryBuilderParameters.AllFields) } def toSql(countQuery: Boolean, fields: Set[String])(implicit profile: JdbcProfile): SQLActionBuilder = { import profile.api._ val escapedTableName = s"""$qs$databaseName$qs.$qs${tableData.tableName}$qs""" val fieldsSql: String = if (countQuery) { val suffix: String = tableData.lastUpdateFieldName match { case Some(lastUpdateField) => s", max($escapedTableName.$qs$lastUpdateField$qs)" case None => "" } s"count(*) $suffix" } else { if (fields == SlickQueryBuilderParameters.AllFields) { s"$escapedTableName.*" } else { fields .map { field => s"$escapedTableName.$qs$field$qs" } .mkString(", ") } } val where = filterToSql(escapedTableName, filter) val orderBy = sortingToSql(escapedTableName, sorting) val limitSql = limitToSql() val sql = sql"""select #$fieldsSql from #$escapedTableName""" val filtersTableLinks: Seq[SlickTableLink] = { import SearchFilterExpr._ def aux(expr: SearchFilterExpr): Seq[SlickTableLink] = expr match { case Atom.TableName(tableName) => List(findLink(tableName)) case Intersection(xs) => xs.flatMap(aux) case Union(xs) => xs.flatMap(aux) case _ => Nil } aux(filter) } val sortingTableLinks: Seq[SlickTableLink] = Sorting.collect(sorting) { case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) } // Combine links from sorting and filter without duplicates val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct def fkSql(fkLinksSql: SQLActionBuilder, tableLinks: Seq[SlickTableLink]): SQLActionBuilder = { if (tableLinks.nonEmpty) { tableLinks.head match { case SlickTableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => val escapedForeignTableName = s"$qs$databaseName$qs.$qs$foreignTableName$qs" val join = sql""" inner join #$escapedForeignTableName on #$escapedTableName.#$qs#$keyColumnName#$qs=#$escapedForeignTableName.#$qs#$foreignKeyColumnName#$qs""" fkSql(fkLinksSql concat join, tableLinks.tail) } } else fkLinksSql } val foreignTableLinksSql = fkSql(sql"", foreignTableLinks) val whereSql = if (where.queryParts.size > 1) { sql" where " concat where } else sql"" val orderSql = if (orderBy.nonEmpty && !countQuery) { sql" order by #$orderBy" } else sql"" val limSql = if (limitSql.queryParts.size > 1 && !countQuery) { sql" " concat limitSql } else sql"" sql concat foreignTableLinksSql concat whereSql concat orderSql concat limSql } /** * Converts filter expression to SQL expression. * * @return Returns SQL string and list of values for binding in prepared statement. */ protected def filterToSql(escapedTableName: String, filter: SearchFilterExpr)( implicit profile: JdbcProfile): SQLActionBuilder = { import SearchFilterBinaryOperation._ import SearchFilterExpr._ import profile.api._ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" def escapeDimension(dimension: SearchFilterExpr.Dimension) = { s"${dimension.tableName.map(t => s"$qs$databaseName$qs.$qs$t$qs").getOrElse(escapedTableName)}.$qs${dimension.name}$qs" } def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x) } def multipleSqlToAction( first: Boolean, op: String, conditions: Seq[SQLActionBuilder], sql: SQLActionBuilder): SQLActionBuilder = { if (conditions.nonEmpty) { val condition = conditions.head if (first) { multipleSqlToAction(first = false, op, conditions.tail, condition) } else { multipleSqlToAction(first = false, op, conditions.tail, sql concat sql" #${op} " concat condition) } } else sql } def concatenateParameters(sql: SQLActionBuilder, first: Boolean, tail: Seq[AnyRef]): SQLActionBuilder = { if (tail.nonEmpty) { if (!first) { concatenateParameters(sql concat sql""",${tail.head}""", first = false, tail.tail) } else { concatenateParameters(sql"""(${tail.head}""", first = false, tail.tail) } } else sql concat sql")" } filter match { case x if isEmpty(x) => sql"" case AllowAll => sql"1=1" case DenyAll => sql"1=0" case Atom.Binary(dimension, Eq, value) if isNull(value) => sql"#${escapeDimension(dimension)} is NULL" case Atom.Binary(dimension, NotEq, value) if isNull(value) => sql"#${escapeDimension(dimension)} is not NULL" case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => // In MySQL NULL <> Any === NULL // So, to handle NotEq for nullable fields we need to use more complex SQL expression. // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html val escapedColumn = escapeDimension(dimension) sql"(#${escapedColumn} is null or #${escapedColumn} != $value)" case Atom.Binary(dimension, op, value) => val operator = op match { case Eq => sql"=" case NotEq => sql"!=" case Like => sql" like " case Gt => sql">" case GtEq => sql">=" case Lt => sql"<" case LtEq => sql"<=" } sql"#${escapeDimension(dimension)}" concat operator concat sql"""$value""" case Atom.NAry(dimension, op, values) => val sqlOp = op match { case SearchFilterNAryOperation.In => sql" in " case SearchFilterNAryOperation.NotIn => sql" not in " } if (values.nonEmpty) { val formattedValues = concatenateParameters(sql"", first = true, values) sql"#${escapeDimension(dimension)}" concat sqlOp concat formattedValues } else { sql"1=0" } case Intersection(operands) => val filter = multipleSqlToAction(first = true, "and", filterToSqlMultiple(operands), sql"") sql"(" concat filter concat sql")" case Union(operands) => val filter = multipleSqlToAction(first = true, "or", filterToSqlMultiple(operands), sql"") sql"(" concat filter concat sql")" } } protected def limitToSql()(implicit profile: JdbcProfile): SQLActionBuilder /** * @param escapedMainTableName Should be escaped */ protected def sortingToSql(escapedMainTableName: String, sorting: Sorting)(implicit profile: JdbcProfile): String = { sorting match { case Dimension(optSortingTableName, field, order) => val sortingTableName = optSortingTableName.map(x => s"$qs$databaseName$qs.$qs$x$qs").getOrElse(escapedMainTableName) val fullName = s"$sortingTableName.$qs$field$qs" s"$fullName ${orderToSql(order)}" case Sequential(xs) => xs.map(sortingToSql(escapedMainTableName, _)).mkString(", ") } } protected def orderToSql(x: SortingOrder): String = x match { case Ascending => "asc" case Descending => "desc" } protected def binder(bindings: List[AnyRef])(bind: PreparedStatement): PreparedStatement = { bindings.zipWithIndex.foreach { case (binding, index) => bind.setObject(index + 1, binding) } bind } }