diff options
author | Stewart Stewart <stewinsalot@gmail.com> | 2017-02-27 18:18:55 -0800 |
---|---|---|
committer | Stewart Stewart <stewinsalot@gmail.com> | 2017-02-27 18:18:55 -0800 |
commit | 0d2f133e92a0f9601d4fb82924b13d1a0416b222 (patch) | |
tree | d733201cec9eff18f9764c91723db772fc8d995e | |
parent | 6177b8733dca879845e035321ee585a78df7c399 (diff) | |
download | slick-codegen-plugin-0d2f133e92a0f9601d4fb82924b13d1a0416b222.tar.gz slick-codegen-plugin-0d2f133e92a0f9601d4fb82924b13d1a0416b222.tar.bz2 slick-codegen-plugin-0d2f133e92a0f9601d4fb82924b13d1a0416b222.zip |
simplify code with foreign keys types
-rw-r--r-- | src/main/scala/TypedIdTable.scala | 64 |
1 files changed, 24 insertions, 40 deletions
diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala index 3867d1a..93ea8ee 100644 --- a/src/main/scala/TypedIdTable.scala +++ b/src/main/scala/TypedIdTable.scala @@ -16,53 +16,37 @@ class TypedIdSourceCodeGenerator( "Int" -> "serialKeyMapper" ) - def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { - val referencedColumn: Seq[(m.Table, m.Column)] = - table.foreignKeys - .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) - .filter(columnFk => columnFk.referencedColumns.length == 1) - .flatMap(_.referencedColumns.map(c => - (databaseModel.tablesByName(c.table), c))) - assert(referencedColumn.distinct.length <= 1, referencedColumn) + def pKeyTypeTag(columnRef: m.Column): String = { + val schemaName = columnRef.table.schema.getOrElse("`public`") + val tableName = entityName(columnRef.table.table) + s"$schemaName.$tableName." + } - referencedColumn.headOption - .orElse(manualReferences.get((table.name.asString, column.name))) - .map((derefColumn _).tupled) - .getOrElse((table, column)) + def pKeyType(columnRef: m.Column): String = { + s"${idType.getOrElse("Id")}[${pKeyTypeTag(columnRef)}]" } class TypedIdTable(model: m.Table) extends Table(model) { table => - class TypedIdColumn(override val model: m.Column) extends Column(model) { - column => - - def rowTypeFor(tableName: m.QualifiedName) = { - val schemaObjectName = tableName.schema.getOrElse("`public`") - val rowTypeName = entityName(tableName.table) - s"$schemaObjectName.$rowTypeName" - } - override def code = { - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) - s"""|implicit val ${name}KeyMapper: BaseColumnType[${rawType}] = - | ${modelTypeToColumnMaper(model.tpe)}[${rowTypeFor(referencedTable.name)}]\n - |${super.code}""" - else - super.code - } + val keyReferences: Map[m.Column, m.Column] = { + val fks = model.foreignKeys + .filter(_.referencedColumns.length == 1) + .filter(_.referencedColumns.forall( + _.options.contains(slick.ast.ColumnOption.PrimaryKey))) + .flatMap(fk => + fk.referencingColumns.flatMap(from => + fk.referencedColumns.headOption.map(to => (from -> to)))) + val pk = model.primaryKey + .filter(_.columns.length == 1) + .flatMap(_.columns.headOption.map(c => (c -> c))) + + fks.toMap ++ pk + } + class TypedIdColumn(override val model: m.Column) extends Column(model) { + column => override def rawType: String = { - // write key columns as Id types - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) { - val idTypeName = idType.getOrElse("Id") - s"$idTypeName[${rowTypeFor(referencedTable.name)}]" - } - else super.rawType + keyReferences.get(model).fold(super.rawType)(pKeyType) } } } |