diff options
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/CodegenPlugin.scala | 8 | ||||
-rw-r--r-- | src/main/scala/NamespacedCodegen.scala | 38 |
2 files changed, 26 insertions, 20 deletions
diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala index f9351c1..92e595f 100644 --- a/src/main/scala/CodegenPlugin.scala +++ b/src/main/scala/CodegenPlugin.scala @@ -24,6 +24,10 @@ object CodegenPlugin extends AutoPlugin { lazy val codegenSchemaBaseClassParts = SettingKey[List[String]]( "codegen-schema-base-class-parts", "parts inherited by each generated schema object") + lazy val codegenIdType = SettingKey[String]( + "codegen-id-type", + "The in-scope type `T` of kind `T[TableRow]` to apply in place T for id columns." + ) lazy val slickCodeGenTask = TaskKey[Unit]("gen-tables", "generate the table definitions") @@ -36,6 +40,7 @@ object CodegenPlugin extends AutoPlugin { codegenSchemaWhitelist := List.empty, codegenForeignKeys := Map.empty, codegenSchemaBaseClassParts := List.empty, + codegenIdType := "Id", slickCodeGenTask := Def.taskDyn { Def.task { Generator.run( @@ -47,7 +52,8 @@ object CodegenPlugin extends AutoPlugin { codegenSchemaBaseClassParts.value match { case Nil => "AnyRef" case parts => parts.mkString(" with ") - } + }, + codegenIdType.value ) } }.value diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala index 3e4e5e1..6928cde 100644 --- a/src/main/scala/NamespacedCodegen.scala +++ b/src/main/scala/NamespacedCodegen.scala @@ -14,7 +14,7 @@ import slick.dbio.DBIO import slick.driver.JdbcProfile import slick.jdbc.meta.MTable import slick.{model => sModel} -import slick.model.{Column, Model, Table} +import slick.model.{Column, Model, Table, QualifiedName} object Generator { @@ -23,7 +23,8 @@ object Generator { schemaNames: Option[List[String]], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)], - schemaBaseClass: String) = { + schemaBaseClass: String, + idTypeName: String) = { val dc: DatabaseConfig[JdbcProfile] = DatabaseConfig.forURI[JdbcProfile](uri) val parsedSchemasOpt: Option[Map[String, List[String]]] = @@ -37,7 +38,8 @@ object Generator { dbModel, outputPath, manualForeignKeys, - schemaBaseClass) + schemaBaseClass, + idTypeName) generator.code // Yes... Files are written as a side effect parsedSchemasOpt .getOrElse(Map()) @@ -61,7 +63,7 @@ class PackageNameGenerator(pkg: String, dbModel: Model) class ImportGenerator(dbModel: Model) extends SourceCodeGenerator(dbModel) { val baseImports: String = s""" - |import xyz.driver.core._ + |import xyz.driver.core.Id | |""".stripMargin @@ -92,7 +94,8 @@ class Generator(uri: URI, dbModel: Model, outputPath: String, manualForeignKeys: Map[(String, String), (String, String)], - schemaBaseClass: String) + schemaBaseClass: String, + idTypeName: String) extends SourceCodeGenerator(dbModel) with OutputHelpers { @@ -199,11 +202,19 @@ class Generator(uri: URI, .getOrElse((table, column)) } + def idType(tableName: QualifiedName) = { + val schemaObjectName = tableName.schema.getOrElse("`public`") + val rowTypeName = entityName(tableName.table) + s"$idTypeName[$schemaObjectName.$rowTypeName]" + } + // re-write ids, and time types override def rawType: String = { - val (t, c) = derefColumn(table.model, column.model) - if (c.options.contains(slick.ast.ColumnOption.PrimaryKey)) - TypeGenerator.idType(pkg, t) + val (referencedTable, referencedColumn) = + derefColumn(table.model, column.model) + if (referencedColumn.options.contains( + slick.ast.ColumnOption.PrimaryKey)) + idType(referencedTable.name) else model.tpe match { // TODO: There should be a way to add adhoc custom time mappings @@ -289,17 +300,6 @@ object SchemaParser { } -object TypeGenerator extends StringGeneratorHelpers { - // generate the id types - def idType(pkg: String, t: sModel.Table): String = { - val header = s"Id[" - val schemaName = t.name.schema.fold("")(_ + ".") - val tableName = (t.name.table).toCamelCase - val footer = "]" - s"${header}${pkg}.${schemaName}${tableName}Row${footer}" - } -} - object FileHelpers { def schemaOutputPath(path: String, schemaName: String): String = Paths.get(path, s"${schemaName}.scala").toAbsolutePath().toString() |