package xyz.driver.pdsuicommon.db import java.io.Closeable import java.net.URI import java.time._ import java.util.UUID import java.util.concurrent.Executors import javax.sql.DataSource import xyz.driver.pdsuicommon.logging.{PhiLogging, Unsafe} import xyz.driver.pdsuicommon.concurrent.MdcExecutionContext import xyz.driver.pdsuicommon.db.SqlContext.Settings import xyz.driver.pdsuicommon.domain._ import xyz.driver.pdsuicommon.error.IncorrectIdException import xyz.driver.pdsuicommon.utils.JsonSerializer import com.typesafe.config.Config import io.getquill._ import xyz.driver.pdsuidomain.entities.{CaseId, RecordRequestId} import scala.concurrent.ExecutionContext import scala.util.control.NonFatal import scala.util.{Failure, Success, Try} object SqlContext extends PhiLogging { case class DbCredentials(user: String, password: String, host: String, port: Int, dbName: String, dbCreateFlag: Boolean, dbContext: String, connectionParams: String, url: String) case class Settings(credentials: DbCredentials, connection: Config, connectionAttemptsOnStartup: Int, threadPoolSize: Int) def apply(settings: Settings): SqlContext = { // Prevent leaking credentials to a log Try(JdbcContextConfig(settings.connection).dataSource) match { case Success(dataSource) => new SqlContext(dataSource, settings) case Failure(NonFatal(e)) => logger.error(phi"Can not load dataSource, error: ${Unsafe(e.getClass.getName)}") throw new IllegalArgumentException("Can not load dataSource from config. Check your database and config") } } } class SqlContext(dataSource: DataSource with Closeable, settings: Settings) extends MysqlJdbcContext[MysqlEscape](dataSource) with EntityExtractorDerivation[Literal] { private val tpe = Executors.newFixedThreadPool(settings.threadPoolSize) implicit val executionContext: ExecutionContext = { val orig = ExecutionContext.fromExecutor(tpe) MdcExecutionContext.from(orig) } override def close(): Unit = { super.close() tpe.shutdownNow() } // ///////// Encodes/Decoders /////////// /** * Overrode, because Quill JDBC optionDecoder pass null inside decoders. * If custom decoder don't have special null handler, it will failed. * * @see https://github.com/getquill/quill/issues/535 */ implicit override def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] = decoder( sqlType = d.sqlType, row => index => { try { val res = d(index - 1, row) if (row.wasNull) { None } else { Some(res) } } catch { case _: NullPointerException => None case _: IncorrectIdException => None } } ) implicit def encodeStringId[T] = MappedEncoding[StringId[T], String](_.id) implicit def decodeStringId[T] = MappedEncoding[String, StringId[T]] { case "" => throw IncorrectIdException("'' is an invalid Id value") case x => StringId(x) } def decodeOptStringId[T] = MappedEncoding[Option[String], Option[StringId[T]]] { case None | Some("") => None case Some(x) => Some(StringId(x)) } implicit def encodeLongId[T] = MappedEncoding[LongId[T], Long](_.id) implicit def decodeLongId[T] = MappedEncoding[Long, LongId[T]] { case 0 => throw IncorrectIdException("0 is an invalid Id value") case x => LongId(x) } // TODO Dirty hack, see REP-475 def decodeOptLongId[T] = MappedEncoding[Option[Long], Option[LongId[T]]] { case None | Some(0) => None case Some(x) => Some(LongId(x)) } implicit def encodeUuidId[T] = MappedEncoding[UuidId[T], String](_.toString) implicit def decodeUuidId[T] = MappedEncoding[String, UuidId[T]] { case "" => throw IncorrectIdException("'' is an invalid Id value") case x => UuidId(x) } def decodeOptUuidId[T] = MappedEncoding[Option[String], Option[UuidId[T]]] { case None | Some("") => None case Some(x) => Some(UuidId(x)) } implicit def encodeTextJson[T: Manifest] = MappedEncoding[TextJson[T], String](x => JsonSerializer.serialize(x.content)) implicit def decodeTextJson[T: Manifest] = MappedEncoding[String, TextJson[T]] { x => TextJson(JsonSerializer.deserialize[T](x)) } implicit val encodeUserRole = MappedEncoding[User.Role, Int](_.bit) implicit val decodeUserRole = MappedEncoding[Int, User.Role] { // 0 is treated as null for numeric types case 0 => throw new NullPointerException("0 means no roles. A user must have a role") case x => User.Role(x) } implicit val encodeEmail = MappedEncoding[Email, String](_.value.toString) implicit val decodeEmail = MappedEncoding[String, Email](Email) implicit val encodePasswordHash = MappedEncoding[PasswordHash, Array[Byte]](_.value) implicit val decodePasswordHash = MappedEncoding[Array[Byte], PasswordHash](PasswordHash(_)) implicit val encodeUri = MappedEncoding[URI, String](_.toString) implicit val decodeUri = MappedEncoding[String, URI](URI.create) implicit val encodeCaseId = MappedEncoding[CaseId, String](_.id.toString) implicit val decodeCaseId = MappedEncoding[String, CaseId](CaseId(_)) implicit val encodeFuzzyValue = { MappedEncoding[FuzzyValue, String] { case FuzzyValue.No => "No" case FuzzyValue.Yes => "Yes" case FuzzyValue.Maybe => "Maybe" } } implicit val decodeFuzzyValue = MappedEncoding[String, FuzzyValue] { case "Yes" => FuzzyValue.Yes case "No" => FuzzyValue.No case "Maybe" => FuzzyValue.Maybe case x => Option(x).fold { throw new NullPointerException("FuzzyValue is null") // See catch in optionDecoder } { _ => throw new IllegalStateException(s"Unknown fuzzy value: $x") } } implicit val encodeRecordRequestId = MappedEncoding[RecordRequestId, String](_.id.toString) implicit val decodeRecordRequestId = MappedEncoding[String, RecordRequestId] { x => RecordRequestId(UUID.fromString(x)) } final implicit class LocalDateTimeDbOps(val left: LocalDateTime) { // scalastyle:off def <=(right: LocalDateTime): Quoted[Boolean] = quote(infix"$left <= $right".as[Boolean]) } }