diff options
Diffstat (limited to 'src/main/scala/xyz/driver/pdsuicommon/db')
14 files changed, 1128 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/DbCommand.scala b/src/main/scala/xyz/driver/pdsuicommon/db/DbCommand.scala new file mode 100644 index 0000000..911ecee --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/DbCommand.scala @@ -0,0 +1,15 @@ +package xyz.driver.pdsuicommon.db + +import scala.concurrent.Future + +trait DbCommand { + def runSync(): Unit + def runAsync(transactions: Transactions): Future[Unit] +} + +object DbCommand { + val Empty: DbCommand = new DbCommand { + override def runSync(): Unit = {} + override def runAsync(transactions: Transactions): Future[Unit] = Future.successful(()) + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/DbCommandFactory.scala b/src/main/scala/xyz/driver/pdsuicommon/db/DbCommandFactory.scala new file mode 100644 index 0000000..f12b437 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/DbCommandFactory.scala @@ -0,0 +1,14 @@ +package xyz.driver.pdsuicommon.db + +import scala.concurrent.{ExecutionContext, Future} + +trait DbCommandFactory[T] { + def createCommand(orig: T)(implicit ec: ExecutionContext): Future[DbCommand] +} + +object DbCommandFactory { + def empty[T]: DbCommandFactory[T] = new DbCommandFactory[T] { + override def createCommand(orig: T)(implicit ec: ExecutionContext): Future[DbCommand] = Future.successful(DbCommand.Empty) + } +} + diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/EntityExtractorDerivation.scala b/src/main/scala/xyz/driver/pdsuicommon/db/EntityExtractorDerivation.scala new file mode 100644 index 0000000..e13ea39 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/EntityExtractorDerivation.scala @@ -0,0 +1,68 @@ +package xyz.driver.pdsuicommon.db + +import java.sql.ResultSet + +import io.getquill.NamingStrategy +import io.getquill.dsl.EncodingDsl + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +trait EntityExtractorDerivation[Naming <: NamingStrategy] { + this: EncodingDsl => + + /** + * Simple Quill extractor derivation for [[T]] + * Only case classes available. Type parameters is not supported + */ + def entityExtractor[T]: (ResultSet => T) = macro EntityExtractorDerivation.impl[T] +} + +object EntityExtractorDerivation { + def impl[T: c.WeakTypeTag](c: blackbox.Context): c.Tree = { + import c.universe._ + val namingStrategy = c.prefix.actualType + .baseType(c.weakTypeOf[EntityExtractorDerivation[NamingStrategy]].typeSymbol) + .typeArgs + .head + .typeSymbol + .companion + val functionBody = { + val tpe = weakTypeOf[T] + val resultOpt = tpe.decls.collectFirst { + // Find first constructor of T + case cons: MethodSymbol if cons.isConstructor => + // Create param list for constructor + val params = cons.paramLists.flatten.map { param => + val t = param.typeSignature + val paramName = param.name.toString + val col = q"$namingStrategy.column($paramName)" + // Resolve implicit decoders (from SqlContext) and apply ResultSet for each + val d = q"implicitly[${c.prefix}.Decoder[$t]]" + // Minus 1 cause Quill JDBC decoders make plus one. + // ¯\_(ツ)_/¯ + val i = q"row.findColumn($col) - 1" + val decoderName = TermName(paramName + "Decoder") + val valueName = TermName(paramName + "Value") + ( + q"val $decoderName = $d", + q"val $valueName = $decoderName($i, row)", + valueName + ) + } + // Call constructor with param list + q""" + ..${params.map(_._1)} + ..${params.map(_._2)} + new $tpe(..${params.map(_._3)}) + """ + } + resultOpt match { + case Some(result) => result + case None => c.abort(c.enclosingPosition, + s"Can not derive extractor for $tpe. Constructor not found.") + } + } + q"(row: java.sql.ResultSet) => $functionBody" + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/EntityNotFoundException.scala b/src/main/scala/xyz/driver/pdsuicommon/db/EntityNotFoundException.scala new file mode 100644 index 0000000..3b3bbdf --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/EntityNotFoundException.scala @@ -0,0 +1,10 @@ +package xyz.driver.pdsuicommon.db + +import xyz.driver.pdsuicommon.domain.Id + +class EntityNotFoundException private(id: String, tableName: String) + extends RuntimeException(s"Entity with id $id is not found in $tableName table") { + + def this(id: Id[_], tableName: String) = this(id.toString, tableName) + def this(id: Long, tableName: String) = this(id.toString, tableName) +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/MysqlQueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/MysqlQueryBuilder.scala new file mode 100644 index 0000000..672a34e --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/MysqlQueryBuilder.scala @@ -0,0 +1,90 @@ +package xyz.driver.pdsuicommon.db + +import java.sql.ResultSet + +import io.getquill.{MySQLDialect, MysqlEscape} + +import scala.collection.breakOut +import scala.concurrent.{ExecutionContext, Future} + +object MysqlQueryBuilder { + import xyz.driver.pdsuicommon.db.QueryBuilder._ + + def apply[T](tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[TableLink], + runner: Runner[T], + countRunner: CountRunner) + (implicit ec: ExecutionContext): MysqlQueryBuilder[T] = { + val parameters = MysqlQueryBuilderParameters( + tableData = TableData(tableName, lastUpdateFieldName, nullableFields), + links = links.map(x => x.foreignTableName -> x)(breakOut) + ) + new MysqlQueryBuilder[T](parameters)(runner, countRunner, ec) + } + + def apply[T](tableName: String, + lastUpdateFieldName: Option[String], + nullableFields: Set[String], + links: Set[TableLink], + extractor: (ResultSet) => T) + (implicit sqlContext: SqlContext): MysqlQueryBuilder[T] = { + + val runner = (parameters: QueryBuilderParameters) => { + Future { + val (sql, binder) = parameters.toSql(namingStrategy = MysqlEscape) + sqlContext.executeQuery[T](sql, binder, { resultSet => + extractor(resultSet) + }) + }(sqlContext.executionContext) + } + + val countRunner = (parameters: QueryBuilderParameters) => { + Future { + val (sql, binder) = parameters.toSql(countQuery = true, namingStrategy = MysqlEscape) + sqlContext.executeQuery[CountResult](sql, binder, { resultSet => + val count = resultSet.getInt(1) + val lastUpdate = if (parameters.tableData.lastUpdateFieldName.isDefined) { + Option(sqlContext.localDateTimeDecoder.decoder(2, resultSet)) + } else None + + (count, lastUpdate) + }).head + }(sqlContext.executionContext) + } + + apply[T]( + tableName = tableName, + lastUpdateFieldName = lastUpdateFieldName, + nullableFields = nullableFields, + links = links, + runner = runner, + countRunner = countRunner + )(sqlContext.executionContext) + } +} + +class MysqlQueryBuilder[T](parameters: MysqlQueryBuilderParameters) + (implicit runner: QueryBuilder.Runner[T], + countRunner: QueryBuilder.CountRunner, + ec: ExecutionContext) + extends QueryBuilder[T, MySQLDialect, MysqlEscape](parameters) { + + def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, MySQLDialect, MysqlEscape] = { + new MysqlQueryBuilder[T](parameters.copy(filter = newFilter)) + } + + def withSorting(newSorting: Sorting): QueryBuilder[T, MySQLDialect, MysqlEscape] = { + new MysqlQueryBuilder[T](parameters.copy(sorting = newSorting)) + } + + def withPagination(newPagination: Pagination): QueryBuilder[T, MySQLDialect, MysqlEscape] = { + new MysqlQueryBuilder[T](parameters.copy(pagination = Some(newPagination))) + } + + def resetPagination: QueryBuilder[T, MySQLDialect, MysqlEscape] = { + new MysqlQueryBuilder[T](parameters.copy(pagination = None)) + } + +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/Pagination.scala b/src/main/scala/xyz/driver/pdsuicommon/db/Pagination.scala new file mode 100644 index 0000000..e72b5c2 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/Pagination.scala @@ -0,0 +1,19 @@ +package xyz.driver.pdsuicommon.db + +import xyz.driver.pdsuicommon.logging._ + +/** + * @param pageNumber Starts with 1 + */ +case class Pagination(pageSize: Int, pageNumber: Int) + +object Pagination { + + // @see https://driverinc.atlassian.net/wiki/display/RA/REST+API+Specification#RESTAPISpecification-CommonRequestQueryParametersForWebServices + val Default = Pagination(pageSize = 100, pageNumber = 1) + + implicit def toPhiString(x: Pagination): PhiString = { + import x._ + phi"Pagination(pageSize=${Unsafe(pageSize)}, pageNumber=${Unsafe(pageNumber)})" + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala b/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala new file mode 100644 index 0000000..9b798d8 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/QueryBuilder.scala @@ -0,0 +1,344 @@ +package xyz.driver.pdsuicommon.db + +import java.sql.PreparedStatement +import java.time.LocalDateTime + +import io.getquill.NamingStrategy +import io.getquill.context.sql.idiom.SqlIdiom +import xyz.driver.pdsuicommon.db.Sorting.{Dimension, Sequential} +import xyz.driver.pdsuicommon.db.SortingOrder.{Ascending, Descending} + +import scala.collection.mutable.ListBuffer +import scala.concurrent.{ExecutionContext, Future} + +object QueryBuilder { + + type Runner[T] = (QueryBuilderParameters) => Future[Seq[T]] + + type CountResult = (Int, Option[LocalDateTime]) + + type CountRunner = (QueryBuilderParameters) => Future[CountResult] + + /** + * Binder for PreparedStatement + */ + type Binder = PreparedStatement => PreparedStatement + + case class TableData(tableName: String, + lastUpdateFieldName: Option[String] = None, + nullableFields: Set[String] = Set.empty) + + val AllFields = Set("*") + +} + +case class TableLink(keyColumnName: String, + foreignTableName: String, + foreignKeyColumnName: String) + +object QueryBuilderParameters { + val AllFields = Set("*") +} + +sealed trait QueryBuilderParameters { + + def tableData: QueryBuilder.TableData + def links: Map[String, TableLink] + def filter: SearchFilterExpr + def sorting: Sorting + def pagination: Option[Pagination] + + def findLink(tableName: String): TableLink = links.get(tableName) match { + case None => throw new IllegalArgumentException(s"Cannot find a link for `$tableName`") + case Some(link) => link + } + + def toSql(countQuery: Boolean = false, namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { + toSql(countQuery, QueryBuilderParameters.AllFields, namingStrategy) + } + + def toSql(countQuery: Boolean, + fields: Set[String], + namingStrategy: NamingStrategy): (String, QueryBuilder.Binder) = { + val escapedTableName = namingStrategy.table(tableData.tableName) + val fieldsSql: String = if (countQuery) { + "count(*)" + (tableData.lastUpdateFieldName match { + case Some(lastUpdateField) => s", max($escapedTableName.${namingStrategy.column(lastUpdateField)})" + case None => "" + }) + } else { + if (fields == QueryBuilderParameters.AllFields) { + s"$escapedTableName.*" + } else { + fields + .map { field => + s"$escapedTableName.${namingStrategy.column(field)}" + } + .mkString(", ") + } + } + val (where, bindings) = filterToSql(escapedTableName, filter, namingStrategy) + val orderBy = sortingToSql(escapedTableName, sorting, namingStrategy) + + val limitSql = limitToSql() + + val sql = new StringBuilder() + sql.append("select ") + sql.append(fieldsSql) + sql.append("\nfrom ") + sql.append(escapedTableName) + + val filtersTableLinks: Seq[TableLink] = { + import SearchFilterExpr._ + def aux(expr: SearchFilterExpr): Seq[TableLink] = expr match { + case Atom.TableName(tableName) => List(findLink(tableName)) + case Intersection(xs) => xs.flatMap(aux) + case Union(xs) => xs.flatMap(aux) + case _ => Nil + } + aux(filter) + } + + val sortingTableLinks: Seq[TableLink] = Sorting.collect(sorting) { + case Dimension(Some(foreignTableName), _, _) => findLink(foreignTableName) + } + + // Combine links from sorting and filter without duplicates + val foreignTableLinks = (filtersTableLinks ++ sortingTableLinks).distinct + + foreignTableLinks.foreach { + case TableLink(keyColumnName, foreignTableName, foreignKeyColumnName) => + val escapedForeignTableName = namingStrategy.table(foreignTableName) + + sql.append("\ninner join ") + sql.append(escapedForeignTableName) + sql.append(" on ") + + sql.append(escapedTableName) + sql.append('.') + sql.append(namingStrategy.column(keyColumnName)) + + sql.append(" = ") + + sql.append(escapedForeignTableName) + sql.append('.') + sql.append(namingStrategy.column(foreignKeyColumnName)) + } + + if (where.nonEmpty) { + sql.append("\nwhere ") + sql.append(where) + } + + if (orderBy.nonEmpty && !countQuery) { + sql.append("\norder by ") + sql.append(orderBy) + } + + if (limitSql.nonEmpty && !countQuery) { + sql.append("\n") + sql.append(limitSql) + } + + (sql.toString, binder(bindings)) + } + + /** + * Converts filter expression to SQL expression. + * + * @return Returns SQL string and list of values for binding in prepared statement. + */ + protected def filterToSql(escapedTableName: String, + filter: SearchFilterExpr, + namingStrategy: NamingStrategy): (String, List[AnyRef]) = { + import SearchFilterBinaryOperation._ + import SearchFilterExpr._ + + def isNull(string: AnyRef) = Option(string).isEmpty || string.toString.toLowerCase == "null" + + def placeholder(field: String) = "?" + + def escapeDimension(dimension: SearchFilterExpr.Dimension) = { + val tableName = dimension.tableName.fold(escapedTableName)(namingStrategy.table) + s"$tableName.${namingStrategy.column(dimension.name)}" + } + + def filterToSqlMultiple(operands: Seq[SearchFilterExpr]) = operands.collect { + case x if !SearchFilterExpr.isEmpty(x) => filterToSql(escapedTableName, x, namingStrategy) + } + + filter match { + case x if isEmpty(x) => + ("", List.empty) + + case AllowAll => + ("1", List.empty) + + case DenyAll => + ("0", List.empty) + + case Atom.Binary(dimension, Eq, value) if isNull(value) => + (s"${escapeDimension(dimension)} is NULL", List.empty) + + case Atom.Binary(dimension, NotEq, value) if isNull(value) => + (s"${escapeDimension(dimension)} is not NULL", List.empty) + + case Atom.Binary(dimension, NotEq, value) if tableData.nullableFields.contains(dimension.name) => + // In MySQL NULL <> Any === NULL + // So, to handle NotEq for nullable fields we need to use more complex SQL expression. + // http://dev.mysql.com/doc/refman/5.7/en/working-with-null.html + val escapedColumn = escapeDimension(dimension) + val sql = s"($escapedColumn is null or $escapedColumn != ${placeholder(dimension.name)})" + (sql, List(value)) + + case Atom.Binary(dimension, op, value) => + val operator = op match { + case Eq => "=" + case NotEq => "!=" + case Like => "like" + case Gt => ">" + case GtEq => ">=" + case Lt => "<" + case LtEq => "<=" + } + (s"${escapeDimension(dimension)} $operator ${placeholder(dimension.name)}", List(value)) + + case Atom.NAry(dimension, op, values) => + val sqlOp = op match { + case SearchFilterNAryOperation.In => "in" + case SearchFilterNAryOperation.NotIn => "not in" + } + + val bindings = ListBuffer[AnyRef]() + val sqlPlaceholder = placeholder(dimension.name) + val formattedValues = values.map { value => + bindings += value + sqlPlaceholder + }.mkString(", ") + (s"${escapeDimension(dimension)} $sqlOp ($formattedValues)", bindings.toList) + + case Intersection(operands) => + val (sql, bindings) = filterToSqlMultiple(operands).unzip + (sql.mkString("(", " and ", ")"), bindings.flatten.toList) + + case Union(operands) => + val (sql, bindings) = filterToSqlMultiple(operands).unzip + (sql.mkString("(", " or ", ")"), bindings.flatten.toList) + } + } + + protected def limitToSql(): String + + /** + * @param escapedMainTableName Should be escaped + */ + protected def sortingToSql(escapedMainTableName: String, + sorting: Sorting, + namingStrategy: NamingStrategy): String = { + sorting match { + case Dimension(optSortingTableName, field, order) => + val sortingTableName = optSortingTableName.map(namingStrategy.table).getOrElse(escapedMainTableName) + val fullName = s"$sortingTableName.${namingStrategy.column(field)}" + + s"$fullName ${orderToSql(order)}" + + case Sequential(xs) => + xs.map(sortingToSql(escapedMainTableName, _, namingStrategy)).mkString(", ") + } + } + + protected def orderToSql(x: SortingOrder): String = x match { + case Ascending => "asc" + case Descending => "desc" + } + + protected def binder(bindings: List[AnyRef]) + (bind: PreparedStatement): PreparedStatement = { + bindings.zipWithIndex.foreach { case (binding, index) => + bind.setObject(index + 1, binding) + } + + bind + } + +} + +case class PostgresQueryBuilderParameters(tableData: QueryBuilder.TableData, + links: Map[String, TableLink] = Map.empty, + filter: SearchFilterExpr = SearchFilterExpr.Empty, + sorting: Sorting = Sorting.Empty, + pagination: Option[Pagination] = None) extends QueryBuilderParameters { + + def limitToSql(): String = { + pagination.map { pagination => + val startFrom = (pagination.pageNumber - 1) * pagination.pageSize + s"limit ${pagination.pageSize} OFFSET $startFrom" + } getOrElse "" + } + +} + +/** + * @param links Links to another tables grouped by foreignTableName + */ +case class MysqlQueryBuilderParameters(tableData: QueryBuilder.TableData, + links: Map[String, TableLink] = Map.empty, + filter: SearchFilterExpr = SearchFilterExpr.Empty, + sorting: Sorting = Sorting.Empty, + pagination: Option[Pagination] = None) extends QueryBuilderParameters { + + def limitToSql(): String = pagination.map { pagination => + val startFrom = (pagination.pageNumber - 1) * pagination.pageSize + s"limit $startFrom, ${pagination.pageSize}" + }.getOrElse("") + +} + +abstract class QueryBuilder[T, D <: SqlIdiom, N <: NamingStrategy](val parameters: QueryBuilderParameters) + (implicit runner: QueryBuilder.Runner[T], + countRunner: QueryBuilder.CountRunner, + ec: ExecutionContext) { + + def run: Future[Seq[T]] = runner(parameters) + + def runCount: Future[QueryBuilder.CountResult] = countRunner(parameters) + + /** + * Runs the query and returns total found rows without considering of pagination. + */ + def runWithCount: Future[(Seq[T], Int, Option[LocalDateTime])] = { + val countFuture = runCount + val selectAllFuture = run + for { + (total, lastUpdate) <- countFuture + all <- selectAllFuture + } yield (all, total, lastUpdate) + } + + def withFilter(newFilter: SearchFilterExpr): QueryBuilder[T, D, N] + + def withFilter(filter: Option[SearchFilterExpr]): QueryBuilder[T, D, N] = { + filter.fold(this)(withFilter) + } + + def resetFilter: QueryBuilder[T, D, N] = withFilter(SearchFilterExpr.Empty) + + + def withSorting(newSorting: Sorting): QueryBuilder[T, D, N] + + def withSorting(sorting: Option[Sorting]): QueryBuilder[T, D, N] = { + sorting.fold(this)(withSorting) + } + + def resetSorting: QueryBuilder[T, D, N] = withSorting(Sorting.Empty) + + + def withPagination(newPagination: Pagination): QueryBuilder[T, D, N] + + def withPagination(pagination: Option[Pagination]): QueryBuilder[T, D, N] = { + pagination.fold(this)(withPagination) + } + + def resetPagination: QueryBuilder[T, D, N] + +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SearchFilterExpr.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SearchFilterExpr.scala new file mode 100644 index 0000000..60db303 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/SearchFilterExpr.scala @@ -0,0 +1,210 @@ +package xyz.driver.pdsuicommon.db + +import xyz.driver.pdsuicommon.logging._ + +sealed trait SearchFilterExpr { + def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] + def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr +} + +object SearchFilterExpr { + + val Empty = Intersection.Empty + val Forbid = Atom.Binary( + dimension = Dimension(None, "true"), + op = SearchFilterBinaryOperation.Eq, + value = "false" + ) + + case class Dimension(tableName: Option[String], name: String) { + def isForeign: Boolean = tableName.isDefined + } + + sealed trait Atom extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + if (p(this)) Some(this) + else None + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + object Atom { + case class Binary(dimension: Dimension, op: SearchFilterBinaryOperation, value: AnyRef) extends Atom + object Binary { + def apply(field: String, op: SearchFilterBinaryOperation, value: AnyRef): Binary = + Binary(Dimension(None, field), op, value) + } + + case class NAry(dimension: Dimension, op: SearchFilterNAryOperation, values: Seq[AnyRef]) extends Atom + object NAry { + def apply(field: String, op: SearchFilterNAryOperation, values: Seq[AnyRef]): NAry = + NAry(Dimension(None, field), op, values) + } + + /** dimension.tableName extractor */ + object TableName { + def unapply(value: Atom): Option[String] = value match { + case Binary(Dimension(tableNameOpt, _), _, _) => tableNameOpt + case NAry(Dimension(tableNameOpt, _), _, _) => tableNameOpt + } + } + } + + case class Intersection private(operands: Seq[SearchFilterExpr]) + extends SearchFilterExpr with SearchFilterExprSeqOps { + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else { + this.copy(operands.map(_.replace(f))) + } + } + + } + + object Intersection { + + val Empty = Intersection(Seq()) + + def create(operands: SearchFilterExpr*): SearchFilterExpr = { + val filtered = operands.filterNot(SearchFilterExpr.isEmpty) + filtered.size match { + case 0 => Empty + case 1 => filtered.head + case _ => Intersection(filtered) + } + } + } + + + case class Union private(operands: Seq[SearchFilterExpr]) extends SearchFilterExpr with SearchFilterExprSeqOps { + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else { + this.copy(operands.map(_.replace(f))) + } + } + + } + + object Union { + + val Empty = Union(Seq()) + + def create(operands: SearchFilterExpr*): SearchFilterExpr = { + val filtered = operands.filterNot(SearchFilterExpr.isEmpty) + filtered.size match { + case 0 => Empty + case 1 => filtered.head + case _ => Union(filtered) + } + } + + def create(dimension: Dimension, values: String*): SearchFilterExpr = values.size match { + case 0 => SearchFilterExpr.Empty + case 1 => SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, values.head) + case _ => + val filters = values.map { value => + SearchFilterExpr.Atom.Binary(dimension, SearchFilterBinaryOperation.Eq, value) + } + + create(filters: _*) + } + + def create(dimension: Dimension, values: Set[String]): SearchFilterExpr = + create(dimension, values.toSeq: _*) + + // Backwards compatible API + + /** Create SearchFilterExpr with empty tableName */ + def create(field: String, values: String*): SearchFilterExpr = + create(Dimension(None, field), values:_*) + + /** Create SearchFilterExpr with empty tableName */ + def create(field: String, values: Set[String]): SearchFilterExpr = + create(Dimension(None, field), values) + } + + + case object AllowAll extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + if (p(this)) Some(this) + else None + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + case object DenyAll extends SearchFilterExpr { + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + if (p(this)) Some(this) + else None + } + + override def replace(f: PartialFunction[SearchFilterExpr, SearchFilterExpr]): SearchFilterExpr = { + if (f.isDefinedAt(this)) f(this) + else this + } + } + + def isEmpty(expr: SearchFilterExpr): Boolean = { + expr == Intersection.Empty || expr == Union.Empty + } + + sealed trait SearchFilterExprSeqOps { + this: SearchFilterExpr => + + val operands: Seq[SearchFilterExpr] + + override def find(p: SearchFilterExpr => Boolean): Option[SearchFilterExpr] = { + if (p(this)) Some(this) + else { + // Search the first expr among operands, which satisfy p + // Is's ok to use foldLeft. If there will be performance issues, replace it by recursive loop + operands.foldLeft(Option.empty[SearchFilterExpr]) { + case (None, expr) => expr.find(p) + case (x, _) => x + } + } + } + + } + + // There is no case, when this is unsafe. At this time. + implicit def toPhiString(x: SearchFilterExpr): PhiString = { + if (isEmpty(x)) Unsafe("SearchFilterExpr.Empty") + else Unsafe(x.toString) + } + +} + +sealed trait SearchFilterBinaryOperation + +object SearchFilterBinaryOperation { + + case object Eq extends SearchFilterBinaryOperation + case object NotEq extends SearchFilterBinaryOperation + case object Like extends SearchFilterBinaryOperation + case object Gt extends SearchFilterBinaryOperation + case object GtEq extends SearchFilterBinaryOperation + case object Lt extends SearchFilterBinaryOperation + case object LtEq extends SearchFilterBinaryOperation + +} + +sealed trait SearchFilterNAryOperation + +object SearchFilterNAryOperation { + + case object In extends SearchFilterNAryOperation + case object NotIn extends SearchFilterNAryOperation + +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/Sorting.scala b/src/main/scala/xyz/driver/pdsuicommon/db/Sorting.scala new file mode 100644 index 0000000..4b2427c --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/Sorting.scala @@ -0,0 +1,62 @@ +package xyz.driver.pdsuicommon.db + +import xyz.driver.pdsuicommon.logging._ + +import scala.collection.generic.CanBuildFrom + +sealed trait SortingOrder +object SortingOrder { + + case object Ascending extends SortingOrder + case object Descending extends SortingOrder + +} + +sealed trait Sorting + +object Sorting { + + val Empty = Sequential(Seq.empty) + + /** + * @param tableName None if the table is default (same) + * @param name Dimension name + * @param order Order + */ + case class Dimension(tableName: Option[String], name: String, order: SortingOrder) extends Sorting { + def isForeign: Boolean = tableName.isDefined + } + + case class Sequential(sorting: Seq[Dimension]) extends Sorting { + override def toString: String = if (isEmpty(this)) "Empty" else super.toString + } + + def isEmpty(input: Sorting): Boolean = { + input match { + case Sequential(Seq()) => true + case _ => false + } + } + + def filter(sorting: Sorting, p: Dimension => Boolean): Seq[Dimension] = sorting match { + case x: Dimension if p(x) => Seq(x) + case x: Dimension => Seq.empty + case Sequential(xs) => xs.filter(p) + } + + def collect[B, That](sorting: Sorting) + (f: PartialFunction[Dimension, B]) + (implicit bf: CanBuildFrom[Seq[Dimension], B, That]): That = sorting match { + case x: Dimension if f.isDefinedAt(x) => + val r = bf.apply() + r += f(x) + r.result() + + case x: Dimension => bf.apply().result() + case Sequential(xs) => xs.collect(f) + } + + // Contains dimensions and ordering only, thus it is safe. + implicit def toPhiString(x: Sorting): PhiString = Unsafe(x.toString) + +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/SqlContext.scala b/src/main/scala/xyz/driver/pdsuicommon/db/SqlContext.scala new file mode 100644 index 0000000..5aa0084 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/SqlContext.scala @@ -0,0 +1,183 @@ +package xyz.driver.pdsuicommon.db + +import java.io.Closeable +import java.net.URI +import java.time._ +import java.util.UUID +import java.util.concurrent.Executors +import javax.sql.DataSource + +import xyz.driver.pdsuicommon.logging.{PhiLogging, Unsafe} +import xyz.driver.pdsuicommon.concurrent.MdcExecutionContext +import xyz.driver.pdsuicommon.db.SqlContext.Settings +import xyz.driver.pdsuicommon.domain._ +import xyz.driver.pdsuicommon.error.IncorrectIdException +import xyz.driver.pdsuicommon.utils.JsonSerializer +import com.typesafe.config.Config +import io.getquill._ +import xyz.driver.pdsuidomain.entities.{CaseId, RecordRequestId} + +import scala.concurrent.ExecutionContext +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} + +object SqlContext extends PhiLogging { + + case class DbCredentials(user: String, + password: String, + host: String, + port: Int, + dbName: String, + dbCreateFlag: Boolean, + dbContext: String, + connectionParams: String, + url: String) + + case class Settings(credentials: DbCredentials, + connection: Config, + connectionAttemptsOnStartup: Int, + threadPoolSize: Int) + + def apply(settings: Settings): SqlContext = { + // Prevent leaking credentials to a log + Try(JdbcContextConfig(settings.connection).dataSource) match { + case Success(dataSource) => new SqlContext(dataSource, settings) + case Failure(NonFatal(e)) => + logger.error(phi"Can not load dataSource, error: ${Unsafe(e.getClass.getName)}") + throw new IllegalArgumentException("Can not load dataSource from config. Check your database and config") + } + } + +} + +class SqlContext(dataSource: DataSource with Closeable, settings: Settings) + extends MysqlJdbcContext[MysqlEscape](dataSource) + with EntityExtractorDerivation[Literal] { + + private val tpe = Executors.newFixedThreadPool(settings.threadPoolSize) + + implicit val executionContext: ExecutionContext = { + val orig = ExecutionContext.fromExecutor(tpe) + MdcExecutionContext.from(orig) + } + + override def close(): Unit = { + super.close() + tpe.shutdownNow() + } + + // ///////// Encodes/Decoders /////////// + + /** + * Overrode, because Quill JDBC optionDecoder pass null inside decoders. + * If custom decoder don't have special null handler, it will failed. + * + * @see https://github.com/getquill/quill/issues/535 + */ + implicit override def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] = + decoder( + sqlType = d.sqlType, + row => index => { + try { + val res = d(index - 1, row) + if (row.wasNull) { + None + } + else { + Some(res) + } + } catch { + case _: NullPointerException => None + case _: IncorrectIdException => None + } + } + ) + + implicit def encodeStringId[T] = MappedEncoding[StringId[T], String](_.id) + implicit def decodeStringId[T] = MappedEncoding[String, StringId[T]] { + case "" => throw IncorrectIdException("'' is an invalid Id value") + case x => StringId(x) + } + + def decodeOptStringId[T] = MappedEncoding[Option[String], Option[StringId[T]]] { + case None | Some("") => None + case Some(x) => Some(StringId(x)) + } + + implicit def encodeLongId[T] = MappedEncoding[LongId[T], Long](_.id) + implicit def decodeLongId[T] = MappedEncoding[Long, LongId[T]] { + case 0 => throw IncorrectIdException("0 is an invalid Id value") + case x => LongId(x) + } + + // TODO Dirty hack, see REP-475 + def decodeOptLongId[T] = MappedEncoding[Option[Long], Option[LongId[T]]] { + case None | Some(0) => None + case Some(x) => Some(LongId(x)) + } + + implicit def encodeUuidId[T] = MappedEncoding[UuidId[T], String](_.toString) + implicit def decodeUuidId[T] = MappedEncoding[String, UuidId[T]] { + case "" => throw IncorrectIdException("'' is an invalid Id value") + case x => UuidId(x) + } + + def decodeOptUuidId[T] = MappedEncoding[Option[String], Option[UuidId[T]]] { + case None | Some("") => None + case Some(x) => Some(UuidId(x)) + } + + implicit def encodeTextJson[T: Manifest] = MappedEncoding[TextJson[T], String](x => JsonSerializer.serialize(x.content)) + implicit def decodeTextJson[T: Manifest] = MappedEncoding[String, TextJson[T]] { x => + TextJson(JsonSerializer.deserialize[T](x)) + } + + implicit val encodeUserRole = MappedEncoding[User.Role, Int](_.bit) + implicit val decodeUserRole = MappedEncoding[Int, User.Role] { + // 0 is treated as null for numeric types + case 0 => throw new NullPointerException("0 means no roles. A user must have a role") + case x => User.Role(x) + } + + implicit val encodeEmail = MappedEncoding[Email, String](_.value.toString) + implicit val decodeEmail = MappedEncoding[String, Email](Email) + + implicit val encodePasswordHash = MappedEncoding[PasswordHash, Array[Byte]](_.value) + implicit val decodePasswordHash = MappedEncoding[Array[Byte], PasswordHash](PasswordHash(_)) + + implicit val encodeUri = MappedEncoding[URI, String](_.toString) + implicit val decodeUri = MappedEncoding[String, URI](URI.create) + + implicit val encodeCaseId = MappedEncoding[CaseId, String](_.id.toString) + implicit val decodeCaseId = MappedEncoding[String, CaseId](CaseId(_)) + + implicit val encodeFuzzyValue = { + MappedEncoding[FuzzyValue, String] { + case FuzzyValue.No => "No" + case FuzzyValue.Yes => "Yes" + case FuzzyValue.Maybe => "Maybe" + } + } + implicit val decodeFuzzyValue = MappedEncoding[String, FuzzyValue] { + case "Yes" => FuzzyValue.Yes + case "No" => FuzzyValue.No + case "Maybe" => FuzzyValue.Maybe + case x => + Option(x).fold { + throw new NullPointerException("FuzzyValue is null") // See catch in optionDecoder + } { _ => + throw new IllegalStateException(s"Unknown fuzzy value: $x") + } + } + + implicit val encodeRecordRequestId = MappedEncoding[RecordRequestId, String](_.id.toString) + implicit val decodeRecordRequestId = MappedEncoding[String, RecordRequestId] { x => + RecordRequestId(UUID.fromString(x)) + } + + final implicit class LocalDateTimeDbOps(val left: LocalDateTime) { + + // scalastyle:off + def <=(right: LocalDateTime): Quoted[Boolean] = quote(infix"$left <= $right".as[Boolean]) + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/Transactions.scala b/src/main/scala/xyz/driver/pdsuicommon/db/Transactions.scala new file mode 100644 index 0000000..72c358a --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/Transactions.scala @@ -0,0 +1,23 @@ +package xyz.driver.pdsuicommon.db + +import xyz.driver.pdsuicommon.logging.PhiLogging + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +class Transactions()(implicit context: SqlContext) extends PhiLogging { + def run[T](f: SqlContext => T): Future[T] = { + import context.executionContext + + Future(context.transaction(f(context))).andThen { + case Failure(e) => logger.error(phi"Can't run a transaction: $e") + } + } + + def runSync[T](f: SqlContext => T): Unit = { + Try(context.transaction(f(context))) match { + case Success(_) => + case Failure(e) => logger.error(phi"Can't run a transaction: $e") + } + } +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/repositories/BridgeUploadQueueRepository.scala b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/BridgeUploadQueueRepository.scala new file mode 100644 index 0000000..edef2eb --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/BridgeUploadQueueRepository.scala @@ -0,0 +1,24 @@ +package xyz.driver.pdsuicommon.db.repositories + +import xyz.driver.pdsuicommon.concurrent.BridgeUploadQueue +import xyz.driver.pdsuicommon.domain.LongId + +import scala.concurrent.Future + +trait BridgeUploadQueueRepository extends Repository { + + type EntityT = BridgeUploadQueue.Item + type IdT = LongId[EntityT] + + def add(draft: EntityT): EntityT + + def getById(id: LongId[EntityT]): Option[EntityT] + + def isCompleted(kind: String, tag: String): Future[Boolean] + + def getOne(kind: String): Future[Option[BridgeUploadQueue.Item]] + + def update(entity: EntityT): EntityT + + def delete(id: IdT): Unit +} diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/repositories/Repository.scala b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/Repository.scala new file mode 100644 index 0000000..d671e80 --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/Repository.scala @@ -0,0 +1,4 @@ +package xyz.driver.pdsuicommon.db.repositories + +// For further usage and migration to Postgres and slick +trait Repository extends RepositoryLogging diff --git a/src/main/scala/xyz/driver/pdsuicommon/db/repositories/RepositoryLogging.scala b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/RepositoryLogging.scala new file mode 100644 index 0000000..d1ec1da --- /dev/null +++ b/src/main/scala/xyz/driver/pdsuicommon/db/repositories/RepositoryLogging.scala @@ -0,0 +1,62 @@ +package xyz.driver.pdsuicommon.db.repositories + +import xyz.driver.pdsuicommon.logging._ + +trait RepositoryLogging extends PhiLogging { + + protected def logCreatedOne[T](x: T)(implicit toPhiString: T => PhiString): T = { + logger.info(phi"An entity was created: $x") + x + } + + protected def logCreatedMultiple[T <: Iterable[_]](xs: T)(implicit toPhiString: T => PhiString): T = { + if (xs.nonEmpty) { + logger.info(phi"Entities were created: $xs") + } + xs + } + + protected def logUpdatedOne(rowsAffected: Long): Long = { + rowsAffected match { + case 0 => logger.trace(phi"The entity is up to date") + case 1 => logger.info(phi"The entity was updated") + case x => logger.warn(phi"The ${Unsafe(x)} entities were updated") + } + rowsAffected + } + + protected def logUpdatedOneUnimportant(rowsAffected: Long): Long = { + rowsAffected match { + case 0 => logger.trace(phi"The entity is up to date") + case 1 => logger.trace(phi"The entity was updated") + case x => logger.warn(phi"The ${Unsafe(x)} entities were updated") + } + rowsAffected + } + + protected def logUpdatedMultiple(rowsAffected: Long): Long = { + rowsAffected match { + case 0 => logger.trace(phi"All entities are up to date") + case x => logger.info(phi"The ${Unsafe(x)} entities were updated") + } + rowsAffected + } + + protected def logDeletedOne(rowsAffected: Long): Long = { + rowsAffected match { + case 0 => logger.trace(phi"The entity does not exist") + case 1 => logger.info(phi"The entity was deleted") + case x => logger.warn(phi"Deleted ${Unsafe(x)} entities, expected one") + } + rowsAffected + } + + protected def logDeletedMultiple(rowsAffected: Long): Long = { + rowsAffected match { + case 0 => logger.trace(phi"Entities do not exist") + case x => logger.info(phi"Deleted ${Unsafe(x)} entities") + } + rowsAffected + } + +} |