diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/Generators.scala | 8 | ||||
-rw-r--r-- | src/main/scala/Main.scala | 4 | ||||
-rw-r--r-- | src/main/scala/TypedIdTable.scala | 93 |
3 files changed, 72 insertions, 33 deletions
diff --git a/src/main/scala/Generators.scala b/src/main/scala/Generators.scala index fb75cc1..7138a3c 100644 --- a/src/main/scala/Generators.scala +++ b/src/main/scala/Generators.scala @@ -11,8 +11,8 @@ class RowSourceCodeGenerator( manualForeignKeys: Map[(String, String), (String, String)], typeReplacements: Map[String, String] ) extends TypedIdSourceCodeGenerator( - model, - fullDatabaseModel, + singleSchemaModel = model, + databaseModel = fullDatabaseModel, idType, manualForeignKeys ) @@ -48,8 +48,8 @@ class TableSourceCodeGenerator( override val parentType: Option[String], idType: Option[String], typeReplacements: Map[String, String]) - extends TypedIdSourceCodeGenerator(schemaOnlyModel, - fullDatabaseModel, + extends TypedIdSourceCodeGenerator(singleSchemaModel = schemaOnlyModel, + databaseModel = fullDatabaseModel, idType, manualForeignKeys) with TableOutputHelpers { diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala index 55275a3..5bd84d1 100644 --- a/src/main/scala/Main.scala +++ b/src/main/scala/Main.scala @@ -72,11 +72,11 @@ object Generator { Duration.Inf) val rowGenerator = new RowSourceCodeGenerator( - schemaOnlyModel, + model = schemaOnlyModel, headerComment = header, imports = imports, schemaName = schemaName, - dbModel, + fullDatabaseModel = dbModel, idType, manualForeignKeys, typeReplacements diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala index 1a8f986..2e58cd5 100644 --- a/src/main/scala/TypedIdTable.scala +++ b/src/main/scala/TypedIdTable.scala @@ -10,40 +10,79 @@ class TypedIdSourceCodeGenerator( val manualReferences = SchemaParser.references(databaseModel, manualForeignKeys) - def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { - val referencedColumn: Seq[(m.Table, m.Column)] = - table.foreignKeys - .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) - .filter(columnFk => columnFk.referencedColumns.length == 1) - .flatMap(_.referencedColumns.map(c => - (databaseModel.tablesByName(c.table), c))) - assert(referencedColumn.distinct.length <= 1, referencedColumn) - - referencedColumn.headOption - .orElse(manualReferences.get((table.name.asString, column.name))) - .map((derefColumn _).tupled) - .getOrElse((table, column)) + val modelTypeToColumnMaper = Map( + "java.util.UUID" -> "uuidKeyMapper", + "String" -> "naturalKeyMapper", + "Int" -> "serialKeyMapper" + ) + + val keyReferences: Map[m.Column, m.Column] = { + val pks = databaseModel.tables + .flatMap(_.columns) + .filter(_.options.contains(slick.ast.ColumnOption.PrimaryKey)) + .map(c => (c -> c)) + + val fks: Seq[(m.Column, m.Column)] = databaseModel.tables + .flatMap(_.foreignKeys) + .filter(_.referencedColumns.length == 1) + .filter(_.referencedColumns.forall( + _.options.contains(slick.ast.ColumnOption.PrimaryKey))) + .flatMap(fk => + fk.referencingColumns.flatMap(from => + fk.referencedColumns.headOption.map(to => (from -> to)))) + + (pks ++ fks).toMap + } + + def pKeyTypeTag(columnRef: m.Column): String = { + val schemaName = columnRef.table.schema.getOrElse("`public`") + val tableName = entityName(columnRef.table.table) + s"$schemaName.$tableName" + } + + def pKeyType(columnRef: m.Column): String = { + s"${idType.getOrElse("Id")}[${pKeyTypeTag(columnRef)}]" } class TypedIdTable(model: m.Table) extends Table(model) { table => + override def definitions = + Seq[Def](EntityType, + PlainSqlMapper, + TableClass, + TableValue, + PrimaryKeyMapper) + class TypedIdColumn(override val model: m.Column) extends Column(model) { - column => + override def rawType: String = { + keyReferences.get(model).fold(super.rawType)(pKeyType) + } + } + + type PrimaryKeyMapper = PrimaryKeyMapperDef + + def PrimaryKeyMapper = new PrimaryKeyMapper {} - def tableReferenceName(tableName: m.QualifiedName) = { - val schemaObjectName = tableName.schema.getOrElse("`public`") - val rowTypeName = entityName(tableName.table) - val idTypeName = idType.getOrElse("Id") - s"$idTypeName[$schemaObjectName.$rowTypeName]" + class PrimaryKeyMapperDef extends TermDef { + def primaryKeyColumn: Option[Column] = { + table.model.columns + .flatMap(c => keyReferences.get(c).filter(_ == c)) + .headOption + .map(c => table.columnsByName(c.name)) } - override def rawType: String = { - // write key columns as Id types - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) - tableReferenceName(referencedTable.name) - else super.rawType + override def enabled = primaryKeyColumn.isDefined + + override def doc = + s"Implicit for mapping primary key of ${tableName(table.model.name.table)} to a base column" + + override def rawName = tableName(table.model.name.table) + "KeyMapper" + + override def code = primaryKeyColumn.fold("") { column => + val tpe = s"BaseColumnType[${column.rawType}]" + s"""|implicit def $name: $tpe = + |${modelTypeToColumnMaper(column.model.tpe)}[${pKeyTypeTag( + column.model)}] + |""".stripMargin.lines.mkString("").trim } } } |