aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/common/db
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/common/db')
-rw-r--r--src/main/scala/xyz/driver/common/db/DbCommand.scala15
-rw-r--r--src/main/scala/xyz/driver/common/db/DbCommandFactory.scala14
-rw-r--r--src/main/scala/xyz/driver/common/db/EntityExtractorDerivation.scala71
-rw-r--r--src/main/scala/xyz/driver/common/db/EntityNotFoundException.scala10
-rw-r--r--src/main/scala/xyz/driver/common/db/MysqlQueryBuilder.scala90
-rw-r--r--src/main/scala/xyz/driver/common/db/Pagination.scala20
-rw-r--r--src/main/scala/xyz/driver/common/db/QueryBuilder.scala344
-rw-r--r--src/main/scala/xyz/driver/common/db/SearchFilterExpr.scala210
-rw-r--r--src/main/scala/xyz/driver/common/db/Sorting.scala62
-rw-r--r--src/main/scala/xyz/driver/common/db/SqlContext.scala184
-rw-r--r--src/main/scala/xyz/driver/common/db/Transactions.scala23
-rw-r--r--src/main/scala/xyz/driver/common/db/repositories/BridgeUploadQueueRepository.scala24
-rw-r--r--src/main/scala/xyz/driver/common/db/repositories/Repository.scala4
-rw-r--r--src/main/scala/xyz/driver/common/db/repositories/RepositoryLogging.scala62
14 files changed, 1133 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/common/db/DbCommand.scala b/src/main/scala/xyz/driver/common/db/DbCommand.scala
new file mode 100644
index 0000000..fec8b9f
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/DbCommand.scala
@@ -0,0 +1,15 @@
+package xyz.driver.common.db
+
+import scala.concurrent.Future
+
+trait DbCommand {
+ def runSync(): Unit
+ def runAsync(transactions: Transactions): Future[Unit]
+}
+
+object DbCommand {
+ val Empty: DbCommand = new DbCommand {
+ override def runSync(): Unit = {}
+ override def runAsync(transactions: Transactions): Future[Unit] = Future.successful(())
+ }
+}
diff --git a/src/main/scala/xyz/driver/common/db/DbCommandFactory.scala b/src/main/scala/xyz/driver/common/db/DbCommandFactory.scala
new file mode 100644
index 0000000..84c1383
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/DbCommandFactory.scala
@@ -0,0 +1,14 @@
+package xyz.driver.common.db
+
+import scala.concurrent.{ExecutionContext, Future}
+
+trait DbCommandFactory[T] {
+ def createCommand(orig: T)(implicit ec: ExecutionContext): Future[DbCommand]
+}
+
+object DbCommandFactory {
+ def empty[T]: DbCommandFactory[T] = new DbCommandFactory[T] {
+ override def createCommand(orig: T)(implicit ec: ExecutionContext): Future[DbCommand] = Future.successful(DbCommand.Empty)
+ }
+}
+
diff --git a/src/main/scala/xyz/driver/common/db/EntityExtractorDerivation.scala b/src/main/scala/xyz/driver/common/db/EntityExtractorDerivation.scala
new file mode 100644
index 0000000..0396ea5
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/EntityExtractorDerivation.scala
@@ -0,0 +1,71 @@
+package xyz.driver.common.db
+
+import java.sql.ResultSet
+
+import io.getquill.NamingStrategy
+import io.getquill.dsl.EncodingDsl
+
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox
+
+trait EntityExtractorDerivation[Naming <: NamingStrategy] {
+ this: EncodingDsl =>
+
+ /**
+ * Simple Quill extractor derivation for [[T]]
+ * Only case classes available. Type parameters is not supported
+ *
+ * @tparam T
+ * @return
+ */
+ def entityExtractor[T]: (ResultSet => T) = macro EntityExtractorDerivation.impl[T]
+}
+
+object EntityExtractorDerivation {
+ def impl[T: c.WeakTypeTag](c: blackbox.Context): c.Tree = {
+ import c.universe._
+ val namingStrategy = c.prefix.actualType
+ .baseType(c.weakTypeOf[EntityExtractorDerivation[NamingStrategy]].typeSymbol)
+ .typeArgs
+ .head
+ .typeSymbol
+ .companion
+ val functionBody = {
+ val tpe = weakTypeOf[T]
+ val resultOpt = tpe.decls.collectFirst {
+ // Find first constructor of T
+ case cons: MethodSymbol if cons.isConstructor =>
+ // Create param list for constructor
+ val params = cons.paramLists.flatten.map { param =>
+ val t = param.typeSignature
+ val paramName = param.name.toString
+ val col = q"$namingStrategy.column($paramName)"
+ // Resolve implicit decoders (from SqlContext) and apply ResultSet for each
+ val d = q"implicitly[${c.prefix}.Decoder[$t]]"
+ // Minus 1 cause Quill JDBC decoders make plus one.
+ // ¯\_(ツ)_/¯
+ val i = q"row.findColumn($col) - 1"
+ val decoderName = TermName(paramName + "Decoder")
+ val valueName = TermName(paramName + "Value")
+ (
+ q"val $decoderName = $d",
+ q"val $valueName = $decoderName($i, row)",
+ valueName
+ )
+ }
+ // Call constructor with param list
+ q"""
+ ..${params.map(_._1)}
+ ..${params.map(_._2)}
+ new $tpe(..${params.map(_._3)})
+ """
+ }
+ resultOpt match {
+ case Some(result) => result
+ case None => c.abort(c.enclosingPosition,
+ s"Can not derive extractor for $tpe. Constructor not found.")
+ }
+ }
+ q"(row: java.sql.ResultSet) => $functionBody"
+ }
+}
diff --git a/src/main/scala/xyz/driver/common/db/EntityNotFoundException.scala b/src/main/scala/xyz/driver/common/db/EntityNotFoundException.scala
new file mode 100644
index 0000000..d4c11ac
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/EntityNotFoundException.scala
@@ -0,0 +1,10 @@
+package xyz.driver.common.db
+
+import xyz.driver.common.domain.Id
+
+class EntityNotFoundException private(id: String, tableName: String)
+ extends RuntimeException(s"Entity with id $id is not found in $tableName table") {
+
+ def this(id: Id[_], tableName: String) = this(id.toString, tableName)
+ def this(id: Long, tableName: String) = this(id.toString, tableName)
+}
diff --git a/src/main/scala/xyz/driver/common/db/MysqlQueryBuilder.scala b/src/main/scala/xyz/driver/common/db/MysqlQueryBuilder.scala
new file mode 100644
index 0000000..d6b53d9
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/MysqlQueryBuilder.scala
@@ -0,0 +1,90 @@
+package xyz.driver.common.db
+
+import java.sql.ResultSet
+
+import io.getquill.{MySQLDialect, MysqlEscape}
+
+import scala.collection.breakOut
+import scala.concurrent.{ExecutionContext, Future}
+
+object MysqlQueryBuilder {
+ import xyz.driver.common.db.QueryBuilder._
+
+ def apply[T](tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[TableLink],
+ runner: Runner[T],
+ countRunner: CountRunner)
+ (implicit ec: ExecutionContext): MysqlQueryBuilder[T] = {
+ val parameters = MysqlQueryBuilderParameters(
+ tableData = TableData(tableName, lastUpdateFieldName, nullableFields),
+ links = links.map(x => x.foreignTableName -> x)(breakOut)
+ )
+ new MysqlQueryBuilder[T](parameters)(runner, countRunner, ec)
+ }
+
+ def apply[T](tableName: String,
+ lastUpdateFieldName: Option[String],
+ nullableFields: Set[String],
+ links: Set[TableLink],
+ extractor: (ResultSet) => T)
+ (implicit sqlContext: SqlContext): MysqlQueryBuilder[T] = {
+
+ val runner = (parameters: QueryBuilderParameters) => {
+ Future {
+ val (sql, binder) = parameters.toSql(namingStrategy = MysqlEscape)
+ sqlContext.executeQuery[T](sql, binder, { resultSet =>
+ extractor(resultSet)
+ })
+ }(sqlContext.executionContext)
+ }
+
+ val countRunner = (parameters: QueryBuilderParameters) => {
+ Future {
+ val (sql, binder) = parameters.toSql(countQuery = true, namingStrategy = MysqlEscape)
+ sqlContext.executeQuery[CountResult](sql, binder, { resultSet =>
+ val count = resultSet.getInt(1)
+ val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) {
+ Option(sqlContext.localDateTimeDecoder.decoder(2, resultSet))
+ } else None
+
+ (count, lastUpdate)
+ }).head
+ }(sqlContext.executionContext)
+ }
+
+ apply[T](
+ tableName = tableName,
+ lastUpdateFieldName = lastUpdateFieldName,
+ nullableFields = nullableFields,
+ links = links,
+ runner = runner,
+ countRunner = countRunner
+ )(sqlContext.executionContext)
+ }
+}
+
+class MysqlQueryBuilder[T](parameters: MysqlQueryBuilderParameters)
+ (implicit runner: QueryBuilder.Runner[T],
+ countRunner: QueryBuilder.CountRunner,
+ ec: ExecutionContext)
+ extends QueryBuilder[T, MySQLDialect, MysqlEscape](parameters) {
+
+ def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, MySQLDialect, MysqlEscape] = {
+ new MysqlQueryBuilder[T](parameters.copy(filter = newFilter))
+ }
+
+ def withSorting(newSorting: Sorting): QueryBuilder[T, MySQLDialect, MysqlEscape] = {
+ new MysqlQueryBuilder[T](parameters.copy(sorting = newSorting))
+ }
+
+ def withPagination(newPagination: Pagination): QueryBuilder[T, MySQLDialect, MysqlEscape] = {
+ new MysqlQueryBuilder[T](parameters.copy(pagination = Some(newPagination)))
+ }
+
+ def resetPagination: QueryBuilder[T, MySQLDialect, MysqlEscape] = {
+ new MysqlQueryBuilder[T](parameters.copy(pagination = None))
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/Pagination.scala b/src/main/scala/xyz/driver/common/db/Pagination.scala
new file mode 100644
index 0000000..d4a96d3
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/Pagination.scala
@@ -0,0 +1,20 @@
+package xyz.driver.common.db
+
+import xyz.driver.common.logging._
+
+/**
+ * @param pageNumber Starts with 1
+ */
+case class Pagination(pageSize: Int, pageNumber: Int)
+
+object Pagination {
+
+ // @see https://driverinc.atlassian.net/wiki/display/RA/REST+API+Specification#RESTAPISpecification-CommonRequestQueryParametersForWebServices
+ val Default = Pagination(pageSize = 100, pageNumber = 1)
+
+ implicit def toPhiString(x: Pagination): PhiString = {
+ import x._
+ phi"Pagination(pageSize=${Unsafe(pageSize)}, pageNumber=${Unsafe(pageNumber)})"
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/QueryBuilder.scala b/src/main/scala/xyz/driver/common/db/QueryBuilder.scala
new file mode 100644
index 0000000..f0beca6
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/QueryBuilder.scala
@@ -0,0 +1,344 @@
+package xyz.driver.common.db
+
+import java.sql.PreparedStatement
+import java.time.LocalDateTime
+
+import io.getquill.NamingStrategy
+import io.getquill.context.sql.idiom.SqlIdiom
+import xyz.driver.common.db.Sorting.{Dimension, Sequential}
+import xyz.driver.common.db.SortingOrder.{Ascending, Descending}
+
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.{ExecutionContext, Future}
+
+object QueryBuilder {
+
+ type Runner[T] = (QueryBuilderParameters) => Future[Seq[T]]
+
+ type CountResult = (Int, Option[LocalDateTime])
+
+ type CountRunner = (QueryBuilderParameters) => Future[CountResult]
+
+ /**
+ * Binder for PreparedStatement
+ */
+ type Binder = PreparedStatement => PreparedStatement
+
+ case class TableData(tableName: String,
+ lastUpdateFieldName: Option[String] = None,
+ nullableFields: Set[String] = Set.empty)
+
+ val AllFields = Set("*")
+
+}
+
+case class TableLink(keyColumnName: String,
+ foreignTableName: String,
+ foreignKeyColumnName: String)
+
+object QueryBuilderParameters {
+ val AllFields = Set("*")
+}
+
+sealed trait QueryBuilderParameters {
+
+ def tableData: QueryBuilder.TableData
+ def links: Map[String, TableLink]
+ def filter: SearchFilterExpr
+ def sorting: Sorting
+ def pagination: Option[Pagination]
+
+ def findLink(tableName: String): TableLink = links.get(tableName) match {
+ case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`")
+ case Some(link) => link
+ }
+
+ def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
+ toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy)
+ }
+
+ def toSql(countQuery: Boolean,
+ fields: Set[String],
+ namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = {
+ val escapedTableName = namingStrategy.table(tableData.tableName)
+ val fieldsSql: String = if (countQuery) {
+ "count(*)" + (tableData.lastUpdateFieldName match {
+ case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})"
+ case None => ""
+ })
+ } else {
+ if (fields == QueryBuilderParameters.AllFields) {
+ s"$escapedTableName.*"
+ } else {
+ fields
+ .map { field =>
+ s"$escapedTableName.${namingStrategy.column(field)}"
+ }
+ .mkString(", ")
+ }
+ }
+ val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy)
+ val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy)
+
+ val limitSql = limitToSql()
+
+ val sql = new StringBuilder()
+ sql.append("select ")
+ sql.append(fieldsSql)
+ sql.append("\nfrom ")
+ sql.append(escapedTableName)
+
+ val filtersTableLinks: Seq[TableLink] = {
+ import SearchFilterExpr._
+ def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match {
+ case Atom.TableName(tableName) => List(findLink(tableName))
+ case Intersection(xs) => xs.flatMap(aux)
+ case Union(xs) => xs.flatMap(aux)
+ case _ => Nil
+ }
+ aux(filter)
+ }
+
+ val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) {
+ case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName)
+ }
+
+ // Combine links from sorting and filter without duplicates
+ val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct
+
+ foreignTableLinks.foreach {
+ case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) =>
+ val escapedForeignTableName = namingStrategy.table(foreignTableName)
+
+ sql.append("\ninner join ")
+ sql.append(escapedForeignTableName)
+ sql.append(" on ")
+
+ sql.append(escapedTableName)
+ sql.append('.')
+ sql.append(namingStrategy.column(keyColumnName))
+
+ sql.append(" = ")
+
+ sql.append(escapedForeignTableName)
+ sql.append('.')
+ sql.append(namingStrategy.column(foreignKeyColumnName))
+ }
+
+ if (where.nonEmpty) {
+ sql.append("\nwhere ")
+ sql.append(where)
+ }
+
+ if (orderBy.nonEmpty && !countQuery) {
+ sql.append("\norder by ")
+ sql.append(orderBy)
+ }
+
+ if (limitSql.nonEmpty && !countQuery) {
+ sql.append("\n")
+ sql.append(limitSql)
+ }
+
+ (sql.toString, binder(bindings))
+ }
+
+ /**
+ * Converts filter expression to SQL expression.
+ *
+ * @return Returns SQL string and list of values for binding in prepared statement.
+ */
+ protected def filterToSql(escapedTableName: String,
+ filter: SearchFilterExpr,
+ namingStrategy: NamingStrategy): (String, List[AnyRef]) = {
+ import SearchFilterBinaryOperation._
+ import SearchFilterExpr._
+
+ def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null"
+
+ def placeholder(field: String) = "?"
+
+ def escapeDimension(dimension: SearchFilterExpr.Dimension) = {
+ val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table)
+ s"$tableName.${namingStrategy.column(dimension.name)}"
+ }
+
+ def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect {
+ case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy)
+ }
+
+ filter match {
+ case x if isEmpty(x) =>
+ ("", List.empty)
+
+ case AllowAll =>
+ ("1", List.empty)
+
+ case DenyAll =>
+ ("0", List.empty)
+
+ case Atom.Binary(dimension, Eq, value) if isNull(value) =>
+ (s"${escapeDimension(dimension)} is NULL", List.empty)
+
+ case Atom.Binary(dimension, NotEq, value) if isNull(value) =>
+ (s"${escapeDimension(dimension)} is not NULL", List.empty)
+
+ case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) =>
+ // In MySQL NULL <> Any === NULL
+ // So, to handle NotEq for nullable fields we need to use more complex SQL expression.
+ // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html
+ val escapedColumn = escapeDimension(dimension)
+ val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})"
+ (sql, List(value))
+
+ case Atom.Binary(dimension, op, value) =>
+ val operator = op match {
+ case Eq => "="
+ case NotEq => "!="
+ case Like => "like"
+ case Gt => ">"
+ case GtEq => ">="
+ case Lt => "<"
+ case LtEq => "<="
+ }
+ (s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value))
+
+ case Atom.NAry(dimension, op, values) =>
+ val sqlOp = op match {
+ case SearchFilterNAryOperation.In => "in"
+ case SearchFilterNAryOperation.NotIn => "not in"
+ }
+
+ val bindings = ListBuffer[AnyRef]()
+ val sqlPlaceholder = placeholder(dimension.name)
+ val formattedValues = values.map { value =>
+ bindings += value
+ sqlPlaceholder
+ }.mkString(", ")
+ (s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList)
+
+ case Intersection(operands) =>
+ val (sql, bindings) = filterToSqlMultiple(operands).unzip
+ (sql.mkString("(", " and ", ")"), bindings.flatten.toList)
+
+ case Union(operands) =>
+ val (sql, bindings) = filterToSqlMultiple(operands).unzip
+ (sql.mkString("(", " or ", ")"), bindings.flatten.toList)
+ }
+ }
+
+ protected def limitToSql(): String
+
+ /**
+ * @param escapedMainTableName Should be escaped
+ */
+ protected def sortingToSql(escapedMainTableName: String,
+ sorting: Sorting,
+ namingStrategy: NamingStrategy): String = {
+ sorting match {
+ case Dimension(optSortingTableName, field, order) =>
+ val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName)
+ val fullName = s"$sortingTableName.${namingStrategy.column(field)}"
+
+ s"$fullName ${orderToSql(order)}"
+
+ case Sequential(xs) =>
+ xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ")
+ }
+ }
+
+ protected def orderToSql(x: SortingOrder): String = x match {
+ case Ascending => "asc"
+ case Descending => "desc"
+ }
+
+ protected def binder(bindings: List[AnyRef])
+ (bind: PreparedStatement): PreparedStatement = {
+ bindings.zipWithIndex.foreach { case (binding, index) =>
+ bind.setObject(index + 1, binding)
+ }
+
+ bind
+ }
+
+}
+
+case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData,
+ links: Map[String, TableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None) extends QueryBuilderParameters {
+
+ def limitToSql(): String = {
+ pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ s"limit ${pagination.pageSize} OFFSET $startFrom"
+ } getOrElse ""
+ }
+
+}
+
+/**
+ * @param links Links to another tables grouped by foreignTableName
+ */
+case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData,
+ links: Map[String, TableLink] = Map.empty,
+ filter: SearchFilterExpr = SearchFilterExpr.Empty,
+ sorting: Sorting = Sorting.Empty,
+ pagination: Option[Pagination] = None) extends QueryBuilderParameters {
+
+ def limitToSql(): String = pagination.map { pagination =>
+ val startFrom = (pagination.pageNumber - 1) * pagination.pageSize
+ s"limit $startFrom, ${pagination.pageSize}"
+ }.getOrElse("")
+
+}
+
+abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters)
+ (implicit runner: QueryBuilder.Runner[T],
+ countRunner: QueryBuilder.CountRunner,
+ ec: ExecutionContext) {
+
+ def run: Future[Seq[T]] = runner(parameters)
+
+ def runCount: Future[QueryBuilder.CountResult] = countRunner(parameters)
+
+ /**
+ * Runs the query and returns total found rows without considering of pagination.
+ */
+ def runWithCount: Future[(Seq[T], Int, Option[LocalDateTime])] = {
+ val countFuture = runCount
+ val selectAllFuture = run
+ for {
+ (total, lastUpdate) <- countFuture
+ all <- selectAllFuture
+ } yield (all, total, lastUpdate)
+ }
+
+ def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N]
+
+ def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = {
+ filter.fold(this)(withFilter)
+ }
+
+ def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty)
+
+
+ def withSorting(newSorting: Sorting): QueryBuilder[T, D, N]
+
+ def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = {
+ sorting.fold(this)(withSorting)
+ }
+
+ def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty)
+
+
+ def withPagination(newPagination: Pagination): QueryBuilder[T, D, N]
+
+ def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = {
+ pagination.fold(this)(withPagination)
+ }
+
+ def resetPagination: QueryBuilder[T, D, N]
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/SearchFilterExpr.scala b/src/main/scala/xyz/driver/common/db/SearchFilterExpr.scala
new file mode 100644
index 0000000..06b21cd
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/SearchFilterExpr.scala
@@ -0,0 +1,210 @@
+package xyz.driver.common.db
+
+import xyz.driver.common.logging._
+
+sealed trait SearchFilterExpr {
+ def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr]
+ def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr
+}
+
+object SearchFilterExpr {
+
+ val Empty = Intersection.Empty
+ val Forbid = Atom.Binary(
+ dimension = Dimension(None, "true"),
+ op = SearchFilterBinaryOperation.Eq,
+ value = "false"
+ )
+
+ case class Dimension(tableName: Option[String], name: String) {
+ def isForeign: Boolean = tableName.isDefined
+ }
+
+ sealed trait Atom extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ if (p(this)) Some(this)
+ else None
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ object Atom {
+ case class Binary(dimension: Dimension, op: SearchFilterBinaryOperation, value: AnyRef) extends Atom
+ object Binary {
+ def apply(field: String, op: SearchFilterBinaryOperation, value: AnyRef): Binary =
+ Binary(Dimension(None, field), op, value)
+ }
+
+ case class NAry(dimension: Dimension, op: SearchFilterNAryOperation, values: Seq[AnyRef]) extends Atom
+ object NAry {
+ def apply(field: String, op: SearchFilterNAryOperation, values: Seq[AnyRef]): NAry =
+ NAry(Dimension(None, field), op, values)
+ }
+
+ /** dimension.tableName extractor */
+ object TableName {
+ def unapply(value: Atom): Option[String] = value match {
+ case Binary(Dimension(tableNameOpt, _), _, _) => tableNameOpt
+ case NAry(Dimension(tableNameOpt, _), _, _) => tableNameOpt
+ }
+ }
+ }
+
+ case class Intersection private(operands: Seq[SearchFilterExpr])
+ extends SearchFilterExpr with SearchFilterExprSeqOps {
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else {
+ this.copy(operands.map(_.replace(f)))
+ }
+ }
+
+ }
+
+ object Intersection {
+
+ val Empty = Intersection(Seq())
+
+ def create(operands: SearchFilterExpr*): SearchFilterExpr = {
+ val filtered = operands.filterNot(SearchFilterExpr.isEmpty)
+ filtered.size match {
+ case 0 => Empty
+ case 1 => filtered.head
+ case _ => Intersection(filtered)
+ }
+ }
+ }
+
+
+ case class Union private(operands: Seq[SearchFilterExpr]) extends SearchFilterExpr with SearchFilterExprSeqOps {
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else {
+ this.copy(operands.map(_.replace(f)))
+ }
+ }
+
+ }
+
+ object Union {
+
+ val Empty = Union(Seq())
+
+ def create(operands: SearchFilterExpr*): SearchFilterExpr = {
+ val filtered = operands.filterNot(SearchFilterExpr.isEmpty)
+ filtered.size match {
+ case 0 => Empty
+ case 1 => filtered.head
+ case _ => Union(filtered)
+ }
+ }
+
+ def create(dimension: Dimension, values: String*): SearchFilterExpr = values.size match {
+ case 0 => SearchFilterExpr.Empty
+ case 1 => SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, values.head)
+ case _ =>
+ val filters = values.map { value =>
+ SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, value)
+ }
+
+ create(filters: _*)
+ }
+
+ def create(dimension: Dimension, values: Set[String]): SearchFilterExpr =
+ create(dimension, values.toSeq: _*)
+
+ // Backwards compatible API
+
+ /** Create SearchFilterExpr with empty tableName */
+ def create(field: String, values: String*): SearchFilterExpr =
+ create(Dimension(None, field), values:_*)
+
+ /** Create SearchFilterExpr with empty tableName */
+ def create(field: String, values: Set[String]): SearchFilterExpr =
+ create(Dimension(None, field), values)
+ }
+
+
+ case object AllowAll extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ if (p(this)) Some(this)
+ else None
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ case object DenyAll extends SearchFilterExpr {
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ if (p(this)) Some(this)
+ else None
+ }
+
+ override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = {
+ if (f.isDefinedAt(this)) f(this)
+ else this
+ }
+ }
+
+ def isEmpty(expr: SearchFilterExpr): Boolean = {
+ expr == Intersection.Empty || expr == Union.Empty
+ }
+
+ sealed trait SearchFilterExprSeqOps {
+ this: SearchFilterExpr =>
+
+ val operands: Seq[SearchFilterExpr]
+
+ override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = {
+ if (p(this)) Some(this)
+ else {
+ // Search the first expr among operands, which satisfy p
+ // Is's ok to use foldLeft. If there will be performance issues, replace it by recursive loop
+ operands.foldLeft(Option.empty[SearchFilterExpr]) {
+ case (None, expr) => expr.find(p)
+ case (x, _) => x
+ }
+ }
+ }
+
+ }
+
+ // There is no case, when this is unsafe. At this time.
+ implicit def toPhiString(x: SearchFilterExpr): PhiString = {
+ if (isEmpty(x)) Unsafe("SearchFilterExpr.Empty")
+ else Unsafe(x.toString)
+ }
+
+}
+
+sealed trait SearchFilterBinaryOperation
+
+object SearchFilterBinaryOperation {
+
+ case object Eq extends SearchFilterBinaryOperation
+ case object NotEq extends SearchFilterBinaryOperation
+ case object Like extends SearchFilterBinaryOperation
+ case object Gt extends SearchFilterBinaryOperation
+ case object GtEq extends SearchFilterBinaryOperation
+ case object Lt extends SearchFilterBinaryOperation
+ case object LtEq extends SearchFilterBinaryOperation
+
+}
+
+sealed trait SearchFilterNAryOperation
+
+object SearchFilterNAryOperation {
+
+ case object In extends SearchFilterNAryOperation
+ case object NotIn extends SearchFilterNAryOperation
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/Sorting.scala b/src/main/scala/xyz/driver/common/db/Sorting.scala
new file mode 100644
index 0000000..70c25f2
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/Sorting.scala
@@ -0,0 +1,62 @@
+package xyz.driver.common.db
+
+import xyz.driver.common.logging._
+
+import scala.collection.generic.CanBuildFrom
+
+sealed trait SortingOrder
+object SortingOrder {
+
+ case object Ascending extends SortingOrder
+ case object Descending extends SortingOrder
+
+}
+
+sealed trait Sorting
+
+object Sorting {
+
+ val Empty = Sequential(Seq.empty)
+
+ /**
+ * @param tableName None if the table is default (same)
+ * @param name Dimension name
+ * @param order Order
+ */
+ case class Dimension(tableName: Option[String], name: String, order: SortingOrder) extends Sorting {
+ def isForeign: Boolean = tableName.isDefined
+ }
+
+ case class Sequential(sorting: Seq[Dimension]) extends Sorting {
+ override def toString: String = if (isEmpty(this)) "Empty" else super.toString
+ }
+
+ def isEmpty(input: Sorting): Boolean = {
+ input match {
+ case Sequential(Seq()) => true
+ case _ => false
+ }
+ }
+
+ def filter(sorting: Sorting, p: Dimension => Boolean): Seq[Dimension] = sorting match {
+ case x: Dimension if p(x) => Seq(x)
+ case x: Dimension => Seq.empty
+ case Sequential(xs) => xs.filter(p)
+ }
+
+ def collect[B, That](sorting: Sorting)
+ (f: PartialFunction[Dimension, B])
+ (implicit bf: CanBuildFrom[Seq[Dimension], B, That]): That = sorting match {
+ case x: Dimension if f.isDefinedAt(x) =>
+ val r = bf.apply()
+ r += f(x)
+ r.result()
+
+ case x: Dimension => bf.apply().result()
+ case Sequential(xs) => xs.collect(f)
+ }
+
+ // Contains dimensions and ordering only, thus it is safe.
+ implicit def toPhiString(x: Sorting): PhiString = Unsafe(x.toString)
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/SqlContext.scala b/src/main/scala/xyz/driver/common/db/SqlContext.scala
new file mode 100644
index 0000000..4b9d676
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/SqlContext.scala
@@ -0,0 +1,184 @@
+package xyz.driver.common.db
+
+import java.io.Closeable
+import java.net.URI
+import java.time._
+import java.util.UUID
+import java.util.concurrent.Executors
+import javax.sql.DataSource
+
+import xyz.driver.common.logging.{PhiLogging, Unsafe}
+import xyz.driver.common.concurrent.MdcExecutionContext
+import xyz.driver.common.db.SqlContext.Settings
+import xyz.driver.common.domain._
+import xyz.driver.common.error.IncorrectIdException
+import xyz.driver.common.utils.JsonSerializer
+import com.typesafe.config.Config
+import io.getquill._
+
+import scala.concurrent.ExecutionContext
+import scala.util.control.NonFatal
+import scala.util.{Failure, Success, Try}
+
+object SqlContext extends PhiLogging {
+
+ case class DbCredentials(user: String,
+ password: String,
+ host: String,
+ port: Int,
+ dbName: String,
+ dbCreateFlag: Boolean,
+ dbContext: String,
+ connectionParams: String,
+ url: String)
+
+ case class Settings(credentials: DbCredentials,
+ connection: Config,
+ connectionAttemptsOnStartup: Int,
+ threadPoolSize: Int)
+
+ def apply(settings: Settings): SqlContext = {
+ // Prevent leaking credentials to a log
+ Try(JdbcContextConfig(settings.connection).dataSource) match {
+ case Success(dataSource) => new SqlContext(dataSource, settings)
+ case Failure(NonFatal(e)) =>
+ logger.error(phi"Can not load dataSource, error: ${Unsafe(e.getClass.getName)}")
+ throw new IllegalArgumentException("Can not load dataSource from config. Check your database and config")
+ }
+ }
+
+}
+
+class SqlContext(dataSource: DataSource with Closeable, settings: Settings)
+ extends MysqlJdbcContext[MysqlEscape](dataSource)
+ with EntityExtractorDerivation[Literal] {
+
+ private val tpe = Executors.newFixedThreadPool(settings.threadPoolSize)
+
+ implicit val executionContext: ExecutionContext = {
+ val orig = ExecutionContext.fromExecutor(tpe)
+ MdcExecutionContext.from(orig)
+ }
+
+ override def close(): Unit = {
+ super.close()
+ tpe.shutdownNow()
+ }
+
+ // ///////// Encodes/Decoders ///////////
+
+ /**
+ * Overrode, because Quill JDBC optionDecoder pass null inside decoders.
+ * If custom decoder don't have special null handler, it will failed.
+ *
+ * @see https://github.com/getquill/quill/issues/535
+ */
+ implicit override def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] =
+ decoder(
+ sqlType = d.sqlType,
+ row => index => {
+ try {
+ val res = d(index - 1, row)
+ if (row.wasNull) {
+ None
+ }
+ else {
+ Some(res)
+ }
+ } catch {
+ case _: NullPointerException => None
+ case _: IncorrectIdException => None
+ }
+ }
+ )
+
+ implicit def encodeStringId[T] = MappedEncoding[StringId[T], String](_.id)
+ implicit def decodeStringId[T] = MappedEncoding[String, StringId[T]] {
+ case "" => throw IncorrectIdException("'' is an invalid Id value")
+ case x => StringId(x)
+ }
+
+ def decodeOptStringId[T] = MappedEncoding[Option[String], Option[StringId[T]]] {
+ case None | Some("") => None
+ case Some(x) => Some(StringId(x))
+ }
+
+ implicit def encodeLongId[T] = MappedEncoding[LongId[T], Long](_.id)
+ implicit def decodeLongId[T] = MappedEncoding[Long, LongId[T]] {
+ case 0 => throw IncorrectIdException("0 is an invalid Id value")
+ case x => LongId(x)
+ }
+
+ // TODO Dirty hack, see REP-475
+ def decodeOptLongId[T] = MappedEncoding[Option[Long], Option[LongId[T]]] {
+ case None | Some(0) => None
+ case Some(x) => Some(LongId(x))
+ }
+
+ implicit def encodeUuidId[T] = MappedEncoding[UuidId[T], String](_.toString)
+ implicit def decodeUuidId[T] = MappedEncoding[String, UuidId[T]] {
+ case "" => throw IncorrectIdException("'' is an invalid Id value")
+ case x => UuidId(x)
+ }
+
+ def decodeOptUuidId[T] = MappedEncoding[Option[String], Option[UuidId[T]]] {
+ case None | Some("") => None
+ case Some(x) => Some(UuidId(x))
+ }
+
+ implicit def encodeTextJson[T: Manifest] = MappedEncoding[TextJson[T], String](x => JsonSerializer.serialize(x.content))
+ implicit def decodeTextJson[T: Manifest] = MappedEncoding[String, TextJson[T]] { x =>
+ TextJson(JsonSerializer.deserialize[T](x))
+ }
+
+ implicit val encodeUserRole = MappedEncoding[User.Role, Int](_.bit)
+ implicit val decodeUserRole = MappedEncoding[Int, User.Role] {
+ // 0 is treated as null for numeric types
+ case 0 => throw new NullPointerException("0 means no roles. A user must have a role")
+ case x => User.Role(x)
+ }
+
+ implicit val encodeEmail = MappedEncoding[Email, String](_.value.toString)
+ implicit val decodeEmail = MappedEncoding[String, Email](Email)
+
+ implicit val encodePasswordHash = MappedEncoding[PasswordHash, Array[Byte]](_.value)
+ implicit val decodePasswordHash = MappedEncoding[Array[Byte], PasswordHash](PasswordHash(_))
+
+ implicit val encodeUri = MappedEncoding[URI, String](_.toString)
+ implicit val decodeUri = MappedEncoding[String, URI](URI.create)
+
+ implicit val encodeCaseId = MappedEncoding[CaseId, String](_.id.toString)
+ implicit val decodeCaseId = MappedEncoding[String, CaseId](CaseId(_))
+
+ implicit val encodeFuzzyValue = {
+ MappedEncoding[FuzzyValue, String] {
+ case FuzzyValue.No => "No"
+ case FuzzyValue.Yes => "Yes"
+ case FuzzyValue.Maybe => "Maybe"
+ }
+ }
+ implicit val decodeFuzzyValue = MappedEncoding[String, FuzzyValue] {
+ case "Yes" => FuzzyValue.Yes
+ case "No" => FuzzyValue.No
+ case "Maybe" => FuzzyValue.Maybe
+ case x =>
+ Option(x).fold {
+ throw new NullPointerException("FuzzyValue is null") // See catch in optionDecoder
+ } { _ =>
+ throw new IllegalStateException(s"Unknown fuzzy value: $x")
+ }
+ }
+
+
+ implicit val encodeRecordRequestId = MappedEncoding[RecordRequestId, String](_.id.toString)
+ implicit val decodeRecordRequestId = MappedEncoding[String, RecordRequestId] { x =>
+ RecordRequestId(UUID.fromString(x))
+ }
+
+ final implicit class LocalDateTimeDbOps(val left: LocalDateTime) {
+
+ // scalastyle:off
+ def <=(right: LocalDateTime): Quoted[Boolean] = quote(infix"$left <= $right".as[Boolean])
+ }
+
+}
diff --git a/src/main/scala/xyz/driver/common/db/Transactions.scala b/src/main/scala/xyz/driver/common/db/Transactions.scala
new file mode 100644
index 0000000..2f5a2cc
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/Transactions.scala
@@ -0,0 +1,23 @@
+package xyz.driver.common.db
+
+import xyz.driver.common.logging.PhiLogging
+
+import scala.concurrent.Future
+import scala.util.{Failure, Success, Try}
+
+class Transactions()(implicit context: SqlContext) extends PhiLogging {
+ def run[T](f: SqlContext => T): Future[T] = {
+ import context.executionContext
+
+ Future(context.transaction(f(context))).andThen {
+ case Failure(e) => logger.error(phi"Can't run a transaction: $e")
+ }
+ }
+
+ def runSync[T](f: SqlContext => T): Unit = {
+ Try(context.transaction(f(context))) match {
+ case Success(_) =>
+ case Failure(e) => logger.error(phi"Can't run a transaction: $e")
+ }
+ }
+}
diff --git a/src/main/scala/xyz/driver/common/db/repositories/BridgeUploadQueueRepository.scala b/src/main/scala/xyz/driver/common/db/repositories/BridgeUploadQueueRepository.scala
new file mode 100644
index 0000000..e0d6ff2
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/repositories/BridgeUploadQueueRepository.scala
@@ -0,0 +1,24 @@
+package xyz.driver.common.db.repositories
+
+import xyz.driver.common.concurrent.BridgeUploadQueue
+import xyz.driver.common.domain.LongId
+
+import scala.concurrent.Future
+
+trait BridgeUploadQueueRepository extends Repository {
+
+ type EntityT = BridgeUploadQueue.Item
+ type IdT = LongId[EntityT]
+
+ def add(draft: EntityT): EntityT
+
+ def getById(id: LongId[EntityT]): Option[EntityT]
+
+ def isCompleted(kind: String, tag: String): Future[Boolean]
+
+ def getOne(kind: String): Future[Option[BridgeUploadQueue.Item]]
+
+ def update(entity: EntityT): EntityT
+
+ def delete(id: IdT): Unit
+}
diff --git a/src/main/scala/xyz/driver/common/db/repositories/Repository.scala b/src/main/scala/xyz/driver/common/db/repositories/Repository.scala
new file mode 100644
index 0000000..ae2a3e6
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/repositories/Repository.scala
@@ -0,0 +1,4 @@
+package xyz.driver.common.db.repositories
+
+// For further usage and migration to Postgres and slick
+trait Repository extends RepositoryLogging
diff --git a/src/main/scala/xyz/driver/common/db/repositories/RepositoryLogging.scala b/src/main/scala/xyz/driver/common/db/repositories/RepositoryLogging.scala
new file mode 100644
index 0000000..cb2c438
--- /dev/null
+++ b/src/main/scala/xyz/driver/common/db/repositories/RepositoryLogging.scala
@@ -0,0 +1,62 @@
+package xyz.driver.common.db.repositories
+
+import xyz.driver.common.logging._
+
+trait RepositoryLogging extends PhiLogging {
+
+ protected def logCreatedOne[T](x: T)(implicit toPhiString: T => PhiString): T = {
+ logger.info(phi"An entity was created: $x")
+ x
+ }
+
+ protected def logCreatedMultiple[T <: Iterable[_]](xs: T)(implicit toPhiString: T => PhiString): T = {
+ if (xs.nonEmpty) {
+ logger.info(phi"Entities were created: $xs")
+ }
+ xs
+ }
+
+ protected def logUpdatedOne(rowsAffected: Long): Long = {
+ rowsAffected match {
+ case 0 => logger.trace(phi"The entity is up to date")
+ case 1 => logger.info(phi"The entity was updated")
+ case x => logger.warn(phi"The ${Unsafe(x)} entities were updated")
+ }
+ rowsAffected
+ }
+
+ protected def logUpdatedOneUnimportant(rowsAffected: Long): Long = {
+ rowsAffected match {
+ case 0 => logger.trace(phi"The entity is up to date")
+ case 1 => logger.trace(phi"The entity was updated")
+ case x => logger.warn(phi"The ${Unsafe(x)} entities were updated")
+ }
+ rowsAffected
+ }
+
+ protected def logUpdatedMultiple(rowsAffected: Long): Long = {
+ rowsAffected match {
+ case 0 => logger.trace(phi"All entities are up to date")
+ case x => logger.info(phi"The ${Unsafe(x)} entities were updated")
+ }
+ rowsAffected
+ }
+
+ protected def logDeletedOne(rowsAffected: Long): Long = {
+ rowsAffected match {
+ case 0 => logger.trace(phi"The entity does not exist")
+ case 1 => logger.info(phi"The entity was deleted")
+ case x => logger.warn(phi"Deleted ${Unsafe(x)} entities, expected one")
+ }
+ rowsAffected
+ }
+
+ protected def logDeletedMultiple(rowsAffected: Long): Long = {
+ rowsAffected match {
+ case 0 => logger.trace(phi"Entities do not exist")
+ case x => logger.info(phi"Deleted ${Unsafe(x)} entities")
+ }
+ rowsAffected
+ }
+
+}