aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStewart Stewart <stewinsalot@gmail.com>2017-02-24 20:36:19 -0800
committerGitHub <noreply@github.com>2017-02-24 20:36:19 -0800
commit77d13f93c0a67e6190cd184c42f6c2d3fe7eb476 (patch)
treee6479d022ec48dd3edf899dfd87db8b950b30ed0
parentae6eb86aa9396e43104520697b3aa5e42f79d6e2 (diff)
parent36339f501c5dfb59deb70d1bcfa0c6270894cf38 (diff)
downloadslick-codegen-plugin-77d13f93c0a67e6190cd184c42f6c2d3fe7eb476.tar.gz
slick-codegen-plugin-77d13f93c0a67e6190cd184c42f6c2d3fe7eb476.tar.bz2
slick-codegen-plugin-77d13f93c0a67e6190cd184c42f6c2d3fe7eb476.zip
Merge pull request #25 from drivergroup/tables-and-row-generators
Separate tables and row generators
-rw-r--r--src/main/scala/Generators.scala137
-rw-r--r--src/main/scala/Main.scala108
-rw-r--r--src/main/scala/NamespacedCodegen.scala262
-rw-r--r--src/main/scala/OutputHelpers.scala72
-rw-r--r--src/main/scala/SchemaParser.scala60
-rw-r--r--src/main/scala/TypedIdTable.scala50
6 files changed, 415 insertions, 274 deletions
diff --git a/src/main/scala/Generators.scala b/src/main/scala/Generators.scala
new file mode 100644
index 0000000..fb75cc1
--- /dev/null
+++ b/src/main/scala/Generators.scala
@@ -0,0 +1,137 @@
+import slick.codegen.SourceCodeGenerator
+import slick.{model => m}
+
+class RowSourceCodeGenerator(
+ model: m.Model,
+ override val headerComment: String,
+ override val imports: String,
+ override val schemaName: String,
+ fullDatabaseModel: m.Model,
+ idType: Option[String],
+ manualForeignKeys: Map[(String, String), (String, String)],
+ typeReplacements: Map[String, String]
+) extends TypedIdSourceCodeGenerator(
+ model,
+ fullDatabaseModel,
+ idType,
+ manualForeignKeys
+ )
+ with RowOutputHelpers {
+
+ override def Table = new TypedIdTable(_) { table =>
+
+ override def Column = new TypedIdColumn(_) {
+ override def rawType: String = {
+ typeReplacements.getOrElse(model.tpe, super.rawType)
+ }
+ }
+
+ override def EntityType = new EntityType {
+ override def code: String =
+ (if (classEnabled) "final " else "") + super.code
+ }
+
+ override def code = Seq[Def](EntityType).map(_.docWithCode)
+ }
+
+ override def code = tables.map(_.code.mkString("\n")).mkString("\n\n")
+}
+
+class TableSourceCodeGenerator(
+ schemaOnlyModel: m.Model,
+ override val headerComment: String,
+ override val imports: String,
+ override val schemaName: String,
+ fullDatabaseModel: m.Model,
+ pkg: String,
+ manualForeignKeys: Map[(String, String), (String, String)],
+ override val parentType: Option[String],
+ idType: Option[String],
+ typeReplacements: Map[String, String])
+ extends TypedIdSourceCodeGenerator(schemaOnlyModel,
+ fullDatabaseModel,
+ idType,
+ manualForeignKeys)
+ with TableOutputHelpers {
+
+ val defaultIdImplementation =
+ """|final case class Id[T](v: Int)
+ |trait DefaultIdTypeMapper {
+ | val profile: slick.driver.JdbcProfile
+ | import profile.api._
+ | implicit def idTypeMapper[A]: BaseColumnType[Id[A]] = MappedColumnType.base[Id[A], Int](_.v, Id(_))
+ |}
+ |""".stripMargin
+
+ override def code = super.code.lines.drop(1).mkString("\n")
+ // Drops needless import: `"import slick.model.ForeignKeyAction\n"`.
+ // Alias to ForeignKeyAction is in profile.api
+ // TODO: fix upstream
+
+ override def Table = new this.TypedIdTable(_) { table =>
+ override def TableClass = new TableClass() {
+ // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it
+ override def optionEnabled = false
+ }
+
+ // use hlists all the time
+ override def hlistEnabled: Boolean = true
+
+ // if false rows are type aliases to hlists, if true rows are case classes
+ override def mappingEnabled: Boolean = true
+
+ // create case class from colums
+ override def factory: String =
+ if (!hlistEnabled) super.factory
+ else {
+ val args = columns.zipWithIndex.map("a" + _._2)
+ val hlist = args.mkString("::") + ":: HNil"
+ val hlistType = columns
+ .map(_.actualType)
+ .mkString("::") + ":: HNil.type"
+ s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
+ }
+
+ // from case class create columns
+ override def extractor: String =
+ if (!hlistEnabled) super.extractor
+ else
+ s"(a : ${TableClass.elementType}) => Some(" + columns
+ .map("a." + _.name)
+ .mkString("::") + ":: HNil)"
+
+ override def EntityType = new EntityType {
+ override def enabled = false
+ }
+
+ override def Column = new TypedIdColumn(_) {
+ override def rawType: String = {
+ typeReplacements.getOrElse(model.tpe, super.rawType)
+ }
+ }
+
+ override def ForeignKey = new ForeignKey(_) {
+ override def code = {
+ val fkColumns = compoundValue(referencingColumns.map(_.name))
+ val qualifier =
+ if (referencedTable.model.name.schema == referencingTable.model.name.schema)
+ ""
+ else
+ referencedTable.model.name.schema.fold("")(sname =>
+ s"$pkg.$sname.")
+
+ val qualifiedName = qualifier + referencedTable.TableValue.name
+ val pkColumns = compoundValue(referencedColumns.map(c =>
+ s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?"
+ else ""}"))
+ val fkName = referencingColumns
+ .map(_.name)
+ .flatMap(_.split("_"))
+ .map(_.capitalize)
+ .mkString
+ .uncapitalize + "Fk"
+ s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)"""
+ }
+ }
+ }
+}
diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala
new file mode 100644
index 0000000..55275a3
--- /dev/null
+++ b/src/main/scala/Main.scala
@@ -0,0 +1,108 @@
+import java.net.URI
+import java.nio.file.Paths
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.concurrent.ExecutionContext.Implicits.global
+import slick.backend.DatabaseConfig
+import slick.codegen.SourceCodeGenerator
+import slick.driver.JdbcProfile
+
+trait TableFileGenerator { self: SourceCodeGenerator =>
+ def writeTablesToFile(profile: String,
+ folder: String,
+ pkg: String,
+ fileName: String): Unit
+}
+
+trait RowFileGenerator { self: SourceCodeGenerator =>
+ def writeRowsToFile(folder: String, pkg: String, fileName: String): Unit
+}
+
+object Generator {
+
+ private def outputSchemaCode(schemaName: String,
+ profile: String,
+ folder: String,
+ pkg: String,
+ tableGen: TableFileGenerator,
+ rowGen: RowFileGenerator): Unit = {
+ val camelSchemaName = schemaName.split('_').map(_.capitalize).mkString("")
+
+ tableGen.writeTablesToFile(profile: String,
+ folder: String,
+ pkg: String,
+ fileName = s"${camelSchemaName}Tables.scala")
+ rowGen.writeRowsToFile(folder: String,
+ pkg: String,
+ fileName = s"${camelSchemaName}Rows.scala")
+ }
+
+ def run(uri: URI,
+ pkg: String,
+ schemaNames: Option[List[String]],
+ outputPath: String,
+ manualForeignKeys: Map[(String, String), (String, String)],
+ parentType: Option[String],
+ idType: Option[String],
+ header: String,
+ schemaImports: List[String],
+ typeReplacements: Map[String, String]) = {
+ val dc: DatabaseConfig[JdbcProfile] =
+ DatabaseConfig.forURI[JdbcProfile](uri)
+ val parsedSchemasOpt: Option[Map[String, List[String]]] =
+ schemaNames.map(SchemaParser.parse)
+ val imports = schemaImports.map("import " + _).mkString("\n")
+
+ try {
+ val dbModel: slick.model.Model = Await.result(
+ dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)),
+ Duration.Inf)
+
+ parsedSchemasOpt.getOrElse(Map.empty).foreach {
+ case (schemaName, tables) =>
+ val profile =
+ s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri
+ .getFragment()}").driver"""
+
+ val schemaOnlyModel = Await.result(
+ dc.db.run(
+ SchemaParser.createModel(dc.driver,
+ Some(Map(schemaName -> tables)))),
+ Duration.Inf)
+
+ val rowGenerator = new RowSourceCodeGenerator(
+ schemaOnlyModel,
+ headerComment = header,
+ imports = imports,
+ schemaName = schemaName,
+ dbModel,
+ idType,
+ manualForeignKeys,
+ typeReplacements
+ )
+
+ val tableGenerator =
+ new TableSourceCodeGenerator(schemaOnlyModel = schemaOnlyModel,
+ headerComment = header,
+ imports = imports,
+ schemaName = schemaName,
+ fullDatabaseModel = dbModel,
+ pkg = pkg,
+ manualForeignKeys,
+ parentType = parentType,
+ idType,
+ typeReplacements)
+
+ outputSchemaCode(schemaName = schemaName,
+ profile = profile,
+ folder = outputPath,
+ pkg = pkg,
+ tableGen = tableGenerator,
+ rowGen = rowGenerator)
+ }
+ } finally {
+ dc.db.close()
+ }
+ }
+}
diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala
deleted file mode 100644
index 85e21bf..0000000
--- a/src/main/scala/NamespacedCodegen.scala
+++ /dev/null
@@ -1,262 +0,0 @@
-import java.net.URI
-import java.nio.file.Paths
-
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
-import scala.concurrent.ExecutionContext.Implicits.global
-import slick.backend.DatabaseConfig
-import slick.codegen.{
- SourceCodeGenerator,
- StringGeneratorHelpers
-}
-import slick.dbio.DBIO
-import slick.driver.JdbcProfile
-import slick.jdbc.meta.MTable
-import slick.{model => sModel}
-import slick.model.{Column, Model, Table, QualifiedName}
-
-object Generator {
-
- def run(uri: URI,
- pkg: String,
- schemaNames: Option[List[String]],
- outputPath: String,
- manualForeignKeys: Map[(String, String), (String, String)],
- parentType: Option[String],
- idType: Option[String],
- header: String,
- schemaImports: List[String],
- typeReplacements: Map[String, String]) = {
- val dc: DatabaseConfig[JdbcProfile] =
- DatabaseConfig.forURI[JdbcProfile](uri)
- val parsedSchemasOpt: Option[Map[String, List[String]]] =
- schemaNames.map(SchemaParser.parse)
-
- try {
- val dbModel: Model = Await.result(
- dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)),
- Duration.Inf)
-
- parsedSchemasOpt.getOrElse(Map.empty).foreach {
- case (schemaName, tables) =>
- val profile =
- s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri
- .getFragment()}").driver"""
-
- val schemaOnlyModel = Await.result(
- dc.db.run(
- SchemaParser.createModel(dc.driver,
- Some(Map(schemaName -> tables)))),
- Duration.Inf)
-
- val generator = new Generator(pkg,
- dbModel,
- schemaOnlyModel,
- manualForeignKeys,
- parentType,
- idType,
- header,
- schemaImports,
- typeReplacements)
- generator.writeToFile(profile = profile,
- folder = outputPath,
- pkg = pkg,
- container = schemaName,
- fileName = s"${schemaName}.scala")
- }
- } finally {
- dc.db.close()
- }
- }
-
-}
-
-class Generator(pkg: String,
- fullDatabaseModel: Model,
- schemaOnlyModel: Model,
- manualForeignKeys: Map[(String, String), (String, String)],
- override val parentType: Option[String],
- idType: Option[String],
- override val headerComment: String,
- schemaImports: List[String],
- typeReplacements: Map[String, String])
- extends SourceCodeGenerator(schemaOnlyModel)
- with OutputHelpers {
-
- override val imports = schemaImports.map("import " + _).mkString("\n")
-
- val defaultIdImplementation =
- """|final case class Id[T](v: Int)
- |trait DefaultIdTypeMapper {
- | val profile: slick.driver.JdbcProfile
- | import profile.api._
- | implicit def idTypeMapper[A]: BaseColumnType[Id[A]] = MappedColumnType.base[Id[A], Int](_.v, Id(_))
- |}
- |""".stripMargin
-
- override def code = super.code.lines.drop(1).mkString("\n")
- // Drops needless import: `"import slick.model.ForeignKeyAction\n"`.
- // Alias to ForeignKeyAction is in profile.api
- // TODO: fix upstream
-
- override def Table = new Table(_) { table =>
-
- override def TableClass = new TableClass() {
- // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it
- override def optionEnabled = false
- }
-
- // use hlists all the time
- override def hlistEnabled: Boolean = true
-
- // if false rows are type aliases to hlists, if true rows are case classes
- override def mappingEnabled: Boolean = true
-
- // create case class from colums
- override def factory: String =
- if (!hlistEnabled) super.factory
- else {
- val args = columns.zipWithIndex.map("a" + _._2)
- val hlist = args.mkString("::") + ":: HNil"
- val hlistType = columns
- .map(_.actualType)
- .mkString("::") + ":: HNil.type"
- s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
- }
-
- // from case class create columns
- override def extractor: String =
- if (!hlistEnabled) super.extractor
- else
- s"(a : ${TableClass.elementType}) => Some(" + columns
- .map("a." + _.name)
- .mkString("::") + ":: HNil)"
-
- override def EntityType = new EntityTypeDef {
- override def code: String =
- // Wartremover wants `final`
- // But can't have the final case class inside the trait
- // TODO: Fix by putting case classes in package or object
- // TODO: Upstream default should be false.
- (if (classEnabled) "sealed " else "") + super.code
- }
-
- override def Column = new Column(_) { column =>
- // use fullDatabasemodel model here for cross-schema foreign keys
- val manualReferences =
- SchemaParser.references(fullDatabaseModel, manualForeignKeys)
-
- // work out the destination of the foreign key
- def derefColumn(table: sModel.Table,
- column: sModel.Column): (sModel.Table, sModel.Column) = {
- val referencedColumn: Seq[(sModel.Table, sModel.Column)] =
- table.foreignKeys
- .filter(tableFk => tableFk.referencingColumns.forall(_ == column))
- .filter(columnFk => columnFk.referencedColumns.length == 1)
- .flatMap(_.referencedColumns.map(c =>
- (fullDatabaseModel.tablesByName(c.table), c)))
- assert(referencedColumn.distinct.length <= 1, referencedColumn)
-
- referencedColumn.headOption
- .orElse(manualReferences.get((table.name.asString, column.name)))
- .map((derefColumn _).tupled)
- .getOrElse((table, column))
- }
-
- def tableReferenceName(tableName: QualifiedName) = {
- val schemaObjectName = tableName.schema.getOrElse("`public`")
- val rowTypeName = entityName(tableName.table)
- val idTypeName = idType.getOrElse("Id")
- s"$idTypeName[$schemaObjectName.$rowTypeName]"
- }
-
- // re-write ids other custom types
- override def rawType: String = {
- val (referencedTable, referencedColumn) =
- derefColumn(table.model, column.model)
- if (referencedColumn.options.contains(
- slick.ast.ColumnOption.PrimaryKey))
- tableReferenceName(referencedTable.name)
- else typeReplacements.getOrElse(model.tpe, model.tpe)
- }
- }
-
- override def ForeignKey = new ForeignKey(_) {
- override def code = {
- val fkColumns = compoundValue(referencingColumns.map(_.name))
- val qualifier =
- if (referencedTable.model.name.schema == referencingTable.model.name.schema)
- ""
- else
- referencedTable.model.name.schema.fold("")(sname =>
- s"$pkg.$sname.")
-
- val qualifiedName = qualifier + referencedTable.TableValue.name
- val pkColumns = compoundValue(referencedColumns.map(c =>
- s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?"
- else ""}"))
- val fkName = referencingColumns
- .map(_.name)
- .flatMap(_.split("_"))
- .map(_.capitalize)
- .mkString
- .uncapitalize + "Fk"
- s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)"""
- }
- }
- }
-}
-
-object SchemaParser {
- def references(dbModel: Model,
- tcMappings: Map[(String, String), (String, String)])
- : Map[(String, String), (Table, Column)] = {
- def getTableColumn(tc: (String, String)): (Table, Column) = {
- val (tableName, columnName) = tc
- val table = dbModel.tables
- .find(_.name.asString == tableName)
- .getOrElse(throw new RuntimeException("No table " + tableName))
- val column = table.columns
- .find(_.name == columnName)
- .getOrElse(throw new RuntimeException(
- "No column " + columnName + " in table " + tableName))
- (table, column)
- }
-
- tcMappings.map {
- case (from, to) => ({ getTableColumn(from); from }, getTableColumn(to))
- }
- }
-
- def parse(schemaTableNames: List[String]): Map[String, List[String]] =
- schemaTableNames
- .map(_.split('.'))
- .groupBy(_.head)
- .mapValues(_.flatMap(_.tail))
-
- def createModel(
- jdbcProfile: JdbcProfile,
- mappedSchemasOpt: Option[Map[String, List[String]]]): DBIO[Model] = {
- import slick.jdbc.meta.MQName
-
- val filteredTables = mappedSchemasOpt.map { mappedSchemas =>
- MTable.getTables.map { (tables: Vector[MTable]) =>
- mappedSchemas.flatMap {
- case (schemaName, tableNames) =>
- tableNames.map(
- tableName =>
- tables
- .find(table =>
- table.name match {
- case MQName(_, Some(`schemaName`), `tableName`) => true
- case _ => false
- })
- .getOrElse(throw new IllegalArgumentException(
- s"$schemaName.$tableName does not exist in the connected database.")))
- }.toList
- }
- }
-
- jdbcProfile.createModel(filteredTables)
- }
-}
diff --git a/src/main/scala/OutputHelpers.scala b/src/main/scala/OutputHelpers.scala
index ce22f2a..97fb4e9 100644
--- a/src/main/scala/OutputHelpers.scala
+++ b/src/main/scala/OutputHelpers.scala
@@ -1,28 +1,76 @@
-trait OutputHelpers extends slick.codegen.OutputHelpers {
+import slick.codegen.{SourceCodeGenerator, OutputHelpers}
- def imports: String
+trait TableOutputHelpers extends TableFileGenerator with OutputHelpers {
+ self: SourceCodeGenerator =>
- def headerComment: String = ""
+ def headerComment: String
+ def schemaName: String
+ def imports: String
- override def packageCode(profile: String,
- pkg: String,
- container: String,
- parentType: Option[String]): String = {
- val traitName = container.capitalize + "SchemaDef"
+ def packageTableCode(headerComment: String,
+ pkg: String,
+ schemaName: String,
+ imports: String,
+ profile: String): String =
s"""|${headerComment.trim().lines.map("// " + _).mkString("\n")}
|package $pkg
+ |package $schemaName
|
|$imports
|
|/** Stand-alone Slick data model for immediate use */
- |object $container extends {
+ |// TODO: change this to `object tables`
+ |object `package` extends {
| val profile = $profile
- |} with $traitName
+ |} with Tables
|
|/** Slick data model trait for extension, choice of backend or usage in the cake pattern. (Make sure to initialize this late.) */
- |trait $traitName${parentType.fold("")(" extends " + _)} {
+ |trait Tables${parentType.fold("")(" extends " + _)} {
| import profile.api._
| ${indent(code)}
- |}""".stripMargin.trim()
+ |}
+ |""".stripMargin.trim()
+
+ def writeTablesToFile(profile: String,
+ folder: String,
+ pkg: String,
+ fileName: String): Unit = {
+ writeStringToFile(
+ content =
+ packageTableCode(headerComment, pkg, schemaName, imports, profile),
+ folder = folder,
+ pkg = s"$pkg.$schemaName",
+ fileName = fileName)
+ }
+}
+
+trait RowOutputHelpers extends RowFileGenerator with OutputHelpers {
+ self: SourceCodeGenerator =>
+
+ def headerComment: String
+ def schemaName: String
+ def imports: String
+
+ def packageRowCode(headerComment: String,
+ schemaName: String,
+ pkg: String,
+ imports: String): String =
+ s"""|${headerComment.trim().lines.map("// " + _).mkString("\n")}
+ |/** Definitions for table rows types of database schema $schemaName */
+ |package $pkg
+ |package $schemaName
+ |
+ |$imports
+ |
+ |$code
+ |""".stripMargin.trim()
+
+ def writeRowsToFile(folder: String, pkg: String, fileName: String): Unit = {
+
+ writeStringToFile(
+ content = packageRowCode(headerComment, schemaName, pkg, imports),
+ folder = folder,
+ pkg = s"$pkg.$schemaName",
+ fileName = fileName)
}
}
diff --git a/src/main/scala/SchemaParser.scala b/src/main/scala/SchemaParser.scala
new file mode 100644
index 0000000..1186f11
--- /dev/null
+++ b/src/main/scala/SchemaParser.scala
@@ -0,0 +1,60 @@
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import slick.dbio.DBIO
+import slick.driver.JdbcProfile
+import slick.jdbc.meta.MTable
+import slick.{model => m}
+
+object SchemaParser {
+ def references(dbModel: m.Model,
+ tcMappings: Map[(String, String), (String, String)])
+ : Map[(String, String), (m.Table, m.Column)] = {
+ def getTableColumn(tc: (String, String)): (m.Table, m.Column) = {
+ val (tableName, columnName) = tc
+ val table = dbModel.tables
+ .find(_.name.asString == tableName)
+ .getOrElse(throw new RuntimeException("No table " + tableName))
+ val column = table.columns
+ .find(_.name == columnName)
+ .getOrElse(throw new RuntimeException(
+ "No column " + columnName + " in table " + tableName))
+ (table, column)
+ }
+
+ tcMappings.map {
+ case (from, to) => ({ getTableColumn(from); from }, getTableColumn(to))
+ }
+ }
+
+ def parse(schemaTableNames: List[String]): Map[String, List[String]] =
+ schemaTableNames
+ .map(_.split('.'))
+ .groupBy(_.head)
+ .mapValues(_.flatMap(_.tail))
+
+ def createModel(
+ jdbcProfile: JdbcProfile,
+ mappedSchemasOpt: Option[Map[String, List[String]]]): DBIO[m.Model] = {
+ import slick.jdbc.meta.MQName
+
+ val filteredTables = mappedSchemasOpt.map { mappedSchemas =>
+ MTable.getTables.map { (tables: Vector[MTable]) =>
+ mappedSchemas.flatMap {
+ case (schemaName, tableNames) =>
+ tableNames.map(
+ tableName =>
+ tables
+ .find(table =>
+ table.name match {
+ case MQName(_, Some(`schemaName`), `tableName`) => true
+ case _ => false
+ })
+ .getOrElse(throw new IllegalArgumentException(
+ s"$schemaName.$tableName does not exist in the connected database.")))
+ }.toList
+ }
+ }
+
+ jdbcProfile.createModel(filteredTables)
+ }
+}
diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala
new file mode 100644
index 0000000..1a8f986
--- /dev/null
+++ b/src/main/scala/TypedIdTable.scala
@@ -0,0 +1,50 @@
+import slick.codegen.SourceCodeGenerator
+import slick.{model => m}
+
+class TypedIdSourceCodeGenerator(
+ singleSchemaModel: m.Model,
+ databaseModel: m.Model,
+ idType: Option[String],
+ manualForeignKeys: Map[(String, String), (String, String)]
+) extends SourceCodeGenerator(singleSchemaModel) {
+ val manualReferences =
+ SchemaParser.references(databaseModel, manualForeignKeys)
+
+ def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = {
+ val referencedColumn: Seq[(m.Table, m.Column)] =
+ table.foreignKeys
+ .filter(tableFk => tableFk.referencingColumns.forall(_ == column))
+ .filter(columnFk => columnFk.referencedColumns.length == 1)
+ .flatMap(_.referencedColumns.map(c =>
+ (databaseModel.tablesByName(c.table), c)))
+ assert(referencedColumn.distinct.length <= 1, referencedColumn)
+
+ referencedColumn.headOption
+ .orElse(manualReferences.get((table.name.asString, column.name)))
+ .map((derefColumn _).tupled)
+ .getOrElse((table, column))
+ }
+
+ class TypedIdTable(model: m.Table) extends Table(model) { table =>
+ class TypedIdColumn(override val model: m.Column) extends Column(model) {
+ column =>
+
+ def tableReferenceName(tableName: m.QualifiedName) = {
+ val schemaObjectName = tableName.schema.getOrElse("`public`")
+ val rowTypeName = entityName(tableName.table)
+ val idTypeName = idType.getOrElse("Id")
+ s"$idTypeName[$schemaObjectName.$rowTypeName]"
+ }
+
+ override def rawType: String = {
+ // write key columns as Id types
+ val (referencedTable, referencedColumn) =
+ derefColumn(table.model, column.model)
+ if (referencedColumn.options.contains(
+ slick.ast.ColumnOption.PrimaryKey))
+ tableReferenceName(referencedTable.name)
+ else super.rawType
+ }
+ }
+ }
+}