aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/pdsuicommon/db/PostgresContext.scala
blob: 1b7e2fba463fa379bb8f9a51afadb1172756fe69 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package xyz.driver.pdsuicommon.db

import java.io.Closeable
import java.sql.Types
import java.time._
import java.util.UUID
import java.util.concurrent.Executors
import javax.sql.DataSource

import io.getquill._
import xyz.driver.pdsuicommon.concurrent.MdcExecutionContext
import xyz.driver.pdsuicommon.db.PostgresContext.Settings
import xyz.driver.pdsuicommon.domain.UuidId
import xyz.driver.pdsuicommon.logging._

import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

object PostgresContext extends PhiLogging {

  final case class Settings(connection: com.typesafe.config.Config,
                            connectionAttemptsOnStartup: Int,
                            threadPoolSize: Int)

  def apply(settings: Settings): PostgresContext = {
    // Prevent leaking credentials to a log
    Try(JdbcContextConfig(settings.connection).dataSource) match {
      case Success(dataSource) => new PostgresContext(dataSource, settings)
      case Failure(NonFatal(e)) =>
        logger.error(phi"Can not load dataSource, error: ${Unsafe(e.getClass.getName)}")
        throw new IllegalArgumentException("Can not load dataSource from config. Check your database and config")
    }
  }

}

class PostgresContext(val dataSource: DataSource with Closeable, settings: Settings)
    extends PostgresJdbcContext[SnakeCase](dataSource) with TransactionalContext
    with EntityExtractorDerivation[SnakeCase] {

  private val tpe = Executors.newFixedThreadPool(settings.threadPoolSize)

  implicit val executionContext: ExecutionContext = {
    val orig = ExecutionContext.fromExecutor(tpe)
    MdcExecutionContext.from(orig)
  }

  override def close(): Unit = {
    super.close()
    tpe.shutdownNow()
  }

  /**
    * Usable for QueryBuilder's extractors
    */
  def timestampToLocalDateTime(timestamp: java.sql.Timestamp): LocalDateTime = {
    LocalDateTime.ofInstant(timestamp.toInstant, ZoneOffset.UTC)
  }

  // Override localDateTime encoder and decoder cause
  // clinicaltrials.gov uses bigint to store timestamps

  override implicit val localDateTimeEncoder: Encoder[LocalDateTime] =
    encoder(Types.BIGINT,
            (index, value, row) => row.setLong(index, value.atZone(ZoneOffset.UTC).toInstant.toEpochMilli))

  override implicit val localDateTimeDecoder: Decoder[LocalDateTime] =
    decoder(
      Types.BIGINT,
      (index, row) => {
        row.getLong(index) match {
          case 0 => throw new NullPointerException("0 is decoded as null")
          case x => LocalDateTime.ofInstant(Instant.ofEpochMilli(x), ZoneId.of("Z"))
        }
      }
    )

  implicit def encodeUuidId[T] = MappedEncoding[UuidId[T], String](_.toString)
  implicit def decodeUuidId[T] = MappedEncoding[String, UuidId[T]] { uuid =>
    UuidId[T](UUID.fromString(uuid))
  }

  def decodeOptUuidId[T] = MappedEncoding[Option[String], Option[UuidId[T]]] {
    case Some(x) => Option(x).map(y => UuidId[T](UUID.fromString(y)))
    case None    => None
  }

  implicit def decodeUuid[T] = MappedEncoding[String, UUID] { uuid =>
    UUID.fromString(uuid)
  }
}