aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/NamespacedCodegen.scala
diff options
context:
space:
mode:
authorPetro Verkhogliad <vpetro@gmail.com>2016-09-22 15:59:16 -0400
committerPetro Verkhogliad <vpetro@gmail.com>2016-09-22 15:59:16 -0400
commitf9c3158fa1bbb67e661b712c9269a7e8cfee60d3 (patch)
tree844ab76aa7b1ae5f60c5d50247172c2957da2d3c /src/main/scala/NamespacedCodegen.scala
parent480e0c04d0a154066c6765c04f2e45ab25322008 (diff)
downloadslick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.tar.gz
slick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.tar.bz2
slick-codegen-plugin-f9c3158fa1bbb67e661b712c9269a7e8cfee60d3.zip
Output generated code into one file per schema
Diffstat (limited to 'src/main/scala/NamespacedCodegen.scala')
-rw-r--r--src/main/scala/NamespacedCodegen.scala414
1 files changed, 209 insertions, 205 deletions
diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala
index c5f96e6..c7d4020 100644
--- a/src/main/scala/NamespacedCodegen.scala
+++ b/src/main/scala/NamespacedCodegen.scala
@@ -1,77 +1,140 @@
-import java.io.{FileWriter, File}
import java.net.URI
+import java.io.{BufferedWriter, File, FileWriter}
+import java.nio.file.Paths
import scala.concurrent.Await
-import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
-import scala.reflect.runtime.currentMirror
-
-import slick.ast.ColumnOption
+import scala.concurrent.ExecutionContext.Implicits.global
import slick.backend.DatabaseConfig
-import slick.codegen.{AbstractGenerator, SourceCodeGenerator}
-import slick.dbio.{DBIO, DBIOAction, NoStream}
+import slick.codegen.SourceCodeGenerator
+import slick.dbio.DBIO
import slick.driver.JdbcProfile
import slick.jdbc.meta.MTable
-import slick.{model => m}
-import slick.model.{Column, Model, Table}
-import slick.util.ConfigExtensionMethods.configExtensionMethods
-
-// NamespacedCodegen handles tables within schemas by namespacing them
-// within objects here
-// (e.g., table a.foo and table b.foo can co-exist, because this code
-// generator places the relevant generated classes into separate
-// objects--a "a" object, and a "b" object)
-object NamespacedCodegen {
- def parseSchemaList(schemaTableNames: List[String]): Map[String, List[String]] =
- schemaTableNames.map(_.split('.'))
- .groupBy(_.head)
- .mapValues(_.flatMap(_.tail))
- .toMap
+import slick.{model => sModel}
+import slick.model.{Model, Column, Table}
- def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] =
- driver.createModel(Some(
- MTable.getTables.map(_.filter((t: MTable) =>
- t.name.schema.flatMap(mappedSchemas.get).exists(tables =>
- tables.isEmpty || tables.contains(t.name.name))))))
+object Generator {
- def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = {
- def getTableColumn(tc: (String, String)) : (Table, Column) = {
- val (tableName, columnName) = tc
- val table = dbModel.tables.find(_.name.asString == tableName)
- .getOrElse(throw new RuntimeException("No table " + tableName))
- val column = table.columns.find(_.name == columnName)
- .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName))
- (table, column)
- }
- tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
+ def run(uri: URI, pkg: String, schemaNames: List[String], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) = {
+ val dc: DatabaseConfig[JdbcProfile] = DatabaseConfig.forURI[JdbcProfile](uri)
+ val parsedSchemas: Map[String, List[String]] = SchemaParser.parse(schemaNames)
+ val dbModel: Model = Await.result(dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemas)), Duration.Inf)
+
+ FileHelpers.createOutputPath(outputPath)
+
+ val generator = new Generator(uri, pkg, dbModel, outputPath, manualForeignKeys)
+ val generatedCode = generator.code
+ parsedSchemas.keys.map(schemaName => FileHelpers.schemaOutputPath(outputPath, schemaName))
}
- def run(
- uri: URI,
- pkg: String,
- filename: String,
- typesFilename: String,
- schemaTableNames: List[String],
- manualForeignKeys: Map[(String, String), (String, String)]
- ): Unit = {
- val dc = DatabaseConfig.forURI[JdbcProfile](uri)
- val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName
- val mappedSchemas = parseSchemaList(schemaTableNames)
- val dbModel = Await.result(dc.db.run(createFilteredModel(dc.driver, mappedSchemas)), Duration.Inf)
- //finally dc.db.close
-
- val manualReferences = references(dbModel, manualForeignKeys)
-
- def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
-
- def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = {
- val referencedColumn: Seq[(m.Table, m.Column)] = table.foreignKeys
+}
+
+class Generator(uri: URI, pkg: String, dbModel: Model, outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) extends SourceCodeGenerator(dbModel) {
+
+ override def code: String = {
+ val baseImport: String =
+ s"""
+ |package ${pkg}
+ |
+ |import com.drivergrp.core._
+ |import com.drivergrp.core.database._
+ |
+ |""".stripMargin
+
+ val hlistImports: String =
+ if (tables.exists(_.hlistEnabled))
+ """
+ |import slick.collection.heterogeneous._
+ |import slick.collection.heterogeneous.syntax._
+ |
+ |""".stripMargin
+ else ""
+
+ val plainSqlMapperImports: String =
+ if (tables.exists(_.PlainSqlMapper.enabled))
+ """
+ |import slick.jdbc.{GetResult => GR}
+ |//NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n
+ |
+ |""".stripMargin
+ else ""
+
+ val allImports: String = baseImport + hlistImports + plainSqlMapperImports
+
+ val sortedSchemaTables: List[(String, Seq[TableDef])] = tables
+ .groupBy(t => t.model.name.schema.getOrElse("`public`"))
+ .toList.sortBy(_._1)
+
+ val schemata: String = sortedSchemaTables.map {
+ case (schemaName, tableDefs) =>
+ val tableCode = tableDefs.sortBy(_.model.name.table).map(_.code.mkString("\n")) .mkString("\n\n")
+ val generatedSchema = s"""
+ |object ${schemaName} extends IdColumnTypes {
+ | override val database = com.drivergrp.core.database.Database.fromConfig("${uri.getFragment()}")
+ | import database.profile.api._
+ | // TODO: the name for this implicit should be changed in driver core
+ | implicit val tColType = MappedColumnType.base[com.drivergrp.core.time.Time, Long](time => time.millis, com.drivergrp.core.time.Time(_))
+ | ${tableCode}
+ |}
+ """.stripMargin
+
+ FileHelpers.write(
+ FileHelpers.schemaOutputPath(outputPath, schemaName),
+ allImports + generatedSchema
+ )
+
+ generatedSchema
+ }.mkString("\n\n")
+
+ allImports + schemata
+ }
+
+
+ override def Table = new Table(_) { table =>
+
+ // need this in order to use our own TableClass generator
+ override def definitions = Seq[Def]( EntityType, PlainSqlMapper, TableClassRef, TableValue )
+
+ def TableClassRef = new TableClass() {
+ // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it
+ override def option = "" // if(hlistEnabled) "" else super.option
+ }
+
+ // use hlists all the time
+ override def hlistEnabled: Boolean = true
+
+ // if false rows are type aliases to hlists, if true rows are case classes
+ override def mappingEnabled: Boolean = true
+
+ // create case class from colums
+ override def factory: String =
+ if(!hlistEnabled) super.factory
+ else {
+ val args = columns.zipWithIndex.map("a"+_._2)
+ val hlist = args.mkString("::") + ":: HNil"
+ val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
+ s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
+ }
+
+ // from case class create columns
+ override def extractor: String =
+ if(!hlistEnabled) super.extractor
+ else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"
+
+
+ override def Column = new Column(_) {
+ column =>
+
+ val manualReferences = SchemaParser.references(dbModel, manualForeignKeys)
+
+ // work out the destination of the foreign key
+ def derefColumn(table: sModel.Table, column: sModel.Column): (sModel.Table, sModel.Column) = {
+ val referencedColumn: Seq[(sModel.Table, sModel.Column)] = table.foreignKeys
.filter(tableFk => tableFk.referencingColumns.forall(_ == column))
.filter(columnFk => columnFk.referencedColumns.length == 1)
.flatMap(_.referencedColumns
.map(c => (dbModel.tablesByName(c.table), c)))
-
assert(referencedColumn.distinct.length <= 1, referencedColumn)
referencedColumn.headOption
@@ -80,164 +143,105 @@ object NamespacedCodegen {
.getOrElse((table, column))
}
- // Is this compatible with ***REMOVED*** Id? How do we make it generic?
- def idType(t: m.Table) : String =
- "Id["+ t.name.schema.fold("public")(_ + ".") + t.name.table.toCamelCase+"Row]"
+ // re-write ids, and time types
+ override def rawType: String = {
+ val (t, c) = derefColumn(table.model, column.model)
+ if (c.options.contains(slick.ast.ColumnOption.PrimaryKey)) TypeGenerator.idType(pkg, t)
+ else model.tpe match {
+ // TODO: There should be a way to add adhoc custom time mappings
+ case "java.sql.Time" => "com.drivergrp.core.time.Time"
+ case "java.sql.Timestamp" => "com.drivergrp.core.time.Time"
+ case _ => super.rawType
+ }
+ }
+ }
+ override def ForeignKey = new ForeignKey(_) {
override def code = {
- //imports is copied right out of
- //scala.slick.model.codegen.AbstractSourceCodeGenerator
- // Why can't we simply re-use?
-
- var imports =
- "import slick.model.ForeignKeyAction\n" +
- ( if(tables.exists(_.hlistEnabled)){
- "import slick.collection.heterogeneous._\n" +
- "import slick.collection.heterogeneous.syntax._\n" +
- "import com.drivergrp.core._\n" +
- "import com.drivergrp.core.database._\n"
- } else ""
- ) +
- ( if(tables.exists(_.PlainSqlMapper.enabled)){
- "import slick.jdbc.{GetResult => GR}\n"+
- "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
- } else ""
- ) + "\n\n" // We didn't copy ddl though
-
- val sortedSchemaTables: List[(String, Seq[TableDef])] = tables
- .groupBy(t => t.model.name.schema.getOrElse("`public`"))
- .toList.sortBy(_._1)
-
- val schemata: String = sortedSchemaTables.map {
- case (schemaName, tables) =>
- val tableCode = tables
- .sortBy(_.model.name.table)
- .map(_.code.mkString("\n"))
- .mkString("\n\n")
- indent(s"object $schemaName extends CoreDBMappers {\n$tableCode")+"\n}\n"
- }.mkString("\n\n")
-
- val mapperTrait: String = """trait CoreDBMappers extends com.drivergrp.core.database.IdColumnTypes { override val database = com.drivergrp.core.database.Database.fromConfig("slick.db.default") }"""
-
- imports + mapperTrait + "\n\n" + schemata
+ val fkColumns = compoundValue(referencingColumns.map(_.name))
+ val qualifier =
+ if (referencedTable.model.name.schema == referencingTable.model.name.schema) ""
+ else referencedTable.model.name.schema.fold("")(sname => s"$pkg.$sname.")
+
+ val qualifiedName = qualifier + referencedTable.TableValue.name
+ val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
+ val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
+ s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)"""
}
+ }
- // This is overridden to output classfiles elsewhere
- override def Table = new Table(_) {
- table =>
- // case classes go in the typeFile (but not types based on hlists)
- override def definitions =
- if (typeFile) Seq[Def](EntityType)
- else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue)
+ }
+}
- def EntityTypeRef = new EntityType() {
- override def code =
- s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
- s"val $name = $pkg.rows.${model.name.schema.get}.$name"
- }
+object SchemaParser {
+ def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = {
+ def getTableColumn(tc: (String, String)) : (Table, Column) = {
+ val (tableName, columnName) = tc
+ val table = dbModel.tables.find(_.name.asString == tableName)
+ .getOrElse(throw new RuntimeException("No table " + tableName))
+ val column = table.columns.find(_.name == columnName)
+ .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName))
+ (table, column)
+ }
- /** Creates a compound type from a given sequence of types.
- * Uses HList if hlistEnabled else tuple.
- */
- override def compoundType(types: Seq[String]): String =
- /** Creates a compound value from a given sequence of values.
- * Uses HList if hlistEnabled else tuple.
- */
- // Yes! This is part of Slick now, yes?
- if (hlistEnabled){
- def mkHList(types: List[String]): String = types match {
- case Nil => "HNil"
- case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
- }
- mkHList(types.toList)
- }
- else compoundValue(types)
-
- //why?
- override def mappingEnabled = true
-
- override def compoundValue(values: Seq[String]): String =
- if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
- else if (values.size == 1) values.head
- else if(values.size <= 22) s"""(${values.mkString(", ")})"""
- else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")
-
- def TableClassRef = new TableClass() {
- // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
- override def option = if(columns.size <= 22) super.option else ""
- }
+ tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
+ }
- override def factory =
- if(columns.size <= 22) super.factory
- else {
- val args = columns.zipWithIndex.map("a"+_._2)
- val hlist = args.mkString("::") + ":: HNil"
- val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
- s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
- }
- override def extractor =
- if(columns.size <= 22) super.extractor
- else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"
-
- // make foreign keys refer to namespaced referents
- // if the referent is in a different namespace
- override def ForeignKey = new ForeignKey(_) {
- override def code = {
- val fkColumns = compoundValue(referencingColumns.map(_.name))
- // Add the schema name to qualify the referenced table name if:
- // 1. it's in a different schema from referencingTable, and
- // 2. it's not None
- val qualifier = if (referencedTable.model.name.schema
- != referencingTable.model.name.schema) {
- referencedTable.model.name.schema match {
- case Some(schema) => schema + "."
- case None => ""
- }
- } else {
- ""
- }
- val qualifiedName = qualifier + referencedTable.TableValue.name
- val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
- val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
- s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
- }
- }
+ def parse(schemaTableNames: List[String]): Map[String, List[String]] =
+ schemaTableNames.map(_.split('.'))
+ .groupBy(_.head)
+ .mapValues(_.flatMap(_.tail))
- override def Column = new Column(_) { column =>
- // customize db type -> scala type mapping, pls adjust it according to your environment
-
- override def rawType = {
- val (t, c) = derefColumn(table.model, column.model)
- //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
- if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
- // ^ahahaha This is hacky
- // This should be customizeable by client
-
- else model.tpe match {
- // how does this type work out?
- // There should be a way to add adhoc custom time mappings
- case "java.sql.Time" => "com.drivergrp.core.time.Time"
- case "java.sql.Timestamp" => "com.drivergrp.core.time.Time"
- case _ => super.rawType
- }
- }
- }
- }
- }
+ def createModel(jdbcProfile: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] = {
+ val allTables: DBIO[Vector[MTable]] = MTable.getTables
- def write(c: String, name: String) = {
- (new File(name).getParentFile).mkdirs()
- val fw = new FileWriter(name)
- fw.write(c)
- fw.close()
- }
- val disableScalaStyle = "// scalastyle:off\n"
- val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
- val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code
+ val filteredTables: DBIO[Vector[MTable]] = allTables.map(
+ (tables: Vector[MTable]) => tables.filter(table =>
+ table.name.schema.flatMap(mappedSchemas.get).exists(ts =>
+ ts.isEmpty || ts.contains(table.name.name))
+ )
+ )
+ jdbcProfile.createModel(Some(filteredTables))
+ }
+
+}
- write(disableScalaStyle + tablesSource, filename)
- write(disableScalaStyle + rowsSource, typesFilename)
+
+object TypeGenerator {
+ // generate the id types
+ def idType(pkg: String, t: sModel.Table): String = {
+ val header = s"Id["
+ val schemaName = t.name.schema.fold("")(_ + ".")
+ val tableName = StringHelpers.toCamelCase(t.name.table)
+ val footer = "]"
+ s"${header}${pkg}.${schemaName}${tableName}Row${footer}"
}
}
+
+object FileHelpers {
+
+ def schemaOutputPath(path: String, schemaName: String): String =
+ Paths.get(path, s"${schemaName}.scala").toAbsolutePath().toString()
+
+ def createOutputPath(path: String): Boolean = (new File(path)).mkdirs()
+
+ def write(filename: String, content: String) = {
+ val file = new File(filename)
+ val bw = new BufferedWriter(new FileWriter(file))
+ bw.write(content)
+ bw.close()
+ }
+
+}
+
+object StringHelpers {
+ // copied from GeneratorHelpers.StringExtensions
+ final def toCamelCase(value: String) = value.toLowerCase
+ .split("_")
+ .map{ case "" => "_" case s => s } // avoid possible collisions caused by multiple '_'
+ .map(_.capitalize)
+ .mkString("")
+
+}