diff options
-rw-r--r-- | src/main/scala/CodegenPlugin.scala | 21 | ||||
-rw-r--r-- | src/main/scala/NamespacedCodegen.scala | 226 | ||||
-rw-r--r-- | src/main/scala/OutputHelpers.scala | 28 |
3 files changed, 117 insertions, 158 deletions
diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala index 55d4089..c90fe66 100644 --- a/src/main/scala/CodegenPlugin.scala +++ b/src/main/scala/CodegenPlugin.scala @@ -46,6 +46,10 @@ object CodegenPlugin extends AutoPlugin { "codegen-type-replacements", "A map of types to find and replace" ) + lazy val codegenHeader = SettingKey[String]( + "codegen-header", + "Comments that go at the top of generated files; notices and tooling directives." + ) lazy val slickCodeGenTask = TaskKey[Unit]("gen-tables", "generate the table definitions") @@ -58,6 +62,7 @@ object CodegenPlugin extends AutoPlugin { codegenIdType := Option.empty, codegenSchemaImports := List.empty, codegenTypeReplacements := Map.empty, + codegenHeader := "AUTO-GENERATED Slick data model", slickCodeGenTask := Def.taskDyn { Def.task { codegenDatabaseConfigs.value.foreach { @@ -67,15 +72,17 @@ object CodegenPlugin extends AutoPlugin { config.outputPackage, Some(config.schemaWhitelist).filter(_.nonEmpty), config.outputPath, - config.foreignKeys, - (if (codegenIdType.value.isEmpty) - codegenSchemaBaseClassParts.value :+ "DefaultIdTypeMapper" - else - codegenSchemaBaseClassParts.value) match { - case Nil => "AnyRef" - case parts => parts.mkString(" with ") + config.foreignKeys, { + val parts = + (if (codegenIdType.value.isEmpty) + codegenSchemaBaseClassParts.value :+ "DefaultIdTypeMapper" + else + codegenSchemaBaseClassParts.value) + + Some(parts).filter(_.nonEmpty).map(_.mkString(" with ")) }, codegenIdType.value, + codegenHeader.value, codegenSchemaImports.value, codegenTypeReplacements.value ) diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala index c9f613a..85e21bf 100644 --- a/src/main/scala/NamespacedCodegen.scala +++ b/src/main/scala/NamespacedCodegen.scala @@ -6,7 +6,6 @@ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import slick.backend.DatabaseConfig import slick.codegen.{ - OutputHelpers, SourceCodeGenerator, StringGeneratorHelpers } @@ -23,8 +22,9 @@ object Generator { schemaNames: Option[List[String]], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)], - schemaBaseClass: String, + parentType: Option[String], idType: Option[String], + header: String, schemaImports: List[String], typeReplacements: Map[String, String]) = { val dc: DatabaseConfig[JdbcProfile] = @@ -32,76 +32,58 @@ object Generator { val parsedSchemasOpt: Option[Map[String, List[String]]] = schemaNames.map(SchemaParser.parse) - val dbModel: Model = try { - Await.result( + try { + val dbModel: Model = Await.result( dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)), Duration.Inf) + + parsedSchemasOpt.getOrElse(Map.empty).foreach { + case (schemaName, tables) => + val profile = + s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri + .getFragment()}").driver""" + + val schemaOnlyModel = Await.result( + dc.db.run( + SchemaParser.createModel(dc.driver, + Some(Map(schemaName -> tables)))), + Duration.Inf) + + val generator = new Generator(pkg, + dbModel, + schemaOnlyModel, + manualForeignKeys, + parentType, + idType, + header, + schemaImports, + typeReplacements) + generator.writeToFile(profile = profile, + folder = outputPath, + pkg = pkg, + container = schemaName, + fileName = s"${schemaName}.scala") + } } finally { dc.db.close() } - - val generator = new Generator(uri, - pkg, - dbModel, - outputPath, - manualForeignKeys, - schemaBaseClass, - idType, - schemaImports, - typeReplacements) - generator.code // Yes... Files are written as a side effect - parsedSchemasOpt - .getOrElse(Map()) - .keys - .map(schemaName => FileHelpers.schemaOutputPath(outputPath, schemaName)) } } -class PackageNameGenerator(pkg: String, dbModel: Model) - extends SourceCodeGenerator(dbModel) { - override def code: String = - s"""|// scalastyle:off - |package ${pkg} - | - |""".stripMargin -} - -class ImportGenerator(dbModel: Model, schemaImports: List[String]) - extends SourceCodeGenerator(dbModel) { - - val baseImports: String = schemaImports.map("import " + _).mkString("\n") + "\n" - - val hlistImports: String = - """|import slick.collection.heterogeneous._ - |import slick.collection.heterogeneous.syntax._ - |""".stripMargin - - val plainSqlMapperImports: String = - if (tables.exists(_.PlainSqlMapper.enabled)) - """|import slick.jdbc.{GetResult => GR} - |//NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n - |""".stripMargin - else "" - - override def code: String = - baseImports + hlistImports + plainSqlMapperImports -} - -class Generator(uri: URI, - pkg: String, - dbModel: Model, - outputPath: String, +class Generator(pkg: String, + fullDatabaseModel: Model, + schemaOnlyModel: Model, manualForeignKeys: Map[(String, String), (String, String)], - schemaBaseClass: String, + override val parentType: Option[String], idType: Option[String], + override val headerComment: String, schemaImports: List[String], typeReplacements: Map[String, String]) - extends SourceCodeGenerator(dbModel) + extends SourceCodeGenerator(schemaOnlyModel) with OutputHelpers { - val packageName = new PackageNameGenerator(pkg, dbModel).code - val allImports: String = new ImportGenerator(dbModel, schemaImports).code + override val imports = schemaImports.map("import " + _).mkString("\n") val defaultIdImplementation = """|final case class Id[T](v: Int) @@ -112,78 +94,16 @@ class Generator(uri: URI, |} |""".stripMargin - override def code: String = { - - val sortedSchemaTables: List[(String, Seq[TableDef])] = tables - .groupBy(t => t.model.name.schema.getOrElse("`public`")) - .toList - .sortBy(_._1) - - val schemata: String = sortedSchemaTables.map { - case (schemaName, tableDefs) => - val tableCode = tableDefs - .sortBy(_.model.name.table) - .map(_.code.mkString("\n")) - .mkString("\n\n") - - val ddlCode = - (if (ddlEnabled) { - "\n/** DDL for all tables. Call .create to execute. */" + - ( - if (tableDefs.length > 5) - "\nlazy val schema: profile.SchemaDescription = Array(" + tableDefs - .map(_.TableValue.name + ".schema") - .mkString(", ") + ").reduceLeft(_ ++ _)" - else if (tableDefs.nonEmpty) - "\nlazy val schema: profile.SchemaDescription = " + tableDefs - .map(_.TableValue.name + ".schema") - .mkString(" ++ ") - else - "\nlazy val schema: profile.SchemaDescription = profile.DDL(Nil, Nil)" - ) + - "\n\n" - } else "") - - val generatedSchema = s""" - |object ${schemaName} extends { - | val profile = slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri - .getFragment()}").driver - |} with $schemaBaseClass { - | import profile.api._ - | ${tableCode} - | ${ddlCode} - |} - |// scalastyle:on""".stripMargin - - writeStringToFile( - packageName + allImports + generatedSchema, - outputPath, - pkg, - s"${schemaName}.scala" - ) - - if (idType.isEmpty) { - writeStringToFile(packageName + defaultIdImplementation, - outputPath, - pkg, - "Id.scala") - } - - generatedSchema - }.mkString("\n\n") - - allImports + schemata - } + override def code = super.code.lines.drop(1).mkString("\n") + // Drops needless import: `"import slick.model.ForeignKeyAction\n"`. + // Alias to ForeignKeyAction is in profile.api + // TODO: fix upstream override def Table = new Table(_) { table => - // need this in order to use our own TableClass generator - override def definitions = - Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue) - - def TableClassRef = new TableClass() { + override def TableClass = new TableClass() { // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it - override def option = "" // if(hlistEnabled) "" else super.option + override def optionEnabled = false } // use hlists all the time @@ -212,15 +132,19 @@ class Generator(uri: URI, .map("a." + _.name) .mkString("::") + ":: HNil)" - def EntityTypeRef = new EntityTypeDef { + override def EntityType = new EntityTypeDef { override def code: String = - (if (classEnabled) "final " else "") + super.code + // Wartremover wants `final` + // But can't have the final case class inside the trait + // TODO: Fix by putting case classes in package or object + // TODO: Upstream default should be false. + (if (classEnabled) "sealed " else "") + super.code } override def Column = new Column(_) { column => - + // use fullDatabasemodel model here for cross-schema foreign keys val manualReferences = - SchemaParser.references(dbModel, manualForeignKeys) + SchemaParser.references(fullDatabaseModel, manualForeignKeys) // work out the destination of the foreign key def derefColumn(table: sModel.Table, @@ -230,7 +154,7 @@ class Generator(uri: URI, .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) .filter(columnFk => columnFk.referencedColumns.length == 1) .flatMap(_.referencedColumns.map(c => - (dbModel.tablesByName(c.table), c))) + (fullDatabaseModel.tablesByName(c.table), c))) assert(referencedColumn.distinct.length <= 1, referencedColumn) referencedColumn.headOption @@ -280,9 +204,7 @@ class Generator(uri: URI, s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)""" } } - } - } object SchemaParser { @@ -315,24 +237,26 @@ object SchemaParser { def createModel( jdbcProfile: JdbcProfile, mappedSchemasOpt: Option[Map[String, List[String]]]): DBIO[Model] = { - val allTables: DBIO[Vector[MTable]] = MTable.getTables - - val filteredTables = mappedSchemasOpt.map( - mappedSchemas => - allTables.map( - (tables: Vector[MTable]) => - tables.filter( - table => - table.name.schema - .flatMap(mappedSchemas.get) - .exists(ts => ts.isEmpty || ts.contains(table.name.name))))) - - jdbcProfile.createModel(filteredTables orElse Some(allTables)) - } - -} + import slick.jdbc.meta.MQName + + val filteredTables = mappedSchemasOpt.map { mappedSchemas => + MTable.getTables.map { (tables: Vector[MTable]) => + mappedSchemas.flatMap { + case (schemaName, tableNames) => + tableNames.map( + tableName => + tables + .find(table => + table.name match { + case MQName(_, Some(`schemaName`), `tableName`) => true + case _ => false + }) + .getOrElse(throw new IllegalArgumentException( + s"$schemaName.$tableName does not exist in the connected database."))) + }.toList + } + } -object FileHelpers { - def schemaOutputPath(path: String, schemaName: String): String = - Paths.get(path, s"${schemaName}.scala").toAbsolutePath().toString() + jdbcProfile.createModel(filteredTables) + } } diff --git a/src/main/scala/OutputHelpers.scala b/src/main/scala/OutputHelpers.scala new file mode 100644 index 0000000..ce22f2a --- /dev/null +++ b/src/main/scala/OutputHelpers.scala @@ -0,0 +1,28 @@ +trait OutputHelpers extends slick.codegen.OutputHelpers { + + def imports: String + + def headerComment: String = "" + + override def packageCode(profile: String, + pkg: String, + container: String, + parentType: Option[String]): String = { + val traitName = container.capitalize + "SchemaDef" + s"""|${headerComment.trim().lines.map("// " + _).mkString("\n")} + |package $pkg + | + |$imports + | + |/** Stand-alone Slick data model for immediate use */ + |object $container extends { + | val profile = $profile + |} with $traitName + | + |/** Slick data model trait for extension, choice of backend or usage in the cake pattern. (Make sure to initialize this late.) */ + |trait $traitName${parentType.fold("")(" extends " + _)} { + | import profile.api._ + | ${indent(code)} + |}""".stripMargin.trim() + } +} |