aboutsummaryrefslogblamecommitdiff
path: root/src/main/scala/NamespacedCodegen.scala
blob: bdda76afa1be3402c6ef45d94223c4da219bb1b1 (plain) (tree)





















                                                                      

                                                                                    





                                                                                         

   








                                                                                                       








                                                                                                                                 
 

                                                                                         
 
 





                                                                 
                

                          

                                                              

                                                    
                                                                                     
                                                         





















































                                                                                                                                           

         


















                                                                                                                                
 






                                                                                 

 




                                                                         
 











                                                                       
             


                                   
 

                                          
 




                                                                                                                               
 



                                                                                                                                 
 










                                                                                                                  
 












                                                                           
               

                    
             



                                                                                                                                                                          


           


















                                                                                                                      
         
       
     
 




                                             
     





                                                                                   


   
import slick.dbio.{NoStream, DBIOAction}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.reflect.runtime.currentMirror
import slick.ast.ColumnOption
import slick.driver.JdbcProfile
import slick.jdbc.meta.MTable
import slick.codegen.{AbstractGenerator, SourceCodeGenerator}
import slick.model._
import slick.{model => m}
import scala.concurrent.ExecutionContext.Implicits.global

import java.io.File
import java.io.FileWriter

// NamespacedCodegen handles tables within schemas by namespacing them
// within objects here
// (e.g., table a.foo and table b.foo can co-exist, because this code
// generator places the relevant generated classes into separate
// objects--a "a" object, and a "b" object)
object NamespacedCodegen {
  def parseSchemaList(schemaTableNames: List[String]): Map[String, List[String]] = {
    val (tables, schemas) = schemaTableNames.partition(_.contains("."))
    val mappedSchemas = schemas.map(_ -> List()).toMap
    val mappedTables = tables.groupBy(_.split("\\.")(0)).map {
      case (key, value) => (key, value.map(_.split("\\.")(1)).asInstanceOf[List[String]])
    }

    mappedSchemas ++ mappedTables
  }

  import slick.dbio.DBIO
  import slick.model.Model

  def createFilteredModel(driver: JdbcProfile, mappedSchemas: Map[String, List[String]]): DBIO[Model] =
    driver.createModel(Some(
      MTable.getTables.map(_.filter((t: MTable) => mappedSchemas
        .get(t.name.schema.getOrElse(""))
        .fold(false)(ts => ts.isEmpty || ts.contains(t.name.name))))))

  def references(dbModel: Model, tcMappings: Map[(String, String), (String, String)]): Map[(String, String), (Table, Column)] = {
    def getTableColumn(tc: (String, String)) : (Table, Column) = {
      val (tableName, columnName) = tc
      val table = dbModel.tables.find(_.name.asString == tableName)
        .getOrElse(throw new RuntimeException("No table " + tableName))
      val column = table.columns.find(_.name == columnName)
        .getOrElse(throw new RuntimeException("No column " + columnName + " in table " + tableName))
      (table, column)
    }

    tcMappings.map{case (from, to) => ({getTableColumn(from); from}, getTableColumn(to))}
  }


  import java.net.URI
  import slick.backend.DatabaseConfig
  import slick.util.ConfigExtensionMethods.configExtensionMethods

