aboutsummaryrefslogblamecommitdiff
path: root/src/main/scala/TypedIdTable.scala
blob: 64f90a2df970e7552458fa2515e50cddbadd9c35 (plain) (tree)
1
2
3
4
5
6
7
8
9
10



                                        



                                                              

                                                  
                                                                    
 
                                   
                                        


                                           

   
                                                

                                  


                                                                    

                                                             
                                              

                                                                                                                    



                      

                                                                 
                                                      
                             
   
 

                                                           

   
                                                                    
                              
                                                                                    
 
                                                                             
                                      
                                                              

       
 

                                               
                                                  

                                               

                                              
                                                            

                                                
       
 

                                                       

                                                                                                    
 
                                                                            

                                                               
                                                      

                                                                                      
                                                    

       

   
import slick.codegen.SourceCodeGenerator
import slick.{model => m}

class TypedIdSourceCodeGenerator(
    singleSchemaModel: m.Model,
    databaseModel: m.Model,
    idType: Option[String],
    manualForeignKeys: Map[(String, String), (String, String)]
) extends SourceCodeGenerator(singleSchemaModel) {
  val manualReferences =
    ModelTransformation.references(databaseModel, manualForeignKeys)

  val modelTypeToColumnMaper = Map(
    "java.util.UUID" -> "uuidKeyMapper",
    "String"         -> "naturalKeyMapper",
    "Int"            -> "serialKeyMapper",
    "Long"           -> "serialKeyMapper"
  )

  val keyReferences: Map[m.Column, m.Column] = {
    val pks = databaseModel.tables
      .flatMap(_.columns)
      .filter(_.options.contains(slick.ast.ColumnOption.PrimaryKey))
      .map(c => (c -> c))

    val fks: Seq[(m.Column, m.Column)] = databaseModel.tables
      .flatMap(_.foreignKeys)
      .filter(_.referencedColumns.length == 1)
      .filter(_.referencedColumns.forall(_.options.contains(slick.ast.ColumnOption.PrimaryKey)))
      .flatMap(fk => fk.referencingColumns.flatMap(from => fk.referencedColumns.headOption.map(to => (from -> to))))

    (pks ++ fks).toMap
  }

  def pKeyTypeTag(columnRef: m.Column): String = {
    val schemaName = columnRef.table.schema.getOrElse("`public`")
    val tableName  = entityName(columnRef.table.table)
    s"$schemaName.$tableName"
  }

  def pKeyType(columnRef: m.Column): String = {
    s"${idType.getOrElse("Id")}[${pKeyTypeTag(columnRef)}]"
  }

  class TypedIdTable(model: m.Table) extends Table(model) { table =>
    override def definitions =
      Seq[Def](EntityType, PlainSqlMapper, TableClass, TableValue, PrimaryKeyMapper)

    class TypedIdColumn(override val model: m.Column) extends Column(model) {
      override def rawType: String = {
        keyReferences.get(model).fold(super.rawType)(pKeyType)
      }
    }

    type PrimaryKeyMapper = PrimaryKeyMapperDef

    def PrimaryKeyMapper = new PrimaryKeyMapper {}

    class PrimaryKeyMapperDef extends TermDef {
      def primaryKeyColumn: Option[Column] = {
        table.model.columns
          .flatMap(c => keyReferences.get(c).filter(_ == c))
          .headOption
          .map(c => table.columnsByName(c.name))
      }

      override def enabled = primaryKeyColumn.isDefined

      override def doc =
        s"Implicit for mapping primary key of ${tableName(table.model.name.table)} to a base column"

      override def rawName = tableName(table.model.name.table) + "KeyMapper"

      override def code = primaryKeyColumn.fold("") { column =>
        val tpe = s"BaseColumnType[${column.rawType}]"
        s"""|implicit def $name: $tpe =
            |${modelTypeToColumnMaper(column.model.tpe)}[${pKeyTypeTag(column.model)}]
            |""".stripMargin.lines.mkString("").trim
      }
    }
  }
}