diff options
-rw-r--r-- | src/main/scala/Generators.scala | 137 | ||||
-rw-r--r-- | src/main/scala/Main.scala | 108 | ||||
-rw-r--r-- | src/main/scala/NamespacedCodegen.scala | 262 | ||||
-rw-r--r-- | src/main/scala/OutputHelpers.scala | 72 | ||||
-rw-r--r-- | src/main/scala/SchemaParser.scala | 60 | ||||
-rw-r--r-- | src/main/scala/TypedIdTable.scala | 50 |
6 files changed, 415 insertions, 274 deletions
diff --git a/src/main/scala/Generators.scala b/src/main/scala/Generators.scala new file mode 100644 index 0000000..fb75cc1 --- /dev/null +++ b/src/main/scala/Generators.scala @@ -0,0 +1,137 @@ +import slick.codegen.SourceCodeGenerator +import slick.{model => m} + +class RowSourceCodeGenerator( + model: m.Model, + override val headerComment: String, + override val imports: String, + override val schemaName: String, + fullDatabaseModel: m.Model, + idType: Option[String], + manualForeignKeys: Map[(String, String), (String, String)], + typeReplacements: Map[String, String] +) extends TypedIdSourceCodeGenerator( + model, + fullDatabaseModel, + idType, + manualForeignKeys + ) + with RowOutputHelpers { + + override def Table = new TypedIdTable(_) { table => + + override def Column = new TypedIdColumn(_) { + override def rawType: String = { + typeReplacements.getOrElse(model.tpe, super.rawType) + } + } + + override def EntityType = new EntityType { + override def code: String = + (if (classEnabled) "final " else "") + super.code + } + + override def code = Seq[Def](EntityType).map(_.docWithCode) + } + + override def code = tables.map(_.code.mkString("\n")).mkString("\n\n") +} + +class TableSourceCodeGenerator( + schemaOnlyModel: m.Model, + override val headerComment: String, + override val imports: String, + override val schemaName: String, + fullDatabaseModel: m.Model, + pkg: String, + manualForeignKeys: Map[(String, String), (String, String)], + override val parentType: Option[String], + idType: Option[String], + typeReplacements: Map[String, String]) + extends TypedIdSourceCodeGenerator(schemaOnlyModel, + fullDatabaseModel, + idType, + manualForeignKeys) + with TableOutputHelpers { + + val defaultIdImplementation = + """|final case class Id[T](v: Int) + |trait DefaultIdTypeMapper { + | val profile: slick.driver.JdbcProfile + | import profile.api._ + | implicit def idTypeMapper[A]: BaseColumnType[Id[A]] = MappedColumnType.base[Id[A], Int](_.v, Id(_)) + |} + |""".stripMargin + + override def code = super.code.lines.drop(1).mkString("\n") + // Drops needless import: `"import slick.model.ForeignKeyAction\n"`. + // Alias to ForeignKeyAction is in profile.api + // TODO: fix upstream + + override def Table = new this.TypedIdTable(_) { table => + override def TableClass = new TableClass() { + // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it + override def optionEnabled = false + } + + // use hlists all the time + override def hlistEnabled: Boolean = true + + // if false rows are type aliases to hlists, if true rows are case classes + override def mappingEnabled: Boolean = true + + // create case class from colums + override def factory: String = + if (!hlistEnabled) super.factory + else { + val args = columns.zipWithIndex.map("a" + _._2) + val hlist = args.mkString("::") + ":: HNil" + val hlistType = columns + .map(_.actualType) + .mkString("::") + ":: HNil.type" + s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" + } + + // from case class create columns + override def extractor: String = + if (!hlistEnabled) super.extractor + else + s"(a : ${TableClass.elementType}) => Some(" + columns + .map("a." + _.name) + .mkString("::") + ":: HNil)" + + override def EntityType = new EntityType { + override def enabled = false + } + + override def Column = new TypedIdColumn(_) { + override def rawType: String = { + typeReplacements.getOrElse(model.tpe, super.rawType) + } + } + + override def ForeignKey = new ForeignKey(_) { + override def code = { + val fkColumns = compoundValue(referencingColumns.map(_.name)) + val qualifier = + if (referencedTable.model.name.schema == referencingTable.model.name.schema) + "" + else + referencedTable.model.name.schema.fold("")(sname => + s"$pkg.$sname.") + + val qualifiedName = qualifier + referencedTable.TableValue.name + val pkColumns = compoundValue(referencedColumns.map(c => + s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" + else ""}")) + val fkName = referencingColumns + .map(_.name) + .flatMap(_.split("_")) + .map(_.capitalize) + .mkString + .uncapitalize + "Fk" + s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)""" + } + } + } +} diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala new file mode 100644 index 0000000..55275a3 --- /dev/null +++ b/src/main/scala/Main.scala @@ -0,0 +1,108 @@ +import java.net.URI +import java.nio.file.Paths + +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.concurrent.ExecutionContext.Implicits.global +import slick.backend.DatabaseConfig +import slick.codegen.SourceCodeGenerator +import slick.driver.JdbcProfile + +trait TableFileGenerator { self: SourceCodeGenerator => + def writeTablesToFile(profile: String, + folder: String, + pkg: String, + fileName: String): Unit +} + +trait RowFileGenerator { self: SourceCodeGenerator => + def writeRowsToFile(folder: String, pkg: String, fileName: String): Unit +} + +object Generator { + + private def outputSchemaCode(schemaName: String, + profile: String, + folder: String, + pkg: String, + tableGen: TableFileGenerator, + rowGen: RowFileGenerator): Unit = { + val camelSchemaName = schemaName.split('_').map(_.capitalize).mkString("") + + tableGen.writeTablesToFile(profile: String, + folder: String, + pkg: String, + fileName = s"${camelSchemaName}Tables.scala") + rowGen.writeRowsToFile(folder: String, + pkg: String, + fileName = s"${camelSchemaName}Rows.scala") + } + + def run(uri: URI, + pkg: String, + schemaNames: Option[List[String]], + outputPath: String, + manualForeignKeys: Map[(String, String), (String, String)], + parentType: Option[String], + idType: Option[String], + header: String, + schemaImports: List[String], + typeReplacements: Map[String, String]) = { + val dc: DatabaseConfig[JdbcProfile] = + DatabaseConfig.forURI[JdbcProfile](uri) + val parsedSchemasOpt: Option[Map[String, List[String]]] = + schemaNames.map(SchemaParser.parse) + val imports = schemaImports.map("import " + _).mkString("\n") + + try { + val dbModel: slick.model.Model = Await.result( + dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)), + Duration.Inf) + + parsedSchemasOpt.getOrElse(Map.empty).foreach { + case (schemaName, tables) => + val profile = + s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri + .getFragment()}").driver""" + + val schemaOnlyModel = Await.result( + dc.db.run( + SchemaParser.createModel(dc.driver, + Some(Map(schemaName -> tables)))), + Duration.Inf) + + val rowGenerator = new RowSourceCodeGenerator( + schemaOnlyModel, + headerComment = header, + imports = imports, + schemaName = schemaName, + dbModel, + idType, + manualForeignKeys, + typeReplacements + ) + + val tableGenerator = + new TableSourceCodeGenerator(schemaOnlyModel = schemaOnlyModel, + headerComment = header, + imports = imports, + schemaName = schemaName, + fullDatabaseModel = dbModel, + pkg = pkg, + manualForeignKeys, + parentType = parentType, + idType, + typeReplacements) + + outputSchemaCode(schemaName = schemaName, + profile = profile, + folder = outputPath, + pkg = pkg, + tableGen = tableGenerator, + rowGen = rowGenerator) + } + } finally { + dc.db.close() + } + } +} diff --git a/src/main/scala/NamespacedCodegen.scala b/src/main/scala/NamespacedCodegen.scala deleted file mode 100644 index 85e21bf..0000000 --- a/src/main/scala/NamespacedCodegen.scala +++ /dev/null @@ -1,262 +0,0 @@ -import java.net.URI -import java.nio.file.Paths - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.concurrent.ExecutionContext.Implicits.global -import slick.backend.DatabaseConfig -import slick.codegen.{ - SourceCodeGenerator, - StringGeneratorHelpers -} -import slick.dbio.DBIO -import slick.driver.JdbcProfile -import slick.jdbc.meta.MTable -import slick.{model => sModel} -import slick.model.{Column, Model, Table, QualifiedName} - -object Generator { - - def run(uri: URI, - pkg: String, - schemaNames: Option[List[String]], - outputPath: String, - manualForeignKeys: Map[(String, String), (String, String)], - parentType: Option[String], - idType: Option[String], - header: String, - schemaImports: List[String], - typeReplacements: Map[String, String]) = { - val dc: DatabaseConfig[JdbcProfile] = - DatabaseConfig.forURI[JdbcProfile](uri) - val parsedSchemasOpt: Option[Map[String, List[String]]] = - schemaNames.map(SchemaParser.parse) - - try { - val dbModel: Model = Await.result( - dc.db.run(SchemaParser.createModel(dc.driver, parsedSchemasOpt)), - Duration.Inf) - - parsedSchemasOpt.getOrElse(Map.empty).foreach { - case (schemaName, tables) => - val profile = - s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri - .getFragment()}").driver""" - - val schemaOnlyModel = Await.result( - dc.db.run( - SchemaParser.createModel(dc.driver, - Some(Map(schemaName -> tables)))), - Duration.Inf) - - val generator = new Generator(pkg, - dbModel, - schemaOnlyModel, - manualForeignKeys, - parentType, - idType, - header, - schemaImports, - typeReplacements) - generator.writeToFile(profile = profile, - folder = outputPath, - pkg = pkg, - container = schemaName, - fileName = s"${schemaName}.scala") - } - } finally { - dc.db.close() - } - } - -} - -class Generator(pkg: String, - fullDatabaseModel: Model, - schemaOnlyModel: Model, - manualForeignKeys: Map[(String, String), (String, String)], - override val parentType: Option[String], - idType: Option[String], - override val headerComment: String, - schemaImports: List[String], - typeReplacements: Map[String, String]) - extends SourceCodeGenerator(schemaOnlyModel) - with OutputHelpers { - - override val imports = schemaImports.map("import " + _).mkString("\n") - - val defaultIdImplementation = - """|final case class Id[T](v: Int) - |trait DefaultIdTypeMapper { - | val profile: slick.driver.JdbcProfile - | import profile.api._ - | implicit def idTypeMapper[A]: BaseColumnType[Id[A]] = MappedColumnType.base[Id[A], Int](_.v, Id(_)) - |} - |""".stripMargin - - override def code = super.code.lines.drop(1).mkString("\n") - // Drops needless import: `"import slick.model.ForeignKeyAction\n"`. - // Alias to ForeignKeyAction is in profile.api - // TODO: fix upstream - - override def Table = new Table(_) { table => - - override def TableClass = new TableClass() { - // We disable the option mapping, as it is a bit more complex to support and we don't appear to need it - override def optionEnabled = false - } - - // use hlists all the time - override def hlistEnabled: Boolean = true - - // if false rows are type aliases to hlists, if true rows are case classes - override def mappingEnabled: Boolean = true - - // create case class from colums - override def factory: String = - if (!hlistEnabled) super.factory - else { - val args = columns.zipWithIndex.map("a" + _._2) - val hlist = args.mkString("::") + ":: HNil" - val hlistType = columns - .map(_.actualType) - .mkString("::") + ":: HNil.type" - s"((h : $hlistType) => h match {case $hlist => ${TableClass.elementType}(${args.mkString(",")})})" - } - - // from case class create columns - override def extractor: String = - if (!hlistEnabled) super.extractor - else - s"(a : ${TableClass.elementType}) => Some(" + columns - .map("a." + _.name) - .mkString("::") + ":: HNil)" - - override def EntityType = new EntityTypeDef { - override def code: String = - // Wartremover wants `final` - // But can't have the final case class inside the trait - // TODO: Fix by putting case classes in package or object - // TODO: Upstream default should be false. - (if (classEnabled) "sealed " else "") + super.code - } - - override def Column = new Column(_) { column => - // use fullDatabasemodel model here for cross-schema foreign keys - val manualReferences = - SchemaParser.references(fullDatabaseModel, manualForeignKeys) - - // work out the destination of the foreign key - def derefColumn(table: sModel.Table, - column: sModel.Column): (sModel.Table, sModel.Column) = { - val referencedColumn: Seq[(sModel.Table, sModel.Column)] = - table.foreignKeys - .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) - .filter(columnFk => columnFk.referencedColumns.length == 1) - .flatMap(_.referencedColumns.map(c => - (fullDatabaseModel.tablesByName(c.table), c))) - assert(referencedColumn.distinct.length <= 1, referencedColumn) - - referencedColumn.headOption - .orElse(manualReferences.get((table.name.asString, column.name))) - .map((derefColumn _).tupled) - .getOrElse((table, column)) - } - - def tableReferenceName(tableName: QualifiedName) = { - val schemaObjectName = tableName.schema.getOrElse("`public`") - val rowTypeName = entityName(tableName.table) - val idTypeName = idType.getOrElse("Id") - s"$idTypeName[$schemaObjectName.$rowTypeName]" - } - - // re-write ids other custom types - override def rawType: String = { - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) - tableReferenceName(referencedTable.name) - else typeReplacements.getOrElse(model.tpe, model.tpe) - } - } - - override def ForeignKey = new ForeignKey(_) { - override def code = { - val fkColumns = compoundValue(referencingColumns.map(_.name)) - val qualifier = - if (referencedTable.model.name.schema == referencingTable.model.name.schema) - "" - else - referencedTable.model.name.schema.fold("")(sname => - s"$pkg.$sname.") - - val qualifiedName = qualifier + referencedTable.TableValue.name - val pkColumns = compoundValue(referencedColumns.map(c => - s"r.${c.name}${if (!c.model.nullable && referencingColumns.forall(_.model.nullable)) ".?" - else ""}")) - val fkName = referencingColumns - .map(_.name) - .flatMap(_.split("_")) - .map(_.capitalize) - .mkString - .uncapitalize + "Fk" - s"""lazy val $fkName = foreignKey("$dbName", $fkColumns, $qualifiedName)(r => $pkColumns, onUpdate=$onUpdate, onDelete=$onDelete)""" - } - } - } -} - -object SchemaParser { - def references(dbModel: Model, - tcMappings: Map[(String, String), (String, String)]) - : Map[(String, String), (Table, Column)] = { - def getTableColumn(tc: (String, String)): (Table, Column) = { - val (tableName, columnName) = tc - val table = dbModel.tables - .find(_.name.asString == tableName) - .getOrElse(throw new RuntimeException("No table " + tableName)) - val column = table.columns - .find(_.name == columnName) - .getOrElse(throw new RuntimeException( - "No column " + columnName + " in table " + tableName)) - (table, column) - } - - tcMappings.map { - case (from, to) => ({ getTableColumn(from); from }, getTableColumn(to)) - } - } - - def parse(schemaTableNames: List[String]): Map[String, List[String]] = - schemaTableNames - .map(_.split('.')) - .groupBy(_.head) - .mapValues(_.flatMap(_.tail)) - - def createModel( - jdbcProfile: JdbcProfile, - mappedSchemasOpt: Option[Map[String, List[String]]]): DBIO[Model] = { - import slick.jdbc.meta.MQName - - val filteredTables = mappedSchemasOpt.map { mappedSchemas => - MTable.getTables.map { (tables: Vector[MTable]) => - mappedSchemas.flatMap { - case (schemaName, tableNames) => - tableNames.map( - tableName => - tables - .find(table => - table.name match { - case MQName(_, Some(`schemaName`), `tableName`) => true - case _ => false - }) - .getOrElse(throw new IllegalArgumentException( - s"$schemaName.$tableName does not exist in the connected database."))) - }.toList - } - } - - jdbcProfile.createModel(filteredTables) - } -} diff --git a/src/main/scala/OutputHelpers.scala b/src/main/scala/OutputHelpers.scala index ce22f2a..97fb4e9 100644 --- a/src/main/scala/OutputHelpers.scala +++ b/src/main/scala/OutputHelpers.scala @@ -1,28 +1,76 @@ -trait OutputHelpers extends slick.codegen.OutputHelpers { +import slick.codegen.{SourceCodeGenerator, OutputHelpers} - def imports: String +trait TableOutputHelpers extends TableFileGenerator with OutputHelpers { + self: SourceCodeGenerator => - def headerComment: String = "" + def headerComment: String + def schemaName: String + def imports: String - override def packageCode(profile: String, - pkg: String, - container: String, - parentType: Option[String]): String = { - val traitName = container.capitalize + "SchemaDef" + def packageTableCode(headerComment: String, + pkg: String, + schemaName: String, + imports: String, + profile: String): String = s"""|${headerComment.trim().lines.map("// " + _).mkString("\n")} |package $pkg + |package $schemaName | |$imports | |/** Stand-alone Slick data model for immediate use */ - |object $container extends { + |// TODO: change this to `object tables` + |object `package` extends { | val profile = $profile - |} with $traitName + |} with Tables | |/** Slick data model trait for extension, choice of backend or usage in the cake pattern. (Make sure to initialize this late.) */ - |trait $traitName${parentType.fold("")(" extends " + _)} { + |trait Tables${parentType.fold("")(" extends " + _)} { | import profile.api._ | ${indent(code)} - |}""".stripMargin.trim() + |} + |""".stripMargin.trim() + + def writeTablesToFile(profile: String, + folder: String, + pkg: String, + fileName: String): Unit = { + writeStringToFile( + content = + packageTableCode(headerComment, pkg, schemaName, imports, profile), + folder = folder, + pkg = s"$pkg.$schemaName", + fileName = fileName) + } +} + +trait RowOutputHelpers extends RowFileGenerator with OutputHelpers { + self: SourceCodeGenerator => + + def headerComment: String + def schemaName: String + def imports: String + + def packageRowCode(headerComment: String, + schemaName: String, + pkg: String, + imports: String): String = + s"""|${headerComment.trim().lines.map("// " + _).mkString("\n")} + |/** Definitions for table rows types of database schema $schemaName */ + |package $pkg + |package $schemaName + | + |$imports + | + |$code + |""".stripMargin.trim() + + def writeRowsToFile(folder: String, pkg: String, fileName: String): Unit = { + + writeStringToFile( + content = packageRowCode(headerComment, schemaName, pkg, imports), + folder = folder, + pkg = s"$pkg.$schemaName", + fileName = fileName) } } diff --git a/src/main/scala/SchemaParser.scala b/src/main/scala/SchemaParser.scala new file mode 100644 index 0000000..1186f11 --- /dev/null +++ b/src/main/scala/SchemaParser.scala @@ -0,0 +1,60 @@ +import scala.concurrent.ExecutionContext.Implicits.global + +import slick.dbio.DBIO +import slick.driver.JdbcProfile +import slick.jdbc.meta.MTable +import slick.{model => m} + +object SchemaParser { + def references(dbModel: m.Model, + tcMappings: Map[(String, String), (String, String)]) + : Map[(String, String), (m.Table, m.Column)] = { + def getTableColumn(tc: (String, String)): (m.Table, m.Column) = { + val (tableName, columnName) = tc + val table = dbModel.tables + .find(_.name.asString == tableName) + .getOrElse(throw new RuntimeException("No table " + tableName)) + val column = table.columns + .find(_.name == columnName) + .getOrElse(throw new RuntimeException( + "No column " + columnName + " in table " + tableName)) + (table, column) + } + + tcMappings.map { + case (from, to) => ({ getTableColumn(from); from }, getTableColumn(to)) + } + } + + def parse(schemaTableNames: List[String]): Map[String, List[String]] = + schemaTableNames + .map(_.split('.')) + .groupBy(_.head) + .mapValues(_.flatMap(_.tail)) + + def createModel( + jdbcProfile: JdbcProfile, + mappedSchemasOpt: Option[Map[String, List[String]]]): DBIO[m.Model] = { + import slick.jdbc.meta.MQName + + val filteredTables = mappedSchemasOpt.map { mappedSchemas => + MTable.getTables.map { (tables: Vector[MTable]) => + mappedSchemas.flatMap { + case (schemaName, tableNames) => + tableNames.map( + tableName => + tables + .find(table => + table.name match { + case MQName(_, Some(`schemaName`), `tableName`) => true + case _ => false + }) + .getOrElse(throw new IllegalArgumentException( + s"$schemaName.$tableName does not exist in the connected database."))) + }.toList + } + } + + jdbcProfile.createModel(filteredTables) + } +} diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala new file mode 100644 index 0000000..1a8f986 --- /dev/null +++ b/src/main/scala/TypedIdTable.scala @@ -0,0 +1,50 @@ +import slick.codegen.SourceCodeGenerator +import slick.{model => m} + +class TypedIdSourceCodeGenerator( + singleSchemaModel: m.Model, + databaseModel: m.Model, + idType: Option[String], + manualForeignKeys: Map[(String, String), (String, String)] +) extends SourceCodeGenerator(singleSchemaModel) { + val manualReferences = + SchemaParser.references(databaseModel, manualForeignKeys) + + def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { + val referencedColumn: Seq[(m.Table, m.Column)] = + table.foreignKeys + .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) + .filter(columnFk => columnFk.referencedColumns.length == 1) + .flatMap(_.referencedColumns.map(c => + (databaseModel.tablesByName(c.table), c))) + assert(referencedColumn.distinct.length <= 1, referencedColumn) + + referencedColumn.headOption + .orElse(manualReferences.get((table.name.asString, column.name))) + .map((derefColumn _).tupled) + .getOrElse((table, column)) + } + + class TypedIdTable(model: m.Table) extends Table(model) { table => + class TypedIdColumn(override val model: m.Column) extends Column(model) { + column => + + def tableReferenceName(tableName: m.QualifiedName) = { + val schemaObjectName = tableName.schema.getOrElse("`public`") + val rowTypeName = entityName(tableName.table) + val idTypeName = idType.getOrElse("Id") + s"$idTypeName[$schemaObjectName.$rowTypeName]" + } + + override def rawType: String = { + // write key columns as Id types + val (referencedTable, referencedColumn) = + derefColumn(table.model, column.model) + if (referencedColumn.options.contains( + slick.ast.ColumnOption.PrimaryKey)) + tableReferenceName(referencedTable.name) + else super.rawType + } + } + } +} |