From 29691456b69fb0e2f81f503530b315c5f9a0275e Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 22 Aug 2016 15:21:09 -0700 Subject: Copied codegen source from LabDash --- build.sbt | 12 ++ project/build.properties | 1 + src/main/scala/CodegenPlugin.scala | 39 +++++ src/main/scala/NamespacedCodegen.scala | 286 +++++++++++++++++++++++++++++++++ 4 files changed, 338 insertions(+) create mode 100644 build.sbt create mode 100644 project/build.properties create mode 100644 src/main/scala/CodegenPlugin.scala create mode 100644 src/main/scala/NamespacedCodegen.scala diff --git a/build.sbt b/build.sbt new file mode 100644 index 0000000..f583d50 --- /dev/null +++ b/build.sbt @@ -0,0 +1,12 @@ +// sbtPlugin := true + +val scalaVersionValue = "2.11.8" + +scalaVersion := scalaVersionValue + +libraryDependencies ++= Seq( + "com.typesafe.slick" %% "slick" % "3.1.1", + "com.typesafe.slick" %% "slick-codegen" % "3.1.1", + "org.scala-lang" % "scala-reflect" % scalaVersionValue, + "org.postgresql" % "postgresql" % "9.3-1102-jdbc41" +) diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..59e7c05 --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=0.13.11 \ No newline at end of file diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala new file mode 100644 index 0000000..36f92a3 --- /dev/null +++ b/src/main/scala/CodegenPlugin.scala @@ -0,0 +1,39 @@ +/* +import sbt._ +import sbt.Keys._ +import complete.DefaultParsers._ + +object CodegenPlugin extends AutoPlugin { + lazy val slick = TaskKey[Seq[File]]("gen-tables") + lazy val slickCodeGenTask = (baseDirectory, //sourceManaged in Compile, + dependencyClasspath in Compile, + runner in Compile, streams) map { + (dir, cp, r, s) => + val url = "jdbc:postgresql://postgres/ctig" + val jdbcDriver = "org.postgresql.Driver" + val slickDriver = "slick.driver.PostgresDriver" + val pkg = "dbmodels" + val outputDir = (dir / "app" / pkg).getPath + val fname = outputDir + "/Tables.scala" + val typesfname = (file(sharedSrcDir) / "src" / "main" / "scala" / pkg / "rows" / "TableTypes.scala").getPath + val schemas = "patients,portal,work_queues,confidential,case_accessioning,samples.samples,samples.subsamples,samples.shipment_preps,samples.collection_methods,experiments.experiments,experiments.exp_types,experiments.somatic_snvs_indels_filtered,samples.basic_diagnosis,samples.molecular_tests,samples.sample_pathology,samples.path_molecular_tests" + val user = "ctig_portal" + val password = "coolnurseconspiracyhandbook" + toError(r.run( + "codegen.NamespacedCodegen", + cp.files, + Array( + slickDriver, + jdbcDriver, + url, + pkg, + schemas, + fname, + typesfname, + user, + password), + s.log)) + Seq(file(fname)) + } +} +*/ diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala new file mode 100644 index 0000000..7d5beac --- /dev/null +++ b/src/main/scala/NamespacedCodegen.scala @@ -0,0 +1,286 @@ +package codegen + +import slick.dbio.{NoStream, DBIOAction} + +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.reflect.runtime.currentMirror +import slick.ast.ColumnOption +import slick.driver.JdbcProfile +import slick.jdbc.meta.MTable +import slick.codegen.{AbstractGenerator, SourceCodeGenerator} +import slick.model._ +import slick.{model => m} +import scala.concurrent.ExecutionContext.Implicits.global + +import java.io.File +import java.io.FileWriter + +// NamespacedCodegen handles tables within schemas by namespacing them +// within objects here +// (e.g., table a.foo and table b.foo can co-exist, because this code +// generator places the relevant generated classes into separate +// objects--a "a" object, and a "b" object) +object NamespacedCodegen { + def parseSchemaList(schemaList: String) = { + val sl = schemaList.split(",").filter(_.trim.nonEmpty) + val tables: List[String] = sl.filter(_.contains(".")).toList + val schemas: List[String] = sl.filter(s => !(s.contains("."))).toList + + val mappedSchemas = schemas.map(_ -> List()).toMap + val mappedTables = tables.groupBy(_.split("\\.")(0)).map { + case (key, value) => (key, value.map(_.split("\\.")(1)).asInstanceOf[List[String]]) + } + + mappedSchemas ++ mappedTables + + } + + def main(args: Array[String]) = { + args.toList match { + case List(slickDriver, jdbcDriver, url, pkg, schemaList, fname, typesfname, + user, password) => { + + val driver: JdbcProfile = { + val module = currentMirror.staticModule(slickDriver) + val reflectedModule = currentMirror.reflectModule(module) + val driver = reflectedModule.instance.asInstanceOf[JdbcProfile] + driver + } + + val schemas = schemaList.split(",").filter(_.trim.nonEmpty).toSet + val mappedSchemas: Map[String, List[String]] = parseSchemaList(schemaList) + + implicit class BlockingRun(v: driver.api.Database) { + def blockingRun[R](a: DBIOAction[R, NoStream, Nothing]) = + Await.result(v.run(a), Duration.Inf) + } + + val dbModel = driver.api.Database + .forURL(url, user, password, driver = jdbcDriver) + .blockingRun( + driver.createModel(Some(MTable.getTables.map(_.filter( + (t: MTable) => mappedSchemas + .get(t.name.schema.getOrElse("")) + .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name)) + ))))) + + // TODO take this map in as a parameter if we want the generator code to independent of our project + def getTableColumn(tc: (String, String)) : (Table, Column) = { + val tb = dbModel.tables.find(_.name.asString == tc._1).getOrElse(throw new RuntimeException("No table " + tc._1)) + val co = tb.columns.find(_.name == tc._2).getOrElse(throw new RuntimeException("No column " + tc._2 + " in table " + tc._1)) + (tb, co) + } + def manualReferences : Map[(String, String), (Table, Column)] = + Map( + ("portal.case_tumor_info", "patient_id") -> ("patients.patients", "patient_id"), + ("portal.case_tumor_info", "case_id") -> ("work_queues.reports", "case_id"), + ("portal.case_tumor_info", "cancer_id") -> ("patients.cancer", "cancer_id"), + ("confidential.join_pat", "patient_id") -> ("patients.patients", "patient_id"), + ("portal.case_tumor_info", "ordering_physician") -> ("patients.oncologists", "oncologist_id"), + ("patients.oncologists_case_permissions_view", "oncologist_id") -> ("patients.oncologists", "oncologist_id"), + ("patients.oncologists_case_permissions_view", "case_id") -> ("work_queues.reports", "case_id"), + ("case_accessioning.case_accessioning", "case_id") -> ("work_queues.reports", "case_id"), + ("case_accessioning.case_accessioning", "cancer_id") -> ("patients.cancer", "cancer_id"), + ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> ("patients.cancer", "cancer_id"), + ("experiments.experiments", "case_id") -> ("work_queues.reports", "case_id"), + ("samples.samples", "case_id") -> ("work_queues.reports", "case_id") // remove me when the foreign keys are fixed + ).map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))} + + def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){ + def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = + (table.foreignKeys.toList + .filter(_.referencingColumns.forall(_ == column)) + .flatMap(fk => + fk.referencedColumns match { + case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)} + case _ => None + }) ++ + manualReferences.get((table.name.asString, column.name))) + .headOption + .map((derefColumn _).tupled) + .getOrElse((table, column)) + + def idType(t: m.Table) : String = + "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]" + + override def code = { + //imports is copied right out of + //scala.slick.model.codegen.AbstractSourceCodeGenerator + + var imports = + if (typeFile) "import acyclic.file\nimport dbmodels.rows\n" + else + "import slick.model.ForeignKeyAction\n" + + "import rows._\n" + + ( if(tables.exists(_.hlistEnabled)){ + "import slick.collection.heterogeneous._\n"+ + "import slick.collection.heterogeneous.syntax._\n" + } else "" + ) + + ( if(tables.exists(_.PlainSqlMapper.enabled)){ + "import slick.jdbc.{GetResult => GR}\n"+ + "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n" + } else "" + ) + "\n\n" + + val bySchema = tables.groupBy(t => { + t.model.name.schema + }) + + val schemaFor = (schema: String) => { + bySchema(Option(schema)).sortBy(_.model.name.table).map( + _.code.mkString("\n") + ).mkString("\n\n") + } + + val schemata = mappedSchemas.keys.toList.sorted.map( + s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n" + ).mkString("\n\n") + + val idType = + if (typeFile) + """|case class Id[T](v: Int) + |""".stripMargin + else + """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] = + | MappedColumnType.base[Id[A], Int](_.v, Id(_)) + |import play.api.mvc.PathBindable + |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v) + |""".stripMargin + + + imports + idType + schemata + } + + override def Table = new Table(_) { + table => + // case classes go in the typeFile (but not types based on hlists) + override def definitions = + if (typeFile) Seq[Def]( EntityType ) + else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue ) + + + def EntityTypeRef = new EntityType() { + override def code = + s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++ + s"val $name = $pkg.rows.${model.name.schema.get}.$name" + } + + /** Creates a compound type from a given sequence of types. + * Uses HList if hlistEnabled else tuple. + */ + override def compoundType(types: Seq[String]): String = + /** Creates a compound value from a given sequence of values. + * Uses HList if hlistEnabled else tuple. + */ + if (hlistEnabled){ + def mkHList(types: List[String]): String = types match { + case Nil => "HNil" + case e :: tail => s"HCons[$e," + mkHList(tail) + "]" + } + mkHList(types.toList) + } + else compoundValue(types) + + override def mappingEnabled = true + + override def compoundValue(values: Seq[String]): String = + if (hlistEnabled) values.mkString(" :: ") + " :: HNil" + else if (values.size == 1) values.head + else if(values.size <= 22) s"""(${values.mkString(", ")})""" + else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.") + + def TableClassRef = new TableClass() { + // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it + override def option = if(columns.size <= 22) super.option else "" + } + + override def factory = + if(columns.size <= 22) super.factory + else { + val args = columns.zipWithIndex.map("a"+_._2) + val hlist = args.mkString("::") + ":: HNil" + val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type" + s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" + } + override def extractor = + if(columns.size <= 22) super.extractor + else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)" + + // make foreign keys refer to namespaced referents + // if the referent is in a different namespace + override def ForeignKey = new ForeignKey(_) { + override def code = { + val fkColumns = compoundValue(referencingColumns.map(_.name)) + // Add the schema name to qualify the referenced table name if: + // 1. it's in a different schema from referencingTable, and + // 2. it's not None + val qualifier = if (referencedTable.model.name.schema + != referencingTable.model.name.schema) { + referencedTable.model.name.schema match { + case Some(schema) => schema + "." + case None => "" + } + } else { + "" + } + val qualifiedName = qualifier + referencedTable.TableValue.name + val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}")) + val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk" + s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})""" + } + } + + override def Column = new Column(_) { column => + // customize db type -> scala type mapping, pls adjust it according to your environment + + override def rawType = { + val (t, c) = derefColumn(table.model, column.model) + //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n") + if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t) + else model.tpe match { + case "java.sql.Date" => "tools.Date" + case "java.sql.Time" => "tools.Time" + case "java.sql.Timestamp" => "tools.Time" + case _ => super.rawType + } + } + } + } + } + + def write(c: String, name: String) = { + (new File(name).getParentFile).mkdirs() + val fw = new FileWriter(name) + fw.write(c) + fw.close() + } + val disableScalariform = "// format: OFF\n" + val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None) + val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code + + write(disableScalariform + tablesSource, fname) + write(disableScalariform + rowsSource, typesfname) + } + case _ => { + println(""" +Usage: NamespacedCodegen.main(Array( slickDriver, jdbcDriver, url, pkg, schemaList, fileName )) + +slickDriver: Fully qualified name of Slick driver class, e.g. 'scala.slick.driver.PostgresDriver' + +jdbcDriver: Fully qualified name of jdbc driver class, e.g. 'org.postgresql.Driver' + +url: jdbc url, e.g. 'jdbc:postgresql://localhost/test' + +pkg: Scala package the generated code should be placed in + +schemaList: string with comma-separated list of schemas to include + +fName: name of output file +""".trim) + } + } + } +} + -- cgit v1.2.3