  def run(
    uri: URI,
    pkg: String,
    filename: String,
    typesFilename: String,
    schemaTableNames: List[String],
    manualForeignKeys: Map[(String, String), (String, String)]
  ): Unit = {
    val dc = DatabaseConfig.forURI[JdbcProfile](uri)
    val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName
    val mappedSchemas = parseSchemaList(schemaTableNames)
    val dbModel = Await.result(dc.db.run(createFilteredModel(dc.driver, mappedSchemas)), Duration.Inf)
    //finally dc.db.close

    val manualReferences = references(dbModel, manualForeignKeys)

    def codegen(typeFile: Boolean) = new SourceCodeGenerator(dbModel){
      def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) =
        (table.foreignKeys.toList
          .filter(_.referencingColumns.forall(_ == column))
          .flatMap(fk =>
          fk.referencedColumns match {
            case Seq(c) => dbModel.tablesByName.get(fk.referencedTable).map{(_, c)}
            case _ => None
          }) ++
          manualReferences.get((table.name.asString, column.name)))
          .headOption
          .map((derefColumn _).tupled)
          .getOrElse((table, column))

      // Is this compatible with ***REMOVED*** Id? How do we make it generic?
      def idType(t: m.Table) : String =
        "Id[rows."+ t.name.schema.fold("")(_ + ".") + t.name.table.toCamelCase+"Row]"

      override def code = {
        //imports is copied right out of
        //scala.slick.model.codegen.AbstractSourceCodeGenerator
        // Why can't we simply re-use?

        var imports =
          // acyclic is unnecessary in generic projects
          //if (typeFile) "import acyclic.file\nimport dbmodels.rows\n"
          //else
          "import slick.model.ForeignKeyAction\n" +
        "import rows._\n" +
        ( if(tables.exists(_.hlistEnabled)){
          "import slick.collection.heterogeneous._\n"+
          "import slick.collection.heterogeneous.syntax._\n"
        } else ""
        ) +
        ( if(tables.exists(_.PlainSqlMapper.enabled)){
          "import slick.jdbc.{GetResult => GR}\n"+
          "// NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.\n"
        } else ""
        ) + "\n\n" // We didn't copy ddl though


        val bySchema = tables.groupBy(t => {
          t.model.name.schema
        })

        val schemaFor = (schema: String) => {
          bySchema(Option(schema)).sortBy(_.model.name.table).map(
            _.code.mkString("\n") // TODO explore here
          ).mkString("\n\n")
        }

        val schemata = mappedSchemas.keys.toList.sorted.map(
          s => indent("object" + " " + s + " {\n" + schemaFor(s)) + "\n}\n"
        ).mkString("\n\n")

        val idType =
          if (typeFile)// Should not be defined here.
            """|case class Id[T](v: Int)
               |""".stripMargin
          else
            // This should be in a separate Implicits trait
            """|implicit def idTypeMapper[A] : BaseColumnType[Id[A]] =
               |  MappedColumnType.base[Id[A], Int](_.v, Id(_))
               |import play.api.mvc.PathBindable
               |implicit def idPathBindable[A] : PathBindable[Id[A]] = implicitly[PathBindable[Int]].transform[Id[A]](Id(_),_.v)
               |""".stripMargin
        //pathbindable is play specific
        // Id works only with labdash Id
        imports + idType + schemata
      }

      // This is overridden to output classfiles elsewhere
      override def Table = new Table(_) {
        table =>
        // case classes go in the typeFile (but not types based on hlists)
        override def definitions =
          if (typeFile) Seq[Def](EntityType)
          else Seq[Def](EntityTypeRef, PlainSqlMapper, TableClassRef, TableValue)


        def EntityTypeRef = new EntityType() {
          override def code =
            s"type $name = $pkg.rows.${model.name.schema.get}.$name\n" ++
          s"val $name = $pkg.rows.${model.name.schema.get}.$name"
        }

        /** Creates a compound type from a given sequence of types.
          *  Uses HList if hlistEnabled else tuple.
          */
        override def compoundType(types: Seq[String]): String =
          /** Creates a compound value from a given sequence of values.
            *  Uses HList if hlistEnabled else tuple.
            */
          // Yes! This is part of Slick now, yes?
          if (hlistEnabled){
            def mkHList(types: List[String]): String = types match {
              case Nil => "HNil"
              case e :: tail => s"HCons[$e," + mkHList(tail) + "]"
            }
            mkHList(types.toList)
          }
          else compoundValue(types)

        //why?
        override def mappingEnabled = true

        override def compoundValue(values: Seq[String]): String =
          if (hlistEnabled) values.mkString(" :: ") + " :: HNil"
          else if (values.size == 1) values.head
          else if(values.size <= 22) s"""(${values.mkString(", ")})"""
          else throw new Exception("Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound.")

        def TableClassRef = new TableClass() {
          // We disable the option mapping for >22 columns, as it is a bit more complex to support and we don't appear to need it
          override def option = if(columns.size <= 22) super.option else ""
        }

        override def factory   =
          if(columns.size <= 22) super.factory
          else {
            val args = columns.zipWithIndex.map("a"+_._2)
            val hlist = args.mkString("::") + ":: HNil"
            val hlistType = columns.map(_.actualType).mkString("::") + ":: HNil.type"
            s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})"
          }
        override def extractor =
          if(columns.size <= 22) super.extractor
          else s"(a : ${TableClass.elementType}) => Some(" + columns.map("a."+_.name ).mkString("::") + ":: HNil)"

        // make foreign keys refer to namespaced referents
        // if the referent is in a different namespace
        override def ForeignKey = new ForeignKey(_) {
          override def code = {
            val fkColumns = compoundValue(referencingColumns.map(_.name))
            // Add the schema name to qualify the referenced table name if:
            // 1. it's in a different schema from referencingTable, and
            // 2. it's not None
            val qualifier = if (referencedTable.model.name.schema
              != referencingTable.model.name.schema) {
              referencedTable.model.name.schema match {
                case Some(schema) => schema + "."
                case None => ""
              }
            } else {
              ""
            }
            val qualifiedName = qualifier + referencedTable.TableValue.name
            val pkColumns = compoundValue(referencedColumns.map(c => s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" else ""}"))
            val fkName = referencingColumns.map(_.name).flatMap(_.split("_")).map(_.capitalize).mkString.uncapitalize + "Fk"
            s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=${onUpdate}, onDelete=${onDelete})"""
          }
        }

        override def Column = new Column(_) { column =>
          // customize db type -> scala type mapping, pls adjust it according to your environment

          override def rawType = {
            val (t, c) = derefColumn(table.model, column.model)
            //System.out.print(s"${table.model.name.asString}:${column.model.name} -> ${t.name.asString}:${c.name}\n")
            if (c.options.exists(_.toString.contains("PrimaryKey"))) idType(t)
            // ^ahahaha This is hacky
            // This should be customizeable by client
            
            else model.tpe match {
              // how does this type work out?
              // There should be a way to add adhoc custom time mappings
              case "java.sql.Date" => "tools.Date"
              case "java.sql.Time" => "tools.Time"
              case "java.sql.Timestamp" => "tools.Time"
              case _ => super.rawType
            }
          }
        }
      }
    }

    def write(c: String, name: String) = {
      (new File(name).getParentFile).mkdirs()
      val fw = new FileWriter(name)
      fw.write(c)
      fw.close()
    }
    val disableScalariform = "filename/ format: OFF\n"
    val tablesSource = codegen(false).packageCode(slickDriver, pkg, "Tables", None)
    val rowsSource = s"package $pkg.rows\n\n" + codegen(true).code

    write(disableScalariform + tablesSource, filename)
    write(disableScalariform + rowsSource, typesFilename)
  }
}