From 6ea742073940d2d5809fb637d3d3a785341fc4f9 Mon Sep 17 00:00:00 2001 From: vlad Date: Thu, 17 Aug 2017 14:15:20 -0700 Subject: Moving PostgresContext to the common code (cherry picked from commit 8acbc7e) --- .../driver/pdsuicommon/db/PostgresContext.scala | 92 ++++++++++++++++++++ .../pdsuicommon/db/PostgresQueryBuilder.scala | 98 ++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 src/main/scala/xyz/driver/pdsuicommon/db/PostgresContext.scala create mode 100644 src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/PostgresContext.scala b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresContext.scala new file mode 100644 index 0000000..1b7e2fb --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresContext.scala @@ -0,0 +1,92 @@ +package xyz.driver.pdsuicommon.db + +import java.io.Closeable +import java.sql.Types +import java.time._ +import java.util.UUID +import java.util.concurrent.Executors +import javax.sql.DataSource + +import io.getquill._ +import xyz.driver.pdsuicommon.concurrent.MdcExecutionContext +import xyz.driver.pdsuicommon.db.PostgresContext.Settings +import xyz.driver.pdsuicommon.domain.UuidId +import xyz.driver.pdsuicommon.logging._ + +import scala.concurrent.ExecutionContext +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} + +object PostgresContext extends PhiLogging { + + final case class Settings(connection: com.typesafe.config.Config, + connectionAttemptsOnStartup: Int, + threadPoolSize: Int) + + def apply(settings: Settings): PostgresContext = { + // Prevent leaking credentials to a log + Try(JdbcContextConfig(settings.connection).dataSource) match { + case Success(dataSource) => new PostgresContext(dataSource, settings) + case Failure(NonFatal(e)) => + logger.error(phi"Can not load dataSource, error: ${Unsafe(e.getClass.getName)}") + throw new IllegalArgumentException("Can not load dataSource from config. Check your database and config") + } + } + +} + +class PostgresContext(val dataSource: DataSource with Closeable, settings: Settings) + extends PostgresJdbcContext[SnakeCase](dataSource) with TransactionalContext + with EntityExtractorDerivation[SnakeCase] { + + private val tpe = Executors.newFixedThreadPool(settings.threadPoolSize) + + implicit val executionContext: ExecutionContext = { + val orig = ExecutionContext.fromExecutor(tpe) + MdcExecutionContext.from(orig) + } + + override def close(): Unit = { + super.close() + tpe.shutdownNow() + } + + /** + * Usable for QueryBuilder's extractors + */ + def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = { + LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC) + } + + // Override localDateTime encoder and decoder cause + // clinicaltrials.gov uses bigint to store timestamps + + override implicit val localDateTimeEncoder: Encoder[LocalDateTime] = + encoder(Types.BIGINT, + (index, value, row) => row.setLong(index, value.atZone(ZoneOffset.UTC).toInstant.toEpochMilli)) + + override implicit val localDateTimeDecoder: Decoder[LocalDateTime] = + decoder( + Types.BIGINT, + (index, row) => { + row.getLong(index) match { + case 0 => throw new NullPointerException("0 is decoded as null") + case x => LocalDateTime.ofInstant(Instant.ofEpochMilli(x), ZoneId.of("Z")) + } + } + ) + + implicit def encodeUuidId[T] = MappedEncoding[UuidId[T], String](_.toString) + implicit def decodeUuidId[T] = MappedEncoding[String, UuidId[T]] { uuid => + UuidId[T](UUID.fromString(uuid)) + } + + def decodeOptUuidId[T] = MappedEncoding[Option[String], Option[UuidId[T]]] { + case Some(x) => Option(x).map(y => UuidId[T](UUID.fromString(y))) + case None => None + } + + implicit def decodeUuid[T] = MappedEncoding[String, UUID] { uuid => + UUID.fromString(uuid) + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala new file mode 100644 index 0000000..8ef1829 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/PostgresQueryBuilder.scala @@ -0,0 +1,98 @@ +package xyz.driver.pdsuicommon.db + +import java.sql.ResultSet + +import io.getquill.{PostgresDialect, PostgresEscape} + +import scala.collection.breakOut + +object PostgresQueryBuilder { + + import xyz.driver.pdsuicommon.db.QueryBuilder._ + + type Escape = PostgresEscape + val Escape = PostgresEscape + + def apply[T](tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[TableLink], + runner: Runner[T], + countRunner: CountRunner): PostgresQueryBuilder[T] = { + val parameters = PostgresQueryBuilderParameters( + tableData = TableData(tableName, lastUpdateFieldName, nullableFields), + links = links.map(x => x.foreignTableName -> x)(breakOut) + ) + new PostgresQueryBuilder[T](parameters)(runner, countRunner) + } + + def apply[T](tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[TableLink], + extractor: ResultSet => T)(implicit sqlContext: PostgresContext): PostgresQueryBuilder[T] = { + apply(tableName, QueryBuilderParameters.AllFields, lastUpdateFieldName, nullableFields, links, extractor) + } + + def apply[T](tableName: String, + fields: Set[String], + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[TableLink], + extractor: ResultSet => T)(implicit sqlContext: PostgresContext): PostgresQueryBuilder[T] = { + + val runner: Runner[T] = { parameters => + val (sql, binder) = parameters.toSql(countQuery = false, fields = fields, namingStrategy = PostgresEscape) + sqlContext.executeQuery[T](sql, binder, { resultSet => + extractor(resultSet) + }) + } + + val countRunner: CountRunner = { parameters => + val (sql, binder) = parameters.toSql(countQuery = true, namingStrategy = PostgresEscape) + sqlContext + .executeQuery[CountResult]( + sql, + binder, { resultSet => + val count = resultSet.getInt(1) + val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) { + Option(resultSet.getTimestamp(2)).map(sqlContext.timestampToLocalDateTime) + } else None + + (count, lastUpdate) + } + ) + .head + } + + apply[T]( + tableName = tableName, + lastUpdateFieldName = lastUpdateFieldName, + nullableFields = nullableFields, + links = links, + runner = runner, + countRunner = countRunner + ) + } +} + +class PostgresQueryBuilder[T](parameters: PostgresQueryBuilderParameters)(implicit runner: QueryBuilder.Runner[T], + countRunner: QueryBuilder.CountRunner) + extends QueryBuilder[T, PostgresDialect, PostgresQueryBuilder.Escape](parameters) { + + def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + new PostgresQueryBuilder[T](parameters.copy(filter = newFilter)) + } + + def withSorting(newSorting: Sorting): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + new PostgresQueryBuilder[T](parameters.copy(sorting = newSorting)) + } + + def withPagination(newPagination: Pagination): QueryBuilder[T, PostgresDialect, PostgresEscape] = { + new PostgresQueryBuilder[T](parameters.copy(pagination = Some(newPagination))) + } + + def resetPagination: QueryBuilder[T, PostgresDialect, PostgresEscape] = { + new PostgresQueryBuilder[T](parameters.copy(pagination = None)) + } +} -- cgit v1.2.3