aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--build.sbt12
-rw-r--r--project/build.properties1
-rw-r--r--src/main/scala/CodegenPlugin.scala39
-rw-r--r--src/main/scala/NamespacedCodegen.scala286
4 files changed, 338 insertions, 0 deletions
diff --git a/build.sbt b/build.sbt
new file mode 100644
index 0000000..f583d50
--- /dev/null
+++ b/build.sbt
@@ -0,0 +1,12 @@
+// sbtPlugin := true
+
+val scalaVersionValue = "2.11.8"
+
+scalaVersion := scalaVersionValue
+
+libraryDependencies ++= Seq(
+ "com.typesafe.slick" %% "slick" % "3.1.1",
+ "com.typesafe.slick" %% "slick-codegen" % "3.1.1",
+ "org.scala-lang" % "scala-reflect" % scalaVersionValue,
+ "org.postgresql" % "postgresql" % "9.3-1102-jdbc41"
+)
diff --git a/project/build.properties b/project/build.properties
new file mode 100644
index 0000000..59e7c05
--- /dev/null
+++ b/project/build.properties
@@ -0,0 +1 @@
+sbt.version=0.13.11 \ No newline at end of file
diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala
new file mode 100644
index 0000000..36f92a3
--- /dev/null
+++ b/src/main/scala/CodegenPlugin.scala
@@ -0,0 +1,39 @@
+/*
+import sbt._
+import sbt.Keys._
+import complete.DefaultParsers._
+
+object CodegenPlugin extends AutoPlugin {
+ lazy val slick = TaskKey[Seq[File]]("gen-tables")
+ lazy val slickCodeGenTask = (baseDirectory, //sourceManaged in Compile,
+ dependencyClasspath in Compile,
+ runner in Compile, streams) map {
+ (dir, cp, r, s) =>
+ val url = "jdbc:postgresql://postgres/ctig"
+ val jdbcDriver = "org.postgresql.Driver"
+ val slickDriver = "slick.driver.PostgresDriver"
+ val pkg = "dbmodels"
+ val outputDir = (dir / "app" / pkg).getPath
+ val fname = outputDir + "/Tables.scala"
+ val typesfname = (file(sharedSrcDir) / "src" / "main" / "scala" / pkg / "rows" / "TableTypes.scala").getPath
+ val schemas = "patients,portal,work_queues,confidential,case_accessioning,samples.samples,samples.subsamples,samples.shipment_preps,samples.collection_methods,experiments.experiments,experiments.exp_types,experiments.somatic_snvs_indels_filtered,samples.basic_diagnosis,samples.molecular_tests,samples.sample_pathology,samples.path_molecular_tests"
+ val user = "ctig_portal"
+ val password = "coolnurseconspiracyhandbook"
+ toError(r.run(
+ "codegen.NamespacedCodegen",
+ cp.files,
+ Array(
+ slickDriver,
+ jdbcDriver,
+ url,
+ pkg,
+ schemas,
+ fname,
+ typesfname,
+ user,
+ password),
+ s.log))
+ Seq(file(fname))
+ }
+}
+*/
diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala
new file mode 100644
index 0000000..7d5beac
--- /dev/null
+++ b/src/main/scala/NamespacedCodegen.scala
@@ -0,0 +1,286 @@
+package codegen
+
+import slick.dbio.{NoStream, DBIOAction}
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.reflect.runtime.currentMirror
+import slick.ast.ColumnOption
+import slick.driver.JdbcProfile
+import slick.jdbc.meta.MTable
+import slick.codegen.{AbstractGenerator, SourceCodeGenerator}
+import slick.model._
+import slick.{model => m}
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import java.io.File
+import java.io.FileWriter
+
+// NamespacedCodegen handles tables within schemas by namespacing them
+// within objects here
+// (e.g., table a.foo and table b.foo can co-exist, because this code
+// generator places the relevant generated classes into separate
+// objects--a "a" object, and a "b" object)
+object NamespacedCodegen {
+ def parseSchemaList(schemaList: String) = {
+ val sl = schemaList.split(",").filter(_.trim.nonEmpty)
+ val tables: List[String] = sl.filter(_.contains(".")).toList
+ val schemas: List[String] = sl.filter(s => !(s.contains("."))).toList
+
+ val mappedSchemas = schemas.map(_ -> List()).toMap
+ val mappedTables = tables.groupBy(_.split("\\.")(0)).map {
+ case (key, value) => (key, value.map(_.split("\\.")(1)).asInstanceOf[List[String]])
+ }
+
+ mappedSchemas ++ mappedTables
+
+ }
+
+ def main(args: Array[String]) = {
+ args.toList match {
+ case List(slickDriver, jdbcDriver, url, pkg, schemaList, fname, typesfname,
+ user, password) => {
+
+ val driver: JdbcProfile = {
+ val module = currentMirror.staticModule(slickDriver)
+ val reflectedModule = currentMirror.reflectModule(module)
+ val driver = reflectedModule.instance.asInstanceOf[JdbcProfile]
+ driver
+ }
+
+ val schemas = schemaList.split(",").filter(_.trim.nonEmpty).toSet
+ val mappedSchemas: Map[String, List[String]] = parseSchemaList(schemaList)
+
+ implicit class BlockingRun(v: driver.api.Database) {
+ def blockingRun[R](a: DBIOAction[R, NoStream, Nothing]) =
+ Await.result(v.run(a), Duration.Inf)
+ }
+
+ val dbModel = driver.api.Database
+ .forURL(url, user, password, driver = jdbcDriver)
+ .blockingRun(
+ driver.createModel(Some(MTable.getTables.map(_.filter(
+ (t: MTable) => mappedSchemas
+ .get(t.name.schema.getOrElse(""))
+ .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))
+ )))))
+
+ // TODO take this map in as a parameter if we want the generator code to independent of our project
+ def getTableColumn(tc: (String, String)) : (Table, Column) = {
+ val tb = dbModel.tables.find(_.name.asString == tc._1).getOrElse(throw new RuntimeException("No table " + tc._1))
+ val co = tb.columns.find(_.name == tc._2).getOrElse(throw new RuntimeException("No column " + tc._2 + " in table " + tc._1))
+ (tb, co)
+ }
+ def manualReferences : Map[(String, String), (Table, Column)] =
+ Map(
+ ("portal.case_tumor_info", "patient_id") -> ("patients.patients", "patient_id"),
+ ("portal.case_tumor_info", "case_id") -> ("work_queues.reports", "case_id"),
+ ("portal.case_tumor_info", "cancer_id") -> ("patients.cancer", "cancer_id"),
+ ("confidential.join_pat", "patient_id") -> ("patients.patients", "patient_id"),
+ ("portal.case_tumor_info", "ordering_physician") -> ("patients.oncologists", "oncologist_id"),
+ ("patients.oncologists_case_permissions_view", "oncologist_id") -> ("patients.oncologists", "oncologist_id"),
+ ("patients.oncologists_case_permissions_view", "case_id") -> ("work_queues.reports", "case_id"),
+ ("case_accessioning.case_accessioning", "case_id") -> ("work_queues.reports", "case_id"),
+ ("case_accessioning.case_accessioning", "cancer_id") -> ("patients.cancer", "cancer_id"),
+ ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> ("patients.cancer", "cancer_id"),
+ ("experiments.experiments", "case_id") -> ("work_queues.reports", "case_id"),
+ ("samples.samples", "case_id") -> ("work_queues.reports", "case_id") // remove me when the foreign keys are fixed
+ ).map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
+
+ def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
+ def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
+ (table.foreignKeys.toList
+ .filter(_.referencingColumns.forall(_ == column))
+ .flatMap(fk =>
+ fk.referencedColumns match {
+ case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
+ case _ => None
+ }) ++
+ manualReferences.get((table.name.asString, column.name)))
+ .headOption
+ .map((derefColumn _).tupled)
+ .getOrElse((table, column))
+
+ def idType(t: m.Table) : String =
+ "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]"
+
+ override def code = {
+ //imports is copied right out of
+ //scala.slick.model.codegen.AbstractSourceCodeGenerator
+
+ var imports =
+ if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
+ else
+ "import slick.model.ForeignKeyAction\n" +
+ "import rows._\n" +
+ ( if(tables.exists(_.hlistEnabled)){
+ "import slick.collection.heterogeneous._\n"+
+ "import slick.collection.heterogeneous.syntax._\n"
+ } else ""
+ ) +
+ ( if(tables.exists(_.PlainSqlMapper.enabled)){
+ "import slick.jdbc.{GetResult => GR}\n"+
+ "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
+ } else ""
+ ) + "\n\n"
+
+ val bySchema = tables.groupBy(t => {
+ t.model.name.schema
+ })
+
+ val schemaFor = (schema: String) => {
+ bySchema(Option(schema)).sortBy(_.model.name.table).map(
+ _.code.mkString("\n")
+ ).mkString("\n\n")
+ }
+
+ val schemata = mappedSchemas.keys.toList.sorted.map(
+ s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
+ ).mkString("\n\n")
+
+ val idType =
+ if (typeFile)
+ """|case class Id[T](v: Int)
+ |""".stripMargin
+ else
+ """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] =
+ | MappedColumnType.base[Id[A], Int](_.v, Id(_))
+ |import play.api.mvc.PathBindable
+ |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v)
+ |""".stripMargin
+
+
+ imports + idType + schemata
+ }
+
+ override def Table = new Table(_) {
+ table =>
+ // case classes go in the typeFile (but not types based on hlists)
+ override def definitions =
+ if (typeFile) Seq[Def]( EntityType )
+ else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue )
+
+
+ def EntityTypeRef = new EntityType() {
+ override def code =
+ s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
+ s"val $name = $pkg.rows.${model.name.schema.get}.$name"
+ }
+
+ /** Creates a compound type from a given sequence of types.
+ * Uses HList if hlistEnabled else tuple.
+ */
+ override def compoundType(types: Seq[String]): String =
+ /** Creates a compound value from a given sequence of values.
+ * Uses HList if hlistEnabled else tuple.
+ */
+ if (hlistEnabled){
+ def mkHList(types: List[String]): String = types match {
+ case Nil => "HNil"
+ case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
+ }
+ mkHList(types.toList)
+ }
+ else compoundValue(types)
+
+ override def mappingEnabled = true
+
+ override def compoundValue(values: Seq[String]): String =
+ if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
+ else if (values.size == 1) values.head
+ else if(values.size <= 22) s"""(${values.mkString(", ")})"""
+ else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")
+
+ def TableClassRef = new TableClass() {
+ // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
+ override def option = if(columns.size <= 22) super.option else ""
+ }
+
+ override def factory =
+ if(columns.size <= 22) super.factory
+ else {
+ val args = columns.zipWithIndex.map("a"+_._2)
+ val hlist = args.mkString("::") + ":: HNil"
+ val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
+ s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
+ }
+ override def extractor =
+ if(columns.size <= 22) super.extractor
+ else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"
+
+ // make foreign keys refer to namespaced referents
+ // if the referent is in a different namespace
+ override def ForeignKey = new ForeignKey(_) {
+ override def code = {
+ val fkColumns = compoundValue(referencingColumns.map(_.name))
+ // Add the schema name to qualify the referenced table name if:
+ // 1. it's in a different schema from referencingTable, and
+ // 2. it's not None
+ val qualifier = if (referencedTable.model.name.schema
+ != referencingTable.model.name.schema) {
+ referencedTable.model.name.schema match {
+ case Some(schema) => schema + "."
+ case None => ""
+ }
+ } else {
+ ""
+ }
+ val qualifiedName = qualifier + referencedTable.TableValue.name
+ val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
+ val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
+ s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
+ }
+ }
+
+ override def Column = new Column(_) { column =>
+ // customize db type -> scala type mapping, pls adjust it according to your environment
+
+ override def rawType = {
+ val (t, c) = derefColumn(table.model, column.model)
+ //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
+ if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
+ else model.tpe match {
+ case "java.sql.Date" => "tools.Date"
+ case "java.sql.Time" => "tools.Time"
+ case "java.sql.Timestamp" => "tools.Time"
+ case _ => super.rawType
+ }
+ }
+ }
+ }
+ }
+
+ def write(c: String, name: String) = {
+ (new File(name).getParentFile).mkdirs()
+ val fw = new FileWriter(name)
+ fw.write(c)
+ fw.close()
+ }
+ val disableScalariform = "// format: OFF\n"
+ val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
+ val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code
+
+ write(disableScalariform + tablesSource, fname)
+ write(disableScalariform + rowsSource, typesfname)
+ }
+ case _ => {
+ println("""
+Usage: NamespacedCodegen.main(Array( slickDriver, jdbcDriver, url, pkg, schemaList, fileName ))
+
+slickDriver: Fully qualified name of Slick driver class, e.g. 'scala.slick.driver.PostgresDriver'
+
+jdbcDriver: Fully qualified name of jdbc driver class, e.g. 'org.postgresql.Driver'
+
+url: jdbc url, e.g. 'jdbc:postgresql://localhost/test'
+
+pkg: Scala package the generated code should be placed in
+
+schemaList: string with comma-separated list of schemas to include
+
+fName: name of output file
+""".trim)
+ }
+ }
+ }
+}
+