aboutsummaryrefslogblamecommitdiff
path: root/src/main/scala/NamespacedCodegen.scala
blob: 7d5beac11fa9743bd3d3f694ef7693f2ee8001fe (plain) (tree)





























































































































































































































































































                                                                                                                                                                              
package codegen

import slick.dbio.{NoStream, DBIOAction}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.reflect.runtime.currentMirror
import slick.ast.ColumnOption
import slick.driver.JdbcProfile
import slick.jdbc.meta.MTable
import slick.codegen.{AbstractGenerator, SourceCodeGenerator}
import slick.model._
import slick.{model => m}
import scala.concurrent.ExecutionContext.Implicits.global

import java.io.File
import java.io.FileWriter

// NamespacedCodegen handles tables within schemas by namespacing them
// within objects here
// (e.g., table a.foo and table b.foo can co-exist, because this code
// generator places the relevant generated classes into separate
// objects--a "a" object, and a "b" object)
object NamespacedCodegen {
  def parseSchemaList(schemaList: String) = {
    val sl = schemaList.split(",").filter(_.trim.nonEmpty)
    val tables: List[String] = sl.filter(_.contains(".")).toList
    val schemas: List[String] = sl.filter(s => !(s.contains("."))).toList

    val mappedSchemas = schemas.map(_ -> List()).toMap
    val mappedTables = tables.groupBy(_.split("\\.")(0)).map {
      case (key, value) => (key, value.map(_.split("\\.")(1)).asInstanceOf[List[String]])
    }

    mappedSchemas ++ mappedTables

  }

  def main(args: Array[String]) = {
    args.toList match {
      case List(slickDriver, jdbcDriver, url, pkg, schemaList, fname, typesfname,
        user, password) => {

        val driver: JdbcProfile = {
          val module = currentMirror.staticModule(slickDriver)
          val reflectedModule = currentMirror.reflectModule(module)
          val driver = reflectedModule.instance.asInstanceOf[JdbcProfile]
          driver
        }

        val schemas = schemaList.split(",").filter(_.trim.nonEmpty).toSet
        val mappedSchemas: Map[String, List[String]] = parseSchemaList(schemaList)

        implicit class BlockingRun(v: driver.api.Database) {
          def blockingRun[R](a: DBIOAction[R, NoStream, Nothing]) =
            Await.result(v.run(a), Duration.Inf)
        }

        val dbModel = driver.api.Database
          .forURL(url, user, password, driver = jdbcDriver)
          .blockingRun(
            driver.createModel(Some(MTable.getTables.map(_.filter(
              (t: MTable) => mappedSchemas
                .get(t.name.schema.getOrElse(""))
                .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))
          )))))

        // TODO take this map in as a parameter if we want the generator code to independent of our project
        def getTableColumn(tc: (String, String)) : (Table, Column) = {
          val tb = dbModel.tables.find(_.name.asString == tc._1).getOrElse(throw new RuntimeException("No table " + tc._1))
          val co = tb.columns.find(_.name == tc._2).getOrElse(throw new RuntimeException("No column " + tc._2 + " in table " + tc._1))
          (tb, co)
        }
        def manualReferences : Map[(String, String), (Table, Column)] =
          Map(
            ("portal.case_tumor_info", "patient_id") -> ("patients.patients", "patient_id"),
            ("portal.case_tumor_info", "case_id") -> ("work_queues.reports", "case_id"),
            ("portal.case_tumor_info", "cancer_id") -> ("patients.cancer", "cancer_id"),
            ("confidential.join_pat", "patient_id") -> ("patients.patients", "patient_id"),
            ("portal.case_tumor_info", "ordering_physician") -> ("patients.oncologists", "oncologist_id"),
            ("patients.oncologists_case_permissions_view", "oncologist_id") -> ("patients.oncologists", "oncologist_id"),
            ("patients.oncologists_case_permissions_view", "case_id") -> ("work_queues.reports", "case_id"),
            ("case_accessioning.case_accessioning", "case_id") -> ("work_queues.reports", "case_id"),
            ("case_accessioning.case_accessioning", "cancer_id") -> ("patients.cancer", "cancer_id"),
            ("experiments.somatic_snvs_indels_filtered", "cancer_id") -> ("patients.cancer", "cancer_id"),
            ("experiments.experiments", "case_id") -> ("work_queues.reports", "case_id"),
            ("samples.samples", "case_id") -> ("work_queues.reports", "case_id") // remove me when the foreign keys are fixed
          ).map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}

        def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
          def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
            (table.foreignKeys.toList
              .filter(_.referencingColumns.forall(_ == column))
              .flatMap(fk =>
                fk.referencedColumns match {
                  case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
                  case _ => None
                }) ++
              manualReferences.get((table.name.asString, column.name)))
              .headOption
              .map((derefColumn _).tupled)
              .getOrElse((table, column))

          def idType(t: m.Table) : String =
            "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]"

