aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/TypedIdTable.scala48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala
new file mode 100644
index 0000000..b5c6056
--- /dev/null
+++ b/src/main/scala/TypedIdTable.scala
@@ -0,0 +1,48 @@
+import slick.codegen.SourceCodeGenerator
+import slick.{model => m}
+
+class TypedIdSourceCodeGenerator(
+ model: m.Model,
+ idType: Option[String],
+ manualForeignKeys: Map[(String, String), (String, String)]
+) extends SourceCodeGenerator(model) {
+ val manualReferences = SchemaParser.references(model, manualForeignKeys)
+
+ def derefColumn(table: m.Table,
+ column: m.Column): (m.Table, m.Column) = {
+ val referencedColumn: Seq[(m.Table, m.Column)] =
+ table.foreignKeys
+ .filter(tableFk => tableFk.referencingColumns.forall(_ == column))
+ .filter(columnFk => columnFk.referencedColumns.length == 1)
+ .flatMap(_.referencedColumns.map(c =>
+ (model.tablesByName(c.table), c)))
+ assert(referencedColumn.distinct.length <= 1, referencedColumn)
+
+ referencedColumn.headOption
+ .orElse(manualReferences.get((table.name.asString, column.name)))
+ .map((derefColumn _).tupled)
+ .getOrElse((table, column))
+ }
+
+ override def Table = new Table(_){ table =>
+ override def Column = new Column(_) { column =>
+
+ def tableReferenceName(tableName: m.QualifiedName) = {
+ val schemaObjectName = tableName.schema.getOrElse("`public`")
+ val rowTypeName = entityName(tableName.table)
+ val idTypeName = idType.getOrElse("Id")
+ s"$idTypeName[$schemaObjectName.$rowTypeName]"
+ }
+
+ override def rawType: String = {
+ // write key columns as Id types
+ val (referencedTable, referencedColumn) =
+ derefColumn(table.model, column.model)
+ if (referencedColumn.options.contains(
+ slick.ast.ColumnOption.PrimaryKey))
+ tableReferenceName(referencedTable.name)
+ else super.rawType
+ }
+ }
+ }
+}