aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStewart Stewart <stewinsalot@gmail.com>2016-09-13 17:59:06 -0700
committerStewart Stewart <stewinsalot@gmail.com>2016-09-14 20:43:02 -0700
commit01e2d258fc02e7c8706ffae6f42e353eed3228f1 (patch)
tree86886d747a3dd1ead17c40cbeb2dbfacb0f0f3ab
parentcbe217066f286d20624d045bf1690350d3848a11 (diff)
downloadslick-codegen-plugin-01e2d258fc02e7c8706ffae6f42e353eed3228f1.tar.gz
slick-codegen-plugin-01e2d258fc02e7c8706ffae6f42e353eed3228f1.tar.bz2
slick-codegen-plugin-01e2d258fc02e7c8706ffae6f42e353eed3228f1.zip
minor readability refactors
-rw-r--r--src/main/scala/NamespacedCodegen.scala72
1 files changed, 33 insertions, 39 deletions
diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala
index e9177a6..245a4ef 100644
--- a/src/main/scala/NamespacedCodegen.scala
+++ b/src/main/scala/NamespacedCodegen.scala
@@ -22,21 +22,17 @@ import slick.util.ConfigExtensionMethods.configExtensionMethods
// generator places the relevant generated classes into separate
// objects--a "a" object, and a "b" object)
object NamespacedCodegen {
- def parseSchemaList(schemaTableNames: List[String]): Map[String, List[String]] = {
- val (tables, schemas) = schemaTableNames.partition(_.contains("."))
- val mappedSchemas = schemas.map(_ -> List()).toMap
- val mappedTables = tables.groupBy(_.split("\\.")(0)).map {
- case (key, value) => (key, value.map(_.split("\\.")(1)).asInstanceOf[List[String]])
- }
-
- mappedSchemas ++ mappedTables
- }
+ def parseSchemaList(schemaTableNames: List[String]): Map[String, List[String]] =
+ schemaTableNames.map(_.split('.'))
+ .groupBy(_.head)
+ .mapValues(_.flatMap(_.tail))
+ .toMap
def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] =
driver.createModel(Some(
- MTable.getTables.map(_.filter((t: MTable) => mappedSchemas
- .get(t.name.schema.getOrElse(""))
- .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))))))
+ MTable.getTables.map(_.filter((t: MTable) =>
+ t.name.schema.flatMap(mappedSchemas.get).exists(tables =>
+ tables.isEmpty || tables.contains(t.name.name))))))
def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = {
def getTableColumn(tc: (String, String)) : (Table, Column) = {
@@ -68,18 +64,21 @@ object NamespacedCodegen {
val manualReferences = references(dbModel, manualForeignKeys)
def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
- def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
- (table.foreignKeys.toList
- .filter(_.referencingColumns.forall(_ == column))
- .flatMap(fk =>
- fk.referencedColumns match {
- case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
- case _ => None
- }) ++
- manualReferences.get((table.name.asString, column.name)))
- .headOption
+
+ def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = {
+ val referencedColumn: Seq[(m.Table, m.Column)] = table.foreignKeys
+ .filter(tableFk => tableFk.referencingColumns.forall(_ == column))
+ .filter(columnFk => columnFk.referencedColumns.length == 1)
+ .flatMap(_.referencedColumns
+ .map(c => (dbModel.tablesByName(c.table), c)))
+
+ assert(referencedColumn.distinct.length <= 1, referencedColumn)
+
+ referencedColumn.headOption
+ .orElse(manualReferences.get((table.name.asString, column.name)))
.map((derefColumn _).tupled)
.getOrElse((table, column))
+ }
// Is this compatible with ***REMOVED*** Id? How do we make it generic?
def idType(t: m.Table) : String =
@@ -91,9 +90,6 @@ object NamespacedCodegen {
// Why can't we simply re-use?
var imports =
- // acyclic is unnecessary in generic projects
- //if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
- //else
"import slick.model.ForeignKeyAction\n" +
"import rows._\n" +
( if(tables.exists(_.hlistEnabled)){
@@ -107,20 +103,18 @@ object NamespacedCodegen {
} else ""
) + "\n\n" // We didn't copy ddl though
-
- val bySchema = tables.groupBy(t => {
- t.model.name.schema
- })
-
- val schemaFor = (schema: String) => {
- bySchema(Option(schema)).sortBy(_.model.name.table).map(
- _.code.mkString("\n") // TODO explore here
- ).mkString("\n\n")
- }
-
- val schemata = mappedSchemas.keys.toList.sorted.map(
- s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
- ).mkString("\n\n")
+ val sortedSchemaTables: List[(String, Seq[TableDef])] = tables
+ .groupBy(t => t.model.name.schema.getOrElse("`public`"))
+ .toList.sortBy(_._1)
+
+ val schemata: String = sortedSchemaTables.map {
+ case (schemaName, tables) =>
+ val tableCode = tables
+ .sortBy(_.model.name.table)
+ .map(_.code.mkString("\n"))
+ .mkString("\n\n")
+ indent(s"object $schemaName {\n$tableCode")+"\n}\n"
+ }.mkString("\n\n")
val idType =
if (typeFile)// Should not be defined here.