aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/NamespacedCodegen.scala
diff options
context:
space:
mode:
authorStewart Stewart <stewinsalot@gmail.com>2016-09-06 08:01:20 -0700
committerStewart Stewart <stewinsalot@gmail.com>2016-09-06 09:06:56 -0700
commitb32724fc7ac3d45de3635c1a8602e509179716f7 (patch)
tree9186a36012caa315b79bfdaed5a013972190b35b /src/main/scala/NamespacedCodegen.scala
parentc964bc59c65f31d3071f8c13e7182546b88ebee5 (diff)
downloadslick-codegen-plugin-b32724fc7ac3d45de3635c1a8602e509179716f7.tar.gz
slick-codegen-plugin-b32724fc7ac3d45de3635c1a8602e509179716f7.tar.bz2
slick-codegen-plugin-b32724fc7ac3d45de3635c1a8602e509179716f7.zip
Expect URI with dbconfig rather than hardcoding
Diffstat (limited to 'src/main/scala/NamespacedCodegen.scala')
-rw-r--r--src/main/scala/NamespacedCodegen.scala455
1 files changed, 232 insertions, 223 deletions
diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala
index 7d5beac..1069c68 100644
--- a/src/main/scala/NamespacedCodegen.scala
+++ b/src/main/scala/NamespacedCodegen.scala
@@ -22,7 +22,7 @@ import java.io.FileWriter
// generator places the relevant generated classes into separate
// objects--a "a" object, and a "b" object)
object NamespacedCodegen {
- def parseSchemaList(schemaList: String) = {
+ def parseSchemaList(schemaList: String): Map[String, List[String]] = {
val sl = schemaList.split(",").filter(_.trim.nonEmpty)
val tables: List[String] = sl.filter(_.contains(".")).toList
val schemas: List[String] = sl.filter(s => !(s.contains("."))).toList
@@ -33,254 +33,263 @@ object NamespacedCodegen {
}
mappedSchemas ++ mappedTables
-
}
- def main(args: Array[String]) = {
- args.toList match {
- case List(slickDriver, jdbcDriver, url, pkg, schemaList, fname, typesfname,
- user, password) => {
+ import slick.dbio.DBIO
+ import slick.model.Model
+
+ def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] =
+ driver.createModel(Some(
+ MTable.getTables.map(_.filter((t: MTable) => mappedSchemas
+ .get(t.name.schema.getOrElse(""))
+ .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))))))
+
+ val manualForeignKeys: Map[(String, String), (String, String)] =
+ Map(
+ ("portal.case_tumor_info", "patient_id") -> (("patients.patients", "patient_id")),
+ ("portal.case_tumor_info", "case_id") -> (("work_queues.reports", "case_id")),
+ ("portal.case_tumor_info", "cancer_id") -> (("patients.cancer", "cancer_id")),
+ ("confidential.join_pat", "patient_id") -> (("patients.patients", "patient_id")),
+ ("portal.case_tumor_info", "ordering_physician") -> (("patients.oncologists", "oncologist_id")),
+ ("patients.oncologists_case_permissions_view", "oncologist_id") -> (("patients.oncologists", "oncologist_id")),
+ ("patients.oncologists_case_permissions_view", "case_id") -> (("work_queues.reports", "case_id")),
+ ("case_accessioning.case_accessioning", "case_id") -> (("work_queues.reports", "case_id")),
+ ("case_accessioning.case_accessioning", "cancer_id") -> (("patients.cancer", "cancer_id")),
+ ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> (("patients.cancer", "cancer_id")),
+ ("experiments.experiments", "case_id") -> (("work_queues.reports", "case_id")),
+ ("samples.samples", "case_id") -> (("work_queues.reports", "case_id")) // TODO: Several of these can be added in a PR on postgres.
+ )
+
+ def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = {
+ def getTableColumn(tc: (String, String)) : (Table, Column) = {
+ val (tableName, columnName) = tc
+ val table = dbModel.tables.find(_.name.asString == tableName)
+ .getOrElse(throw new RuntimeException("No table " + tableName))
+ val column = table.columns.find(_.name == columnName)
+ .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName))
+ (table, column)
+ }
- val driver: JdbcProfile = {
- val module = currentMirror.staticModule(slickDriver)
- val reflectedModule = currentMirror.reflectModule(module)
- val driver = reflectedModule.instance.asInstanceOf[JdbcProfile]
- driver
- }
+ tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
+ }
- val schemas = schemaList.split(",").filter(_.trim.nonEmpty).toSet
- val mappedSchemas: Map[String, List[String]] = parseSchemaList(schemaList)
- implicit class BlockingRun(v: driver.api.Database) {
- def blockingRun[R](a: DBIOAction[R, NoStream, Nothing]) =
- Await.result(v.run(a), Duration.Inf)
+ import java.net.URI
+ import slick.backend.DatabaseConfig
+ import slick.util.ConfigExtensionMethods.configExtensionMethods
+
+ def run(
+ uri: URI,
+ outputDir: Option[String],
+ filename: String,
+ typesFilename: String,
+ schemaList: String
+ ): Unit = {
+ val dc = DatabaseConfig.forURI[JdbcProfile](uri)
+ val pkg = dc.config.getString("codegen.package")
+ val out = outputDir.getOrElse(dc.config.getStringOr("codegen.outputDir", "."))
+ val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName
+
+ // The following three parameters are unique to our code generator
+ // TODO: Decide: Put these in Typsafe Config or make it part of plugin interface?
+ // val filename = dc.config.getString("codegen.filename")
+ // val typesFilename = dc.config.getString("codegen.typesFilename")
+ // val schemaList = dc.config.getString("codegen.schemaList")
+
+ val mappedSchemas = parseSchemaList(schemaList)
+ val dbModel = Await.result(dc.db.run(createFilteredModel(dc.driver, mappedSchemas)), Duration.Inf)
+ //finally dc.db.close
+
+ val manualReferences = references(dbModel, manualForeignKeys)
+
+ def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
+ def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
+ (table.foreignKeys.toList
+ .filter(_.referencingColumns.forall(_ == column))
+ .flatMap(fk =>
+ fk.referencedColumns match {
+ case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
+ case _ => None
+ }) ++
+ manualReferences.get((table.name.asString, column.name)))
+ .headOption
+ .map((derefColumn _).tupled)
+ .getOrElse((table, column))
+
+ // Is this compatible with ***REMOVED*** Id? How do we make it generic?
+ def idType(t: m.Table) : String =
+ "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]"
+
+ override def code = {
+ //imports is copied right out of
+ //scala.slick.model.codegen.AbstractSourceCodeGenerator
+ // Why can't we simply re-use?
+
+ var imports =
+ // acyclic is unnecessary in generic projects
+ //if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
+ //else
+ "import slick.model.ForeignKeyAction\n" +
+ "import rows._\n" +
+ ( if(tables.exists(_.hlistEnabled)){
+ "import slick.collection.heterogeneous._\n"+
+ "import slick.collection.heterogeneous.syntax._\n"
+ } else ""
+ ) +
+ ( if(tables.exists(_.PlainSqlMapper.enabled)){
+ "import slick.jdbc.{GetResult => GR}\n"+
+ "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
+ } else ""
+ ) + "\n\n" // We didn't copy ddl though
+
+
+ val bySchema = tables.groupBy(t => {
+ t.model.name.schema
+ })
+
+ val schemaFor = (schema: String) => {
+ bySchema(Option(schema)).sortBy(_.model.name.table).map(
+ _.code.mkString("\n") // TODO explore here
+ ).mkString("\n\n")
}
- val dbModel = driver.api.Database
- .forURL(url, user, password, driver = jdbcDriver)
- .blockingRun(
- driver.createModel(Some(MTable.getTables.map(_.filter(
- (t: MTable) => mappedSchemas
- .get(t.name.schema.getOrElse(""))
- .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))
- )))))
-
- // TODO take this map in as a parameter if we want the generator code to independent of our project
- def getTableColumn(tc: (String, String)) : (Table, Column) = {
- val tb = dbModel.tables.find(_.name.asString == tc._1).getOrElse(throw new RuntimeException("No table " + tc._1))
- val co = tb.columns.find(_.name == tc._2).getOrElse(throw new RuntimeException("No column " + tc._2 + " in table " + tc._1))
- (tb, co)
- }
- def manualReferences : Map[(String, String), (Table, Column)] =
- Map(
- ("portal.case_tumor_info", "patient_id") -> ("patients.patients", "patient_id"),
- ("portal.case_tumor_info", "case_id") -> ("work_queues.reports", "case_id"),
- ("portal.case_tumor_info", "cancer_id") -> ("patients.cancer", "cancer_id"),
- ("confidential.join_pat", "patient_id") -> ("patients.patients", "patient_id"),
- ("portal.case_tumor_info", "ordering_physician") -> ("patients.oncologists", "oncologist_id"),
- ("patients.oncologists_case_permissions_view", "oncologist_id") -> ("patients.oncologists", "oncologist_id"),
- ("patients.oncologists_case_permissions_view", "case_id") -> ("work_queues.reports", "case_id"),
- ("case_accessioning.case_accessioning", "case_id") -> ("work_queues.reports", "case_id"),
- ("case_accessioning.case_accessioning", "cancer_id") -> ("patients.cancer", "cancer_id"),
- ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> ("patients.cancer", "cancer_id"),
- ("experiments.experiments", "case_id") -> ("work_queues.reports", "case_id"),
- ("samples.samples", "case_id") -> ("work_queues.reports", "case_id") // remove me when the foreign keys are fixed
- ).map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
-
- def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
- def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
- (table.foreignKeys.toList
- .filter(_.referencingColumns.forall(_ == column))
- .flatMap(fk =>
- fk.referencedColumns match {
- case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
- case _ => None
- }) ++
- manualReferences.get((table.name.asString, column.name)))
- .headOption
- .map((derefColumn _).tupled)
- .getOrElse((table, column))
-
- def idType(t: m.Table) : String =
- "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]"
-
- override def code = {
- //imports is copied right out of
- //scala.slick.model.codegen.AbstractSourceCodeGenerator
-
- var imports =
- if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
- else
- "import slick.model.ForeignKeyAction\n" +
- "import rows._\n" +
- ( if(tables.exists(_.hlistEnabled)){
- "import slick.collection.heterogeneous._\n"+
- "import slick.collection.heterogeneous.syntax._\n"
- } else ""
- ) +
- ( if(tables.exists(_.PlainSqlMapper.enabled)){
- "import slick.jdbc.{GetResult => GR}\n"+
- "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
- } else ""
- ) + "\n\n"
-
- val bySchema = tables.groupBy(t => {
- t.model.name.schema
- })
-
- val schemaFor = (schema: String) => {
- bySchema(Option(schema)).sortBy(_.model.name.table).map(
- _.code.mkString("\n")
- ).mkString("\n\n")
- }
-
- val schemata = mappedSchemas.keys.toList.sorted.map(
- s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
- ).mkString("\n\n")
-
- val idType =
- if (typeFile)
- """|case class Id[T](v: Int)
- |""".stripMargin
- else
- """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] =
- | MappedColumnType.base[Id[A], Int](_.v, Id(_))
- |import play.api.mvc.PathBindable
- |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v)
- |""".stripMargin
-
-
- imports + idType + schemata
- }
+ val schemata = mappedSchemas.keys.toList.sorted.map(
+ s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
+ ).mkString("\n\n")
+
+ val idType =
+ if (typeFile)// Should not be defined here.
+ """|case class Id[T](v: Int)
+ |""".stripMargin
+ else
+ // This should be in a separate Implicits trait
+ """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] =
+ | MappedColumnType.base[Id[A], Int](_.v, Id(_))
+ |import play.api.mvc.PathBindable
+ |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v)
+ |""".stripMargin
+ //pathbindable is play specific
+ // Id works only with labdash Id
+ imports + idType + schemata
+ }
- override def Table = new Table(_) {
- table =>
- // case classes go in the typeFile (but not types based on hlists)
- override def definitions =
- if (typeFile) Seq[Def]( EntityType )
- else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue )
+ // This is overridden to output classfiles elsewhere
+ override def Table = new Table(_) {
+ table =>
+ // case classes go in the typeFile (but not types based on hlists)
+ override def definitions =
+ if (typeFile) Seq[Def](EntityType)
+ else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue)
- def EntityTypeRef = new EntityType() {
- override def code =
- s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
- s"val $name = $pkg.rows.${model.name.schema.get}.$name"
- }
+ def EntityTypeRef = new EntityType() {
+ override def code =
+ s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
+ s"val $name = $pkg.rows.${model.name.schema.get}.$name"
+ }
- /** Creates a compound type from a given sequence of types.
- * Uses HList if hlistEnabled else tuple.
- */
- override def compoundType(types: Seq[String]): String =
- /** Creates a compound value from a given sequence of values.
- * Uses HList if hlistEnabled else tuple.
- */
- if (hlistEnabled){
- def mkHList(types: List[String]): String = types match {
- case Nil => "HNil"
- case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
- }
- mkHList(types.toList)
+ /** Creates a compound type from a given sequence of types.
+ * Uses HList if hlistEnabled else tuple.
+ */
+ override def compoundType(types: Seq[String]): String =
+ /** Creates a compound value from a given sequence of values.
+ * Uses HList if hlistEnabled else tuple.
+ */
+ // Yes! This is part of Slick now, yes?
+ if (hlistEnabled){
+ def mkHList(types: List[String]): String = types match {
+ case Nil => "HNil"
+ case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
}
- else compoundValue(types)
+ mkHList(types.toList)
+ }
+ else compoundValue(types)
- override def mappingEnabled = true
+ //why?
+ override def mappingEnabled = true
- override def compoundValue(values: Seq[String]): String =
- if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
- else if (values.size == 1) values.head
- else if(values.size <= 22) s"""(${values.mkString(", ")})"""
- else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")
+ override def compoundValue(values: Seq[String]): String =
+ if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
+ else if (values.size == 1) values.head
+ else if(values.size <= 22) s"""(${values.mkString(", ")})"""
+ else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")
- def TableClassRef = new TableClass() {
- // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
- override def option = if(columns.size <= 22) super.option else ""
- }
+ def TableClassRef = new TableClass() {
+ // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
+ override def option = if(columns.size <= 22) super.option else ""
+ }
- override def factory =
- if(columns.size <= 22) super.factory
- else {
- val args = columns.zipWithIndex.map("a"+_._2)
- val hlist = args.mkString("::") + ":: HNil"
- val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
- s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
- }
- override def extractor =
- if(columns.size <= 22) super.extractor
- else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"
-
- // make foreign keys refer to namespaced referents
- // if the referent is in a different namespace
- override def ForeignKey = new ForeignKey(_) {
- override def code = {
- val fkColumns = compoundValue(referencingColumns.map(_.name))
- // Add the schema name to qualify the referenced table name if:
- // 1. it's in a different schema from referencingTable, and
- // 2. it's not None
- val qualifier = if (referencedTable.model.name.schema
- != referencingTable.model.name.schema) {
- referencedTable.model.name.schema match {
- case Some(schema) => schema + "."
- case None => ""
- }
- } else {
- ""
- }
- val qualifiedName = qualifier + referencedTable.TableValue.name
- val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
- val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
- s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
- }
- }
+ override def factory =
+ if(columns.size <= 22) super.factory
+ else {
+ val args = columns.zipWithIndex.map("a"+_._2)
+ val hlist = args.mkString("::") + ":: HNil"
+ val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
+ s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
+ }
+ override def extractor =
+ if(columns.size <= 22) super.extractor
+ else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"
- override def Column = new Column(_) { column =>
- // customize db type -> scala type mapping, pls adjust it according to your environment
-
- override def rawType = {
- val (t, c) = derefColumn(table.model, column.model)
- //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
- if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
- else model.tpe match {
- case "java.sql.Date" => "tools.Date"
- case "java.sql.Time" => "tools.Time"
- case "java.sql.Timestamp" => "tools.Time"
- case _ => super.rawType
- }
+ // make foreign keys refer to namespaced referents
+ // if the referent is in a different namespace
+ override def ForeignKey = new ForeignKey(_) {
+ override def code = {
+ val fkColumns = compoundValue(referencingColumns.map(_.name))
+ // Add the schema name to qualify the referenced table name if:
+ // 1. it's in a different schema from referencingTable, and
+ // 2. it's not None
+ val qualifier = if (referencedTable.model.name.schema
+ != referencingTable.model.name.schema) {
+ referencedTable.model.name.schema match {
+ case Some(schema) => schema + "."
+ case None => ""
}
+ } else {
+ ""
}
+ val qualifiedName = qualifier + referencedTable.TableValue.name
+ val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
+ val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
+ s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
}
}
- def write(c: String, name: String) = {
- (new File(name).getParentFile).mkdirs()
- val fw = new FileWriter(name)
- fw.write(c)
- fw.close()
+ override def Column = new Column(_) { column =>
+ // customize db type -> scala type mapping, pls adjust it according to your environment
+
+ override def rawType = {
+ val (t, c) = derefColumn(table.model, column.model)
+ //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
+ if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
+ // ^ahahaha This is hacky
+ // This should be customizeable by client
+
+ else model.tpe match {
+ // how does this type work out?
+ // There should be a way to add adhoc custom time mappings
+ case "java.sql.Date" => "tools.Date"
+ case "java.sql.Time" => "tools.Time"
+ case "java.sql.Timestamp" => "tools.Time"
+ case _ => super.rawType
+ }
+ }
}
- val disableScalariform = "// format: OFF\n"
- val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
- val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code
-
- write(disableScalariform + tablesSource, fname)
- write(disableScalariform + rowsSource, typesfname)
}
- case _ => {
- println("""
-Usage: NamespacedCodegen.main(Array( slickDriver, jdbcDriver, url, pkg, schemaList, fileName ))
-
-slickDriver: Fully qualified name of Slick driver class, e.g. 'scala.slick.driver.PostgresDriver'
-
-jdbcDriver: Fully qualified name of jdbc driver class, e.g. 'org.postgresql.Driver'
-
-url: jdbc url, e.g. 'jdbc:postgresql://localhost/test'
-
-pkg: Scala package the generated code should be placed in
-
-schemaList: string with comma-separated list of schemas to include
+ }
-fName: name of output file
-""".trim)
- }
+ def write(c: String, name: String) = {
+ (new File(name).getParentFile).mkdirs()
+ val fw = new FileWriter(name)
+ fw.write(c)
+ fw.close()
}
+ val disableScalariform = "filename/ format: OFF\n"
+ val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
+ val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code
+
+ write(disableScalariform + tablesSource, filename)
+ write(disableScalariform + rowsSource, typesFilename)
}
}