diff options
author | Petro Verkhogliad <vpetro@gmail.com> | 2016-09-22 15:59:16 -0400 |
---|---|---|
committer | Petro Verkhogliad <vpetro@gmail.com> | 2016-09-22 15:59:16 -0400 |
commit | f9c3158fa1bbb67e661b712c9269a7e8cfee60d3 (patch) | |
tree | 844ab76aa7b1ae5f60c5d50247172c2957da2d3c /src | |
parent | 480e0c04d0a154066c6765c04f2e45ab25322008 (diff) | |
download | slick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.tar.gz slick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.tar.bz2 slick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.zip |
Output generated code into one file per schema
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/CodegenPlugin.scala | 12 | ||||
-rw-r--r-- | src/main/scala/NamespacedCodegen.scala | 414 |
2 files changed, 213 insertions, 213 deletions
diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala index 408cf9e..366ef19 100644 --- a/src/main/scala/CodegenPlugin.scala +++ b/src/main/scala/CodegenPlugin.scala @@ -7,26 +7,22 @@ object CodegenPlugin extends AutoPlugin { type TableColumn = (String, String) object autoImport { - lazy val codegen = TaskKey[Seq[File]]("gen-tables", "generate slick database schema") + lazy val codegen = TaskKey[Unit]("gen-tables", "generate slick database schema") lazy val codegenURI = SettingKey[String]("codegen-uri", "uri for the database configuration") lazy val codegenPackage = SettingKey[String]("codegen-package", "package in which to place generated code") - lazy val codegenTablesFile = SettingKey[String]("codegen-tables-file", "path for slick table models") - lazy val codegenRowsFile = SettingKey[String]("codegen-rows-file", "path for row case classes") + lazy val codegenOutputPath = SettingKey[String]("codegen-output-path", "directory to with the generated code will be written") lazy val codegenSchemaWhitelist = SettingKey[List[String]]("codegen-schema-whitelist", "schemas and tables to process") lazy val codegenForeignKeys = SettingKey[Map[TableColumn, TableColumn]]("codegen-foreign-keys", "foreign key references to data models add manually") lazy val slickCodeGenTask = Def.task { - NamespacedCodegen.run( + Generator.run( new java.net.URI(codegenURI.value), codegenPackage.value, - codegenTablesFile.value, - codegenRowsFile.value, codegenSchemaWhitelist.value, + codegenOutputPath.value, codegenForeignKeys.value ) - - Seq(file(codegenTablesFile.value), file(codegenRowsFile.value)) } } } diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala index c5f96e6..c7d4020 100644 --- a/src/main/scala/NamespacedCodegen.scala +++ b/src/main/scala/NamespacedCodegen.scala @@ -1,77 +1,140 @@ -import java.io.{FileWriter, File} import java.net.URI +import java.io.{BufferedWriter, File, FileWriter} +import java.nio.file.Paths import scala.concurrent.Await -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration -import scala.reflect.runtime.currentMirror - -import slick.ast.ColumnOption +import scala.concurrent.ExecutionContext.Implicits.global import slick.backend.DatabaseConfig -import slick.codegen.{AbstractGenerator, SourceCodeGenerator} -import slick.dbio.{DBIO, DBIOAction, NoStream} +import slick.codegen.SourceCodeGenerator +import slick.dbio.DBIO import slick.driver.JdbcProfile import slick.jdbc.meta.MTable -import slick.{model => m} -import slick.model.{Column, Model, Table} -import slick.util.ConfigExtensionMethods.configExtensionMethods - -// NamespacedCodegen handles tables within schemas by namespacing them -// within objects here -// (e.g., table a.foo and table b.foo can co-exist, because this code -// generator places the relevant generated classes into separate -// objects--a "a" object, and a "b" object) -object NamespacedCodegen { - def parseSchemaList(schemaTableNames: List[String]): Map[String, List[String]] = - schemaTableNames.map(_.split('.')) - .groupBy(_.head) - .mapValues(_.flatMap(_.tail)) - .toMap +import slick.{model => sModel} +import slick.model.{Model, Column, Table} - def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] = - driver.createModel(Some( - MTable.getTables.map(_.filter((t: MTable) => - t.name.schema.flatMap(mappedSchemas.get).exists(tables => - tables.isEmpty || tables.contains(t.name.name)))))) +object Generator { - def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = { - def getTableColumn(tc: (String, String)) : (Table, Column) = { - val (tableName, columnName) = tc - val table = dbModel.tables.find(_.name.asString == tableName) - .getOrElse(throw new RuntimeException("No table " + tableName)) - val column = table.columns.find(_.name == columnName) - .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName)) - (table, column) - } - tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))} + def run(uri: URI, pkg: String, schemaNames: List[String], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) = { + val dc: DatabaseConfig[JdbcProfile] = DatabaseConfig.forURI[JdbcProfile](uri) + val parsedSchemas: Map[String, List[String]] = SchemaParser.parse(schemaNames) + val dbModel: Model = Await.result(dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemas)), Duration.Inf) + + FileHelpers.createOutputPath(outputPath) + + val generator = new Generator(uri, pkg, dbModel, outputPath, manualForeignKeys) + val generatedCode = generator.code + parsedSchemas.keys.map(schemaName => FileHelpers.schemaOutputPath(outputPath, schemaName)) } - def run( - uri: URI, - pkg: String, - filename: String, - typesFilename: String, - schemaTableNames: List[String], - manualForeignKeys: Map[(String, String), (String, String)] - ): Unit = { - val dc = DatabaseConfig.forURI[JdbcProfile](uri) - val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName - val mappedSchemas = parseSchemaList(schemaTableNames) - val dbModel = Await.result(dc.db.run(createFilteredModel(dc.driver, mappedSchemas)), Duration.Inf) - //finally dc.db.close - - val manualReferences = references(dbModel, manualForeignKeys) - - def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){ - - def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { - val referencedColumn: Seq[(m.Table, m.Column)] = table.foreignKeys +} + +class Generator(uri: URI, pkg: String, dbModel: Model, outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) extends SourceCodeGenerator(dbModel) { + + override def code: String = { + val baseImport: String = + s""" + |package ${pkg} + | + |import com.drivergrp.core._ + |import com.drivergrp.core.database._ + | + |""".stripMargin + + val hlistImports: String = + if (tables.exists(_.hlistEnabled)) + """ + |import slick.collection.heterogeneous._ + |import slick.collection.heterogeneous.syntax._ + | + |""".stripMargin + else "" + + val plainSqlMapperImports: String = + if (tables.exists(_.PlainSqlMapper.enabled)) + """ + |import slick.jdbc.{GetResult => GR} + |//NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n + | + |""".stripMargin + else "" + + val allImports: String = baseImport + hlistImports + plainSqlMapperImports + + val sortedSchemaTables: List[(String, Seq[TableDef])] = tables + .groupBy(t => t.model.name.schema.getOrElse("`public`")) + .toList.sortBy(_._1) + + val schemata: String = sortedSchemaTables.map { + case (schemaName, tableDefs) => + val tableCode = tableDefs.sortBy(_.model.name.table).map(_.code.mkString("\n")) .mkString("\n\n") + val generatedSchema = s""" + |object ${schemaName} extends IdColumnTypes { + | override val database = com.drivergrp.core.database.Database.fromConfig("${uri.getFragment()}") + | import database.profile.api._ + | // TODO: the name for this implicit should be changed in driver core + | implicit val tColType = MappedColumnType.base[com.drivergrp.core.time.Time, Long](time => time.millis, com.drivergrp.core.time.Time(_)) + | ${tableCode} + |} + """.stripMargin + + FileHelpers.write( + FileHelpers.schemaOutputPath(outputPath, schemaName), + allImports + generatedSchema + ) + + generatedSchema + }.mkString("\n\n") + + allImports + schemata + } + + + override def Table = new Table(_) { table => + + // need this in order to use our own TableClass generator + override def definitions = Seq[Def]( EntityType, PlainSqlMapper, TableClassRef, TableValue ) + + def TableClassRef = new TableClass() { + // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it + override def option = "" // if(hlistEnabled) "" else super.option + } + + // use hlists all the time + override def hlistEnabled: Boolean = true + + // if false rows are type aliases to hlists, if true rows are case classes + override def mappingEnabled: Boolean = true + + // create case class from colums + override def factory: String = + if(!hlistEnabled) super.factory + else { + val args = columns.zipWithIndex.map("a"+_._2) + val hlist = args.mkString("::") + ":: HNil" + val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type" + s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" + } + + // from case class create columns + override def extractor: String = + if(!hlistEnabled) super.extractor + else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)" + + + override def Column = new Column(_) { + column => + + val manualReferences = SchemaParser.references(dbModel, manualForeignKeys) + + // work out the destination of the foreign key + def derefColumn(table: sModel.Table, column: sModel.Column): (sModel.Table, sModel.Column) = { + val referencedColumn: Seq[(sModel.Table, sModel.Column)] = table.foreignKeys .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) .filter(columnFk => columnFk.referencedColumns.length == 1) .flatMap(_.referencedColumns .map(c => (dbModel.tablesByName(c.table), c))) - assert(referencedColumn.distinct.length <= 1, referencedColumn) referencedColumn.headOption @@ -80,164 +143,105 @@ object NamespacedCodegen { .getOrElse((table, column)) } - // Is this compatible with ***REMOVED*** Id? How do we make it generic? - def idType(t: m.Table) : String = - "Id["+ t.name.schema.fold("public")(_ + ".") + t.name.table.toCamelCase+"Row]" + // re-write ids, and time types + override def rawType: String = { + val (t, c) = derefColumn(table.model, column.model) + if (c.options.contains(slick.ast.ColumnOption.PrimaryKey)) TypeGenerator.idType(pkg, t) + else model.tpe match { + // TODO: There should be a way to add adhoc custom time mappings + case "java.sql.Time" => "com.drivergrp.core.time.Time" + case "java.sql.Timestamp" => "com.drivergrp.core.time.Time" + case _ => super.rawType + } + } + } + override def ForeignKey = new ForeignKey(_) { override def code = { - //imports is copied right out of - //scala.slick.model.codegen.AbstractSourceCodeGenerator - // Why can't we simply re-use? - - var imports = - "import slick.model.ForeignKeyAction\n" + - ( if(tables.exists(_.hlistEnabled)){ - "import slick.collection.heterogeneous._\n" + - "import slick.collection.heterogeneous.syntax._\n" + - "import com.drivergrp.core._\n" + - "import com.drivergrp.core.database._\n" - } else "" - ) + - ( if(tables.exists(_.PlainSqlMapper.enabled)){ - "import slick.jdbc.{GetResult => GR}\n"+ - "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n" - } else "" - ) + "\n\n" // We didn't copy ddl though - - val sortedSchemaTables: List[(String, Seq[TableDef])] = tables - .groupBy(t => t.model.name.schema.getOrElse("`public`")) - .toList.sortBy(_._1) - - val schemata: String = sortedSchemaTables.map { - case (schemaName, tables) => - val tableCode = tables - .sortBy(_.model.name.table) - .map(_.code.mkString("\n")) - .mkString("\n\n") - indent(s"object $schemaName extends CoreDBMappers {\n$tableCode")+"\n}\n" - }.mkString("\n\n") - - val mapperTrait: String = """trait CoreDBMappers extends com.drivergrp.core.database.IdColumnTypes { override val database = com.drivergrp.core.database.Database.fromConfig("slick.db.default") }""" - - imports + mapperTrait + "\n\n" + schemata + val fkColumns = compoundValue(referencingColumns.map(_.name)) + val qualifier = + if (referencedTable.model.name.schema == referencingTable.model.name.schema) "" + else referencedTable.model.name.schema.fold("")(sname => s"$pkg.$sname.") + + val qualifiedName = qualifier + referencedTable.TableValue.name + val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}")) + val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk" + s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)""" } + } - // This is overridden to output classfiles elsewhere - override def Table = new Table(_) { - table => - // case classes go in the typeFile (but not types based on hlists) - override def definitions = - if (typeFile) Seq[Def](EntityType) - else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue) + } +} - def EntityTypeRef = new EntityType() { - override def code = - s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++ - s"val $name = $pkg.rows.${model.name.schema.get}.$name" - } +object SchemaParser { + def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = { + def getTableColumn(tc: (String, String)) : (Table, Column) = { + val (tableName, columnName) = tc + val table = dbModel.tables.find(_.name.asString == tableName) + .getOrElse(throw new RuntimeException("No table " + tableName)) + val column = table.columns.find(_.name == columnName) + .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName)) + (table, column) + } - /** Creates a compound type from a given sequence of types. - * Uses HList if hlistEnabled else tuple. - */ - override def compoundType(types: Seq[String]): String = - /** Creates a compound value from a given sequence of values. - * Uses HList if hlistEnabled else tuple. - */ - // Yes! This is part of Slick now, yes? - if (hlistEnabled){ - def mkHList(types: List[String]): String = types match { - case Nil => "HNil" - case e :: tail => s"HCons[$e," + mkHList(tail) + "]" - } - mkHList(types.toList) - } - else compoundValue(types) - - //why? - override def mappingEnabled = true - - override def compoundValue(values: Seq[String]): String = - if (hlistEnabled) values.mkString(" :: ") + " :: HNil" - else if (values.size == 1) values.head - else if(values.size <= 22) s"""(${values.mkString(", ")})""" - else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.") - - def TableClassRef = new TableClass() { - // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it - override def option = if(columns.size <= 22) super.option else "" - } + tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))} + } - override def factory = - if(columns.size <= 22) super.factory - else { - val args = columns.zipWithIndex.map("a"+_._2) - val hlist = args.mkString("::") + ":: HNil" - val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type" - s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" - } - override def extractor = - if(columns.size <= 22) super.extractor - else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)" - - // make foreign keys refer to namespaced referents - // if the referent is in a different namespace - override def ForeignKey = new ForeignKey(_) { - override def code = { - val fkColumns = compoundValue(referencingColumns.map(_.name)) - // Add the schema name to qualify the referenced table name if: - // 1. it's in a different schema from referencingTable, and - // 2. it's not None - val qualifier = if (referencedTable.model.name.schema - != referencingTable.model.name.schema) { - referencedTable.model.name.schema match { - case Some(schema) => schema + "." - case None => "" - } - } else { - "" - } - val qualifiedName = qualifier + referencedTable.TableValue.name - val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}")) - val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk" - s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})""" - } - } + def parse(schemaTableNames: List[String]): Map[String, List[String]] = + schemaTableNames.map(_.split('.')) + .groupBy(_.head) + .mapValues(_.flatMap(_.tail)) - override def Column = new Column(_) { column => - // customize db type -> scala type mapping, pls adjust it according to your environment - - override def rawType = { - val (t, c) = derefColumn(table.model, column.model) - //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n") - if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t) - // ^ahahaha This is hacky - // This should be customizeable by client - - else model.tpe match { - // how does this type work out? - // There should be a way to add adhoc custom time mappings - case "java.sql.Time" => "com.drivergrp.core.time.Time" - case "java.sql.Timestamp" => "com.drivergrp.core.time.Time" - case _ => super.rawType - } - } - } - } - } + def createModel(jdbcProfile: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] = { + val allTables: DBIO[Vector[MTable]] = MTable.getTables - def write(c: String, name: String) = { - (new File(name).getParentFile).mkdirs() - val fw = new FileWriter(name) - fw.write(c) - fw.close() - } - val disableScalaStyle = "// scalastyle:off\n" - val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None) - val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code + val filteredTables: DBIO[Vector[MTable]] = allTables.map( + (tables: Vector[MTable]) => tables.filter(table => + table.name.schema.flatMap(mappedSchemas.get).exists(ts => + ts.isEmpty || ts.contains(table.name.name)) + ) + ) + jdbcProfile.createModel(Some(filteredTables)) + } + +} - write(disableScalaStyle + tablesSource, filename) - write(disableScalaStyle + rowsSource, typesFilename) + +object TypeGenerator { + // generate the id types + def idType(pkg: String, t: sModel.Table): String = { + val header = s"Id[" + val schemaName = t.name.schema.fold("")(_ + ".") + val tableName = StringHelpers.toCamelCase(t.name.table) + val footer = "]" + s"${header}${pkg}.${schemaName}${tableName}Row${footer}" } } + +object FileHelpers { + + def schemaOutputPath(path: String, schemaName: String): String = + Paths.get(path, s"${schemaName}.scala").toAbsolutePath().toString() + + def createOutputPath(path: String): Boolean = (new File(path)).mkdirs() + + def write(filename: String, content: String) = { + val file = new File(filename) + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(content) + bw.close() + } + +} + +object StringHelpers { + // copied from GeneratorHelpers.StringExtensions + final def toCamelCase(value: String) = value.toLowerCase + .split("_") + .map{ case "" => "_" case s => s } // avoid possible collisions caused by multiple '_' + .map(_.capitalize) + .mkString("") + +} |