From b32724fc7ac3d45de3635c1a8602e509179716f7 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Tue, 6 Sep 2016 08:01:20 -0700 Subject: Expect URI with dbconfig rather than hardcoding --- src/main/scala/CodegenPlugin.scala | 14 +- src/main/scala/NamespacedCodegen.scala | 455 +++++++++++++++++---------------- 2 files changed, 238 insertions(+), 231 deletions(-) (limited to 'src/main') diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala index b565e8a..3a09585 100644 --- a/src/main/scala/CodegenPlugin.scala +++ b/src/main/scala/CodegenPlugin.scala @@ -10,19 +10,17 @@ object CodegenPlugin extends AutoPlugin { dependencyClasspath in Compile, runner in Compile, streams) map { (dir, cp, r, s) => - val url = "jdbc:postgresql://postgres/ctig" - val jdbcDriver = "org.postgresql.Driver" - val slickDriver = "slick.driver.PostgresDriver" + // TODO Move this block into application.conf#slick.db.default.codegen val pkg = "dbmodels" val outputDir = (dir / "app" / pkg).getPath val fname = outputDir + "/Tables.scala" - // TODO: typesfname should be a parameter val typesfname = (file("shared") / "src" / "main" / "scala" / pkg / "rows" / "TableTypes.scala").getPath val schemas = "patients,portal,work_queues,confidential,case_accessioning,samples.samples,samples.subsamples,samples.shipment_preps,samples.collection_methods,experiments.experiments,experiments.exp_types,experiments.somatic_snvs_indels_filtered,samples.basic_diagnosis,samples.molecular_tests,samples.sample_pathology,samples.path_molecular_tests" - val user = "ctig_portal" - val password = "coolnurseconspiracyhandbook" - codegen.NamespacedCodegen.main( - Array( slickDriver, jdbcDriver, url, pkg, schemas, fname, typesfname, user, password)) + + val uri = new java.net.URI("#slick.db.default") + + codegen.NamespacedCodegen.run(uri, Some(outputDir), fname, typesfname, schemas) + Seq(file(fname)) } } diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala index 7d5beac..1069c68 100644 --- a/src/main/scala/NamespacedCodegen.scala +++ b/src/main/scala/NamespacedCodegen.scala @@ -22,7 +22,7 @@ import java.io.FileWriter // generator places the relevant generated classes into separate // objects--a "a" object, and a "b" object) object NamespacedCodegen { - def parseSchemaList(schemaList: String) = { + def parseSchemaList(schemaList: String): Map[String, List[String]] = { val sl = schemaList.split(",").filter(_.trim.nonEmpty) val tables: List[String] = sl.filter(_.contains(".")).toList val schemas: List[String] = sl.filter(s => !(s.contains("."))).toList @@ -33,254 +33,263 @@ object NamespacedCodegen { } mappedSchemas ++ mappedTables - } - def main(args: Array[String]) = { - args.toList match { - case List(slickDriver, jdbcDriver, url, pkg, schemaList, fname, typesfname, - user, password) => { + import slick.dbio.DBIO + import slick.model.Model + + def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] = + driver.createModel(Some( + MTable.getTables.map(_.filter((t: MTable) => mappedSchemas + .get(t.name.schema.getOrElse("")) + .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name)))))) + + val manualForeignKeys: Map[(String, String), (String, String)] = + Map( + ("portal.case_tumor_info", "patient_id") -> (("patients.patients", "patient_id")), + ("portal.case_tumor_info", "case_id") -> (("work_queues.reports", "case_id")), + ("portal.case_tumor_info", "cancer_id") -> (("patients.cancer", "cancer_id")), + ("confidential.join_pat", "patient_id") -> (("patients.patients", "patient_id")), + ("portal.case_tumor_info", "ordering_physician") -> (("patients.oncologists", "oncologist_id")), + ("patients.oncologists_case_permissions_view", "oncologist_id") -> (("patients.oncologists", "oncologist_id")), + ("patients.oncologists_case_permissions_view", "case_id") -> (("work_queues.reports", "case_id")), + ("case_accessioning.case_accessioning", "case_id") -> (("work_queues.reports", "case_id")), + ("case_accessioning.case_accessioning", "cancer_id") -> (("patients.cancer", "cancer_id")), + ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> (("patients.cancer", "cancer_id")), + ("experiments.experiments", "case_id") -> (("work_queues.reports", "case_id")), + ("samples.samples", "case_id") -> (("work_queues.reports", "case_id")) // TODO: Several of these can be added in a PR on postgres. + ) + + def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = { + def getTableColumn(tc: (String, String)) : (Table, Column) = { + val (tableName, columnName) = tc + val table = dbModel.tables.find(_.name.asString == tableName) + .getOrElse(throw new RuntimeException("No table " + tableName)) + val column = table.columns.find(_.name == columnName) + .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName)) + (table, column) + } - val driver: JdbcProfile = { - val module = currentMirror.staticModule(slickDriver) - val reflectedModule = currentMirror.reflectModule(module) - val driver = reflectedModule.instance.asInstanceOf[JdbcProfile] - driver - } + tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))} + } - val schemas = schemaList.split(",").filter(_.trim.nonEmpty).toSet - val mappedSchemas: Map[String, List[String]] = parseSchemaList(schemaList) - implicit class BlockingRun(v: driver.api.Database) { - def blockingRun[R](a: DBIOAction[R, NoStream, Nothing]) = - Await.result(v.run(a), Duration.Inf) + import java.net.URI + import slick.backend.DatabaseConfig + import slick.util.ConfigExtensionMethods.configExtensionMethods + + def run( + uri: URI, + outputDir: Option[String], + filename: String, + typesFilename: String, + schemaList: String + ): Unit = { + val dc = DatabaseConfig.forURI[JdbcProfile](uri) + val pkg = dc.config.getString("codegen.package") + val out = outputDir.getOrElse(dc.config.getStringOr("codegen.outputDir", ".")) + val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName + + // The following three parameters are unique to our code generator + // TODO: Decide: Put these in Typsafe Config or make it part of plugin interface? + // val filename = dc.config.getString("codegen.filename") + // val typesFilename = dc.config.getString("codegen.typesFilename") + // val schemaList = dc.config.getString("codegen.schemaList") + + val mappedSchemas = parseSchemaList(schemaList) + val dbModel = Await.result(dc.db.run(createFilteredModel(dc.driver, mappedSchemas)), Duration.Inf) + //finally dc.db.close + + val manualReferences = references(dbModel, manualForeignKeys) + + def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){ + def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = + (table.foreignKeys.toList + .filter(_.referencingColumns.forall(_ == column)) + .flatMap(fk => + fk.referencedColumns match { + case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)} + case _ => None + }) ++ + manualReferences.get((table.name.asString, column.name))) + .headOption + .map((derefColumn _).tupled) + .getOrElse((table, column)) + + // Is this compatible with ***REMOVED*** Id? How do we make it generic? + def idType(t: m.Table) : String = + "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]" + + override def code = { + //imports is copied right out of + //scala.slick.model.codegen.AbstractSourceCodeGenerator + // Why can't we simply re-use? + + var imports = + // acyclic is unnecessary in generic projects + //if (typeFile) "import acyclic.file\nimport dbmodels.rows\n" + //else + "import slick.model.ForeignKeyAction\n" + + "import rows._\n" + + ( if(tables.exists(_.hlistEnabled)){ + "import slick.collection.heterogeneous._\n"+ + "import slick.collection.heterogeneous.syntax._\n" + } else "" + ) + + ( if(tables.exists(_.PlainSqlMapper.enabled)){ + "import slick.jdbc.{GetResult => GR}\n"+ + "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n" + } else "" + ) + "\n\n" // We didn't copy ddl though + + + val bySchema = tables.groupBy(t => { + t.model.name.schema + }) + + val schemaFor = (schema: String) => { + bySchema(Option(schema)).sortBy(_.model.name.table).map( + _.code.mkString("\n") // TODO explore here + ).mkString("\n\n") } - val dbModel = driver.api.Database - .forURL(url, user, password, driver = jdbcDriver) - .blockingRun( - driver.createModel(Some(MTable.getTables.map(_.filter( - (t: MTable) => mappedSchemas - .get(t.name.schema.getOrElse("")) - .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name)) - ))))) - - // TODO take this map in as a parameter if we want the generator code to independent of our project - def getTableColumn(tc: (String, String)) : (Table, Column) = { - val tb = dbModel.tables.find(_.name.asString == tc._1).getOrElse(throw new RuntimeException("No table " + tc._1)) - val co = tb.columns.find(_.name == tc._2).getOrElse(throw new RuntimeException("No column " + tc._2 + " in table " + tc._1)) - (tb, co) - } - def manualReferences : Map[(String, String), (Table, Column)] = - Map( - ("portal.case_tumor_info", "patient_id") -> ("patients.patients", "patient_id"), - ("portal.case_tumor_info", "case_id") -> ("work_queues.reports", "case_id"), - ("portal.case_tumor_info", "cancer_id") -> ("patients.cancer", "cancer_id"), - ("confidential.join_pat", "patient_id") -> ("patients.patients", "patient_id"), - ("portal.case_tumor_info", "ordering_physician") -> ("patients.oncologists", "oncologist_id"), - ("patients.oncologists_case_permissions_view", "oncologist_id") -> ("patients.oncologists", "oncologist_id"), - ("patients.oncologists_case_permissions_view", "case_id") -> ("work_queues.reports", "case_id"), - ("case_accessioning.case_accessioning", "case_id") -> ("work_queues.reports", "case_id"), - ("case_accessioning.case_accessioning", "cancer_id") -> ("patients.cancer", "cancer_id"), - ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> ("patients.cancer", "cancer_id"), - ("experiments.experiments", "case_id") -> ("work_queues.reports", "case_id"), - ("samples.samples", "case_id") -> ("work_queues.reports", "case_id") // remove me when the foreign keys are fixed - ).map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))} - - def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){ - def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = - (table.foreignKeys.toList - .filter(_.referencingColumns.forall(_ == column)) - .flatMap(fk => - fk.referencedColumns match { - case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)} - case _ => None - }) ++ - manualReferences.get((table.name.asString, column.name))) - .headOption - .map((derefColumn _).tupled) - .getOrElse((table, column)) - - def idType(t: m.Table) : String = - "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]" - - override def code = { - //imports is copied right out of - //scala.slick.model.codegen.AbstractSourceCodeGenerator - - var imports = - if (typeFile) "import acyclic.file\nimport dbmodels.rows\n" - else - "import slick.model.ForeignKeyAction\n" + - "import rows._\n" + - ( if(tables.exists(_.hlistEnabled)){ - "import slick.collection.heterogeneous._\n"+ - "import slick.collection.heterogeneous.syntax._\n" - } else "" - ) + - ( if(tables.exists(_.PlainSqlMapper.enabled)){ - "import slick.jdbc.{GetResult => GR}\n"+ - "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n" - } else "" - ) + "\n\n" - - val bySchema = tables.groupBy(t => { - t.model.name.schema - }) - - val schemaFor = (schema: String) => { - bySchema(Option(schema)).sortBy(_.model.name.table).map( - _.code.mkString("\n") - ).mkString("\n\n") - } - - val schemata = mappedSchemas.keys.toList.sorted.map( - s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n" - ).mkString("\n\n") - - val idType = - if (typeFile) - """|case class Id[T](v: Int) - |""".stripMargin - else - """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] = - | MappedColumnType.base[Id[A], Int](_.v, Id(_)) - |import play.api.mvc.PathBindable - |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v) - |""".stripMargin - - - imports + idType + schemata - } + val schemata = mappedSchemas.keys.toList.sorted.map( + s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n" + ).mkString("\n\n") + + val idType = + if (typeFile)// Should not be defined here. + """|case class Id[T](v: Int) + |""".stripMargin + else + // This should be in a separate Implicits trait + """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] = + | MappedColumnType.base[Id[A], Int](_.v, Id(_)) + |import play.api.mvc.PathBindable + |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v) + |""".stripMargin + //pathbindable is play specific + // Id works only with labdash Id + imports + idType + schemata + } - override def Table = new Table(_) { - table => - // case classes go in the typeFile (but not types based on hlists) - override def definitions = - if (typeFile) Seq[Def]( EntityType ) - else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue ) + // This is overridden to output classfiles elsewhere + override def Table = new Table(_) { + table => + // case classes go in the typeFile (but not types based on hlists) + override def definitions = + if (typeFile) Seq[Def](EntityType) + else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue) - def EntityTypeRef = new EntityType() { - override def code = - s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++ - s"val $name = $pkg.rows.${model.name.schema.get}.$name" - } + def EntityTypeRef = new EntityType() { + override def code = + s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++ + s"val $name = $pkg.rows.${model.name.schema.get}.$name" + } - /** Creates a compound type from a given sequence of types. - * Uses HList if hlistEnabled else tuple. - */ - override def compoundType(types: Seq[String]): String = - /** Creates a compound value from a given sequence of values. - * Uses HList if hlistEnabled else tuple. - */ - if (hlistEnabled){ - def mkHList(types: List[String]): String = types match { - case Nil => "HNil" - case e :: tail => s"HCons[$e," + mkHList(tail) + "]" - } - mkHList(types.toList) + /** Creates a compound type from a given sequence of types. + * Uses HList if hlistEnabled else tuple. + */ + override def compoundType(types: Seq[String]): String = + /** Creates a compound value from a given sequence of values. + * Uses HList if hlistEnabled else tuple. + */ + // Yes! This is part of Slick now, yes? + if (hlistEnabled){ + def mkHList(types: List[String]): String = types match { + case Nil => "HNil" + case e :: tail => s"HCons[$e," + mkHList(tail) + "]" } - else compoundValue(types) + mkHList(types.toList) + } + else compoundValue(types) - override def mappingEnabled = true + //why? + override def mappingEnabled = true - override def compoundValue(values: Seq[String]): String = - if (hlistEnabled) values.mkString(" :: ") + " :: HNil" - else if (values.size == 1) values.head - else if(values.size <= 22) s"""(${values.mkString(", ")})""" - else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.") + override def compoundValue(values: Seq[String]): String = + if (hlistEnabled) values.mkString(" :: ") + " :: HNil" + else if (values.size == 1) values.head + else if(values.size <= 22) s"""(${values.mkString(", ")})""" + else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.") - def TableClassRef = new TableClass() { - // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it - override def option = if(columns.size <= 22) super.option else "" - } + def TableClassRef = new TableClass() { + // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it + override def option = if(columns.size <= 22) super.option else "" + } - override def factory = - if(columns.size <= 22) super.factory - else { - val args = columns.zipWithIndex.map("a"+_._2) - val hlist = args.mkString("::") + ":: HNil" - val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type" - s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" - } - override def extractor = - if(columns.size <= 22) super.extractor - else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)" - - // make foreign keys refer to namespaced referents - // if the referent is in a different namespace - override def ForeignKey = new ForeignKey(_) { - override def code = { - val fkColumns = compoundValue(referencingColumns.map(_.name)) - // Add the schema name to qualify the referenced table name if: - // 1. it's in a different schema from referencingTable, and - // 2. it's not None - val qualifier = if (referencedTable.model.name.schema - != referencingTable.model.name.schema) { - referencedTable.model.name.schema match { - case Some(schema) => schema + "." - case None => "" - } - } else { - "" - } - val qualifiedName = qualifier + referencedTable.TableValue.name - val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}")) - val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk" - s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})""" - } - } + override def factory = + if(columns.size <= 22) super.factory + else { + val args = columns.zipWithIndex.map("a"+_._2) + val hlist = args.mkString("::") + ":: HNil" + val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type" + s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" + } + override def extractor = + if(columns.size <= 22) super.extractor + else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)" - override def Column = new Column(_) { column => - // customize db type -> scala type mapping, pls adjust it according to your environment - - override def rawType = { - val (t, c) = derefColumn(table.model, column.model) - //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n") - if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t) - else model.tpe match { - case "java.sql.Date" => "tools.Date" - case "java.sql.Time" => "tools.Time" - case "java.sql.Timestamp" => "tools.Time" - case _ => super.rawType - } + // make foreign keys refer to namespaced referents + // if the referent is in a different namespace + override def ForeignKey = new ForeignKey(_) { + override def code = { + val fkColumns = compoundValue(referencingColumns.map(_.name)) + // Add the schema name to qualify the referenced table name if: + // 1. it's in a different schema from referencingTable, and + // 2. it's not None + val qualifier = if (referencedTable.model.name.schema + != referencingTable.model.name.schema) { + referencedTable.model.name.schema match { + case Some(schema) => schema + "." + case None => "" } + } else { + "" } + val qualifiedName = qualifier + referencedTable.TableValue.name + val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}")) + val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk" + s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})""" } } - def write(c: String, name: String) = { - (new File(name).getParentFile).mkdirs() - val fw = new FileWriter(name) - fw.write(c) - fw.close() + override def Column = new Column(_) { column => + // customize db type -> scala type mapping, pls adjust it according to your environment + + override def rawType = { + val (t, c) = derefColumn(table.model, column.model) + //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n") + if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t) + // ^ahahaha This is hacky + // This should be customizeable by client + + else model.tpe match { + // how does this type work out? + // There should be a way to add adhoc custom time mappings + case "java.sql.Date" => "tools.Date" + case "java.sql.Time" => "tools.Time" + case "java.sql.Timestamp" => "tools.Time" + case _ => super.rawType + } + } } - val disableScalariform = "// format: OFF\n" - val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None) - val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code - - write(disableScalariform + tablesSource, fname) - write(disableScalariform + rowsSource, typesfname) } - case _ => { - println(""" -Usage: NamespacedCodegen.main(Array( slickDriver, jdbcDriver, url, pkg, schemaList, fileName )) - -slickDriver: Fully qualified name of Slick driver class, e.g. 'scala.slick.driver.PostgresDriver' - -jdbcDriver: Fully qualified name of jdbc driver class, e.g. 'org.postgresql.Driver' - -url: jdbc url, e.g. 'jdbc:postgresql://localhost/test' - -pkg: Scala package the generated code should be placed in - -schemaList: string with comma-separated list of schemas to include + } -fName: name of output file -""".trim) - } + def write(c: String, name: String) = { + (new File(name).getParentFile).mkdirs() + val fw = new FileWriter(name) + fw.write(c) + fw.close() } + val disableScalariform = "filename/ format: OFF\n" + val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None) + val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code + + write(disableScalariform + tablesSource, filename) + write(disableScalariform + rowsSource, typesFilename) } } -- cgit v1.2.3