aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/Main.scala
blob: 2a4624ffa993532b6bda9077471b583d81d2cf0a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import java.net.URI

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import slick.basic.DatabaseConfig
import slick.codegen.SourceCodeGenerator
import slick.jdbc.JdbcProfile

trait TableFileGenerator { self: SourceCodeGenerator =>
  def writeTablesToFile(profile: String, folder: String, pkg: String, fileName: String): Unit
}

trait RowFileGenerator { self: SourceCodeGenerator =>
  def writeRowsToFile(folder: String, pkg: String, fileName: String): Unit
}

object Generator {

  private def outputSchemaCode(schemaName: String,
                               profile: String,
                               folder: String,
                               pkg: String,
                               tableGen: TableFileGenerator,
                               rowGen: RowFileGenerator): Unit = {
    val camelSchemaName = schemaName.split('_').map(_.capitalize).mkString("")

    tableGen.writeTablesToFile(profile: String,
                               folder: String,
                               pkg: String,
                               fileName = s"${camelSchemaName}Tables.scala")
    rowGen.writeRowsToFile(folder: String, pkg: String, fileName = s"${camelSchemaName}Rows.scala")
  }

  def run(uri: URI,
          pkg: String,
          schemaNames: Option[List[String]],
          outputPath: String,
          manualForeignKeys: Map[(String, String), (String, String)],
          parentType: Option[String],
          idType: Option[String],
          header: String,
          tablesFileImports: List[String],
          rowsFileImports: List[String],
          typeReplacements: Map[String, String]) = {
    val dc: DatabaseConfig[JdbcProfile] =
      DatabaseConfig.forURI[JdbcProfile](uri)
    val parsedSchemasOpt: Option[Map[String, List[String]]] =
      schemaNames.map(ModelTransformation.parseSchemaList)

    def importStatements(imports: List[String]) = imports.map("import " + _).mkString("\n")

    try {
      val dbModel: slick.model.Model =
        Await.result(dc.db.run(ModelTransformation.createModel(dc.profile, parsedSchemasOpt)), Duration.Inf)

      parsedSchemasOpt.getOrElse(Map.empty).foreach {
        case (schemaName, tables) =>
          val profile =
            s"""slick.backend.DatabaseConfig.forConfig[slick.driver.JdbcProfile]("${uri
              .getFragment()}").driver"""

          val schemaOnlyModel = Await.result(dc.db.run(ModelTransformation
                                               .createModel(dc.profile, Some(Map(schemaName -> tables)))),
                                             Duration.Inf)

          val rowGenerator = new RowSourceCodeGenerator(
            model = schemaOnlyModel,
            headerComment = header,
            imports = importStatements(rowsFileImports),
            schemaName = schemaName,
            fullDatabaseModel = dbModel,
            idType,
            manualForeignKeys,
            typeReplacements
          )

          val tableGenerator =
            new TableSourceCodeGenerator(
              schemaOnlyModel = schemaOnlyModel,
              headerComment = header,
              imports = importStatements(tablesFileImports),
              schemaName = schemaName,
              fullDatabaseModel = dbModel,
              pkg = pkg,
              manualForeignKeys,
              parentType = parentType,
              idType,
              typeReplacements
            )

          outputSchemaCode(schemaName = schemaName,
                           profile = profile,
                           folder = outputPath,
                           pkg = pkg,
                           tableGen = tableGenerator,
                           rowGen = rowGenerator)
      }
    } finally {
      dc.db.close()
    }
  }
}