          override def code = {
            //imports is copied right out of
            //scala.slick.model.codegen.AbstractSourceCodeGenerator

            var imports =
              if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
              else
                  "import slick.model.ForeignKeyAction\n" +
                    "import rows._\n" +
                ( if(tables.exists(_.hlistEnabled)){
                  "import slick.collection.heterogeneous._\n"+
                  "import slick.collection.heterogeneous.syntax._\n"
                } else ""
                ) +
                ( if(tables.exists(_.PlainSqlMapper.enabled)){
                  "import slick.jdbc.{GetResult => GR}\n"+
                  "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
                } else ""
                ) + "\n\n"

            val bySchema = tables.groupBy(t => {
              t.model.name.schema
            })

            val schemaFor = (schema: String) => {
              bySchema(Option(schema)).sortBy(_.model.name.table).map(
                _.code.mkString("\n")
              ).mkString("\n\n")
            }

            val schemata = mappedSchemas.keys.toList.sorted.map(
              s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
            ).mkString("\n\n")

            val idType =
              if (typeFile)
                """|case class Id[T](v: Int)
                  |""".stripMargin
              else
                """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] =
                    |  MappedColumnType.base[Id[A], Int](_.v, Id(_))
                    |import play.api.mvc.PathBindable
                    |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v)
                    |""".stripMargin


            imports + idType + schemata
          }

          override def Table = new Table(_) {
            table =>
            // case classes go in the typeFile (but not types based on hlists)
            override def definitions =
              if (typeFile) Seq[Def]( EntityType )
              else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue )


            def EntityTypeRef = new EntityType() {
              override def code =
                s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
                s"val $name = $pkg.rows.${model.name.schema.get}.$name"
            }

            /** Creates a compound type from a given sequence of types.
              *  Uses HList if hlistEnabled else tuple.
              */
            override def compoundType(types: Seq[String]): String =
            /** Creates a compound value from a given sequence of values.
              *  Uses HList if hlistEnabled else tuple.
              */
            if (hlistEnabled){
              def mkHList(types: List[String]): String = types match {
                case Nil => "HNil"
                case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
              }
              mkHList(types.toList)
            }
            else compoundValue(types)

          override def mappingEnabled = true

          override def compoundValue(values: Seq[String]): String =
            if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
            else if (values.size == 1) values.head
            else if(values.size <= 22) s"""(${values.mkString(", ")})"""
            else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")

            def TableClassRef = new TableClass() {
              // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
              override def option = if(columns.size <= 22) super.option else ""
            }

          override def factory   =
            if(columns.size <= 22) super.factory
            else {
              val args = columns.zipWithIndex.map("a"+_._2)
              val hlist = args.mkString("::") + ":: HNil"
              val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
              s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
            }
          override def extractor =
            if(columns.size <= 22) super.extractor
            else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"

            // make foreign keys refer to namespaced referents
            // if the referent is in a different namespace
            override def ForeignKey = new ForeignKey(_) {
              override def code = {
                val fkColumns = compoundValue(referencingColumns.map(_.name))
                // Add the schema name to qualify the referenced table name if:
                // 1. it's in a different schema from referencingTable, and
                // 2. it's not None
                val qualifier = if (referencedTable.model.name.schema 
                                    != referencingTable.model.name.schema) {
                  referencedTable.model.name.schema match {
                    case Some(schema) => schema + "."
                    case None => ""
                  } 
                } else {
                  ""
                }
                val qualifiedName = qualifier + referencedTable.TableValue.name
                val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
                val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
                s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
              }
            }

            override def Column = new Column(_) { column =>
              // customize db type -> scala type mapping, pls adjust it according to your environment

              override def rawType = {
                val (t, c) = derefColumn(table.model, column.model)
                //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
                if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
                else model.tpe match {
                  case "java.sql.Date" => "tools.Date"
                  case "java.sql.Time" => "tools.Time"
                  case "java.sql.Timestamp" => "tools.Time"
                  case _ => super.rawType
                }
              }
            }
          }
        }

        def write(c: String, name: String) = {
          (new File(name).getParentFile).mkdirs()
          val fw = new FileWriter(name)
          fw.write(c)
          fw.close()
        }
        val disableScalariform = "// format: OFF\n"
        val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
        val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code

        write(disableScalariform + tablesSource, fname)
        write(disableScalariform + rowsSource, typesfname)
      }
      case _ => {
        println("""
Usage: NamespacedCodegen.main(Array( slickDriver, jdbcDriver, url, pkg, schemaList, fileName ))

slickDriver: Fully qualified name of Slick driver class, e.g. 'scala.slick.driver.PostgresDriver'

jdbcDriver: Fully qualified name of jdbc driver class, e.g. 'org.postgresql.Driver'

url: jdbc url, e.g. 'jdbc:postgresql://localhost/test'

pkg: Scala package the generated code should be placed in

schemaList: string with comma-separated list of schemas to include

fName: name of output file
""".trim)
      }
    }
  }
}