aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/TypedIdTable.scala
blob: 3867d1afbfbfdd10427142a5f2d1d71b4aa35e8b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import slick.codegen.SourceCodeGenerator
import slick.{model => m}

class TypedIdSourceCodeGenerator(
    singleSchemaModel: m.Model,
    databaseModel: m.Model,
    idType: Option[String],
    manualForeignKeys: Map[(String, String), (String, String)]
) extends SourceCodeGenerator(singleSchemaModel) {
  val manualReferences =
    SchemaParser.references(databaseModel, manualForeignKeys)

  val modelTypeToColumnMaper = Map(
    "java.util.UUID" -> "uuidKeyMapper",
    "String" -> "naturalKeyMapper",
    "Int" -> "serialKeyMapper"
  )

  def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = {
    val referencedColumn: Seq[(m.Table, m.Column)] =
      table.foreignKeys
        .filter(tableFk => tableFk.referencingColumns.forall(_ == column))
        .filter(columnFk => columnFk.referencedColumns.length == 1)
        .flatMap(_.referencedColumns.map(c =>
          (databaseModel.tablesByName(c.table), c)))
    assert(referencedColumn.distinct.length <= 1, referencedColumn)

    referencedColumn.headOption
      .orElse(manualReferences.get((table.name.asString, column.name)))
      .map((derefColumn _).tupled)
      .getOrElse((table, column))
  }

  class TypedIdTable(model: m.Table) extends Table(model) { table =>
    class TypedIdColumn(override val model: m.Column) extends Column(model) {
      column =>

      def rowTypeFor(tableName: m.QualifiedName) = {
        val schemaObjectName = tableName.schema.getOrElse("`public`")
        val rowTypeName = entityName(tableName.table)
        s"$schemaObjectName.$rowTypeName"
      }

      override def code = {
        val (referencedTable, referencedColumn) =
          derefColumn(table.model, column.model)
        if (referencedColumn.options.contains(
          slick.ast.ColumnOption.PrimaryKey))
          s"""|implicit val ${name}KeyMapper: BaseColumnType[${rawType}] =
              |  ${modelTypeToColumnMaper(model.tpe)}[${rowTypeFor(referencedTable.name)}]\n
              |${super.code}"""
        else
          super.code
      }

      override def rawType: String = {
        // write key columns as Id types
        val (referencedTable, referencedColumn) =
          derefColumn(table.model, column.model)
        if (referencedColumn.options.contains(
          slick.ast.ColumnOption.PrimaryKey)) {
          val idTypeName = idType.getOrElse("Id")
          s"$idTypeName[${rowTypeFor(referencedTable.name)}]"
        }
        else super.rawType
      }
    }
  }
}