From 00688a208108506e062ad494fa88704b8caf15e2 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 21 Nov 2016 19:33:09 -0500 Subject: add parameter for schema super class --- src/main/scala/CodegenPlugin.scala | 8 +++++++- src/main/scala/NamespacedCodegen.scala | 9 ++++----- 2 files changed, 11 insertions(+), 6 deletions(-) (limited to 'src/main') diff --git a/src/main/scala/CodegenPlugin.scala b/src/main/scala/CodegenPlugin.scala index d74eda7..98d832c 100644 --- a/src/main/scala/CodegenPlugin.scala +++ b/src/main/scala/CodegenPlugin.scala @@ -12,6 +12,7 @@ object CodegenPlugin extends AutoPlugin { lazy val codegenOutputPath = SettingKey[String]("codegen-output-path", "directory to with the generated code will be written") lazy val codegenSchemaWhitelist = SettingKey[List[String]]("codegen-schema-whitelist", "schemas and tables to process") lazy val codegenForeignKeys = SettingKey[Map[TableColumn, TableColumn]]("codegen-foreign-keys", "foreign key references to data models add manually") + lazy val codegenSchemaBaseClassParts = SettingKey[List[String]]("codegen-schema-base-class-parts", "parts inherited by each generated schema object") lazy val slickCodeGenTask = TaskKey[Unit]("gen-tables", "generate the table definitions") @@ -22,6 +23,7 @@ object CodegenPlugin extends AutoPlugin { override lazy val projectSettings = Seq( codegenSchemaWhitelist := List.empty, codegenForeignKeys := Map.empty, + codegenSchemaBaseClassParts := List.empty, slickCodeGenTask := Def.taskDyn { Def.task { Generator.run( @@ -29,7 +31,11 @@ object CodegenPlugin extends AutoPlugin { codegenPackage.value, Some(codegenSchemaWhitelist.value).filter(_.nonEmpty), codegenOutputPath.value, - codegenForeignKeys.value + codegenForeignKeys.value, + codegenSchemaBaseClassParts.value match { + case Nil => "AnyRef" + case parts => parts.mkString(" with ") + } ) } }.value diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala index beed852..41905f0 100644 --- a/src/main/scala/NamespacedCodegen.scala +++ b/src/main/scala/NamespacedCodegen.scala @@ -16,12 +16,12 @@ import slick.model.{Column, Model, Table} object Generator { - def run(uri: URI, pkg: String, schemaNames: Option[List[String]], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) = { + def run(uri: URI, pkg: String, schemaNames: Option[List[String]], outputPath: String, manualForeignKeys: Map[(String, String), (String, String)], schemaBaseClass: String) = { val dc: DatabaseConfig[JdbcProfile] = DatabaseConfig.forURI[JdbcProfile](uri) val parsedSchemasOpt: Option[Map[String, List[String]]] = schemaNames.map(SchemaParser.parse) val dbModel: Model = Await.result(dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)), Duration.Inf) - val generator = new Generator(uri, pkg, dbModel, outputPath, manualForeignKeys) + val generator = new Generator(uri, pkg, dbModel, outputPath, manualForeignKeys, schemaBaseClass) val generatedCode = generator.code parsedSchemasOpt.getOrElse(Map()).keys.map(schemaName => FileHelpers.schemaOutputPath(outputPath, schemaName)) } @@ -43,7 +43,6 @@ class ImportGenerator(dbModel: Model) extends SourceCodeGenerator(dbModel) { val baseImports: String = s""" |import xyz.driver.core._ - |import xyz.driver.core.database._ | |""".stripMargin @@ -68,7 +67,7 @@ class ImportGenerator(dbModel: Model) extends SourceCodeGenerator(dbModel) { override def code: String = baseImports + hlistImports + plainSqlMapperImports } -class Generator(uri: URI, pkg: String, dbModel: Model, outputPath: String, manualForeignKeys: Map[(String, String), (String, String)]) extends SourceCodeGenerator(dbModel) with OutputHelpers { +class Generator(uri: URI, pkg: String, dbModel: Model, outputPath: String, manualForeignKeys: Map[(String, String), (String, String)], schemaBaseClass: String) extends SourceCodeGenerator(dbModel) with OutputHelpers { val packageName = new PackageNameGenerator(pkg, dbModel).code val allImports: String = new ImportGenerator(dbModel).code @@ -84,7 +83,7 @@ class Generator(uri: URI, pkg: String, dbModel: Model, outputPath: String, manua case (schemaName, tableDefs) => val tableCode = tableDefs.sortBy(_.model.name.table).map(_.code.mkString("\n")) .mkString("\n\n") val generatedSchema = s""" - |object ${schemaName} extends IdColumnTypes { + |object ${schemaName} extends $schemaBaseClass { | override val database = xyz.driver.core.database.Database.fromConfig("${uri.getFragment()}") | import database.profile.api._ | ${tableCode} -- cgit v1.2.3