From ca9dde318c7bc7b780cf5745cdee71e8430d9f6e Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Fri, 24 Feb 2017 09:06:40 -0500 Subject: remember to generate code only for a schema --- src/main/scala/Generators.scala | 4 +++- src/main/scala/TypedIdTable.scala | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'src/main/scala') diff --git a/src/main/scala/Generators.scala b/src/main/scala/Generators.scala index 51e02ce..59c9d9f 100644 --- a/src/main/scala/Generators.scala +++ b/src/main/scala/Generators.scala @@ -10,6 +10,7 @@ class RowSourceCodeGenerator( idType: Option[String], manualForeignKeys: Map[(String, String), (String, String)] ) extends TypedIdSourceCodeGenerator( + model, fullDatabaseModel, idType, manualForeignKeys @@ -40,7 +41,8 @@ class TableSourceCodeGenerator( parentType: Option[String], idType: Option[String], typeReplacements: Map[String, String]) - extends TypedIdSourceCodeGenerator(fullDatabaseModel, + extends TypedIdSourceCodeGenerator(schemaOnlyModel, + fullDatabaseModel, idType, manualForeignKeys) with TableOutputHelpers { diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala index c7f0151..1a8f986 100644 --- a/src/main/scala/TypedIdTable.scala +++ b/src/main/scala/TypedIdTable.scala @@ -2,11 +2,13 @@ import slick.codegen.SourceCodeGenerator import slick.{model => m} class TypedIdSourceCodeGenerator( - model: m.Model, + singleSchemaModel: m.Model, + databaseModel: m.Model, idType: Option[String], manualForeignKeys: Map[(String, String), (String, String)] -) extends SourceCodeGenerator(model) { - val manualReferences = SchemaParser.references(model, manualForeignKeys) +) extends SourceCodeGenerator(singleSchemaModel) { + val manualReferences = + SchemaParser.references(databaseModel, manualForeignKeys) def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { val referencedColumn: Seq[(m.Table, m.Column)] = @@ -14,7 +16,7 @@ class TypedIdSourceCodeGenerator( .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) .filter(columnFk => columnFk.referencedColumns.length == 1) .flatMap(_.referencedColumns.map(c => - (model.tablesByName(c.table), c))) + (databaseModel.tablesByName(c.table), c))) assert(referencedColumn.distinct.length <= 1, referencedColumn) referencedColumn.headOption -- cgit v1.2.3