From bc60e6aef2a22bdf2167f56417477da121cdeed1 Mon Sep 17 00:00:00 2001 From: vlad Date: Sat, 26 Aug 2017 18:41:47 -0700 Subject: Making Query builder to escape table names with schemas properly --- .../pdsuicommon/db/PostgresQueryBuilder.scala | 26 +++++++++++++++------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala') diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala index 8ef1829..0ddf811 100644 --- a/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala +++ b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala @@ -3,6 +3,7 @@ package xyz.driver.pdsuicommon.db import java.sql.ResultSet import io.getquill.{PostgresDialect, PostgresEscape} +import xyz.driver.pdsuicommon.db.PostgresQueryBuilder.SmartPostgresEscape import scala.collection.breakOut @@ -10,8 +11,17 @@ object PostgresQueryBuilder { import xyz.driver.pdsuicommon.db.QueryBuilder._ - type Escape = PostgresEscape - val Escape = PostgresEscape + trait SmartPostgresEscape extends PostgresEscape { + override def column(s: String): String = + if (s.startsWith("$")) s else super.column(s) + override def default(s: String): String = + s.split("\\.").map(ss => s""""$ss"""").mkString(".") + } + + object SmartPostgresEscape extends SmartPostgresEscape + + type Escape = SmartPostgresEscape + val Escape = SmartPostgresEscape def apply[T](tableName: String, lastUpdateFieldName: Option[String], @@ -42,14 +52,14 @@ object PostgresQueryBuilder { extractor: ResultSet => T)(implicit sqlContext: PostgresContext): PostgresQueryBuilder[T] = { val runner: Runner[T] = { parameters => - val (sql, binder) = parameters.toSql(countQuery = false, fields = fields, namingStrategy = PostgresEscape) + val (sql, binder) = parameters.toSql(countQuery = false, fields = fields, namingStrategy = SmartPostgresEscape) sqlContext.executeQuery[T](sql, binder, { resultSet => extractor(resultSet) }) } val countRunner: CountRunner = { parameters => - val (sql, binder) = parameters.toSql(countQuery = true, namingStrategy = PostgresEscape) + val (sql, binder) = parameters.toSql(countQuery = true, namingStrategy = SmartPostgresEscape) sqlContext .executeQuery[CountResult]( sql, @@ -80,19 +90,19 @@ class PostgresQueryBuilder[T](parameters: PostgresQueryBuilderParameters)(implic countRunner: QueryBuilder.CountRunner) extends QueryBuilder[T, PostgresDialect, PostgresQueryBuilder.Escape](parameters) { - def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, PostgresDialect, SmartPostgresEscape] = { new PostgresQueryBuilder[T](parameters.copy(filter = newFilter)) } - def withSorting(newSorting: Sorting): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + def withSorting(newSorting: Sorting): QueryBuilder[T, PostgresDialect, SmartPostgresEscape] = { new PostgresQueryBuilder[T](parameters.copy(sorting = newSorting)) } - def withPagination(newPagination: Pagination): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + def withPagination(newPagination: Pagination): QueryBuilder[T, PostgresDialect, SmartPostgresEscape] = { new PostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination))) } - def resetPagination: QueryBuilder[T, PostgresDialect, PostgresEscape] = { + def resetPagination: QueryBuilder[T, PostgresDialect, SmartPostgresEscape] = { new PostgresQueryBuilder[T](parameters.copy(pagination = None)) } } -- cgit v1.2.3