From 0d2f133e92a0f9601d4fb82924b13d1a0416b222 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 27 Feb 2017 18:18:55 -0800 Subject: simplify code with foreign keys types --- src/main/scala/TypedIdTable.scala | 64 +++++++++++++++------------------------ 1 file changed, 24 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/main/scala/TypedIdTable.scala b/src/main/scala/TypedIdTable.scala index 3867d1a..93ea8ee 100644 --- a/src/main/scala/TypedIdTable.scala +++ b/src/main/scala/TypedIdTable.scala @@ -16,53 +16,37 @@ class TypedIdSourceCodeGenerator( "Int" -> "serialKeyMapper" ) - def derefColumn(table: m.Table, column: m.Column): (m.Table, m.Column) = { - val referencedColumn: Seq[(m.Table, m.Column)] = - table.foreignKeys - .filter(tableFk => tableFk.referencingColumns.forall(_ == column)) - .filter(columnFk => columnFk.referencedColumns.length == 1) - .flatMap(_.referencedColumns.map(c => - (databaseModel.tablesByName(c.table), c))) - assert(referencedColumn.distinct.length <= 1, referencedColumn) + def pKeyTypeTag(columnRef: m.Column): String = { + val schemaName = columnRef.table.schema.getOrElse("`public`") + val tableName = entityName(columnRef.table.table) + s"$schemaName.$tableName." + } - referencedColumn.headOption - .orElse(manualReferences.get((table.name.asString, column.name))) - .map((derefColumn _).tupled) - .getOrElse((table, column)) + def pKeyType(columnRef: m.Column): String = { + s"${idType.getOrElse("Id")}[${pKeyTypeTag(columnRef)}]" } class TypedIdTable(model: m.Table) extends Table(model) { table => - class TypedIdColumn(override val model: m.Column) extends Column(model) { - column => - - def rowTypeFor(tableName: m.QualifiedName) = { - val schemaObjectName = tableName.schema.getOrElse("`public`") - val rowTypeName = entityName(tableName.table) - s"$schemaObjectName.$rowTypeName" - } - override def code = { - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) - s"""|implicit val ${name}KeyMapper: BaseColumnType[${rawType}] = - | ${modelTypeToColumnMaper(model.tpe)}[${rowTypeFor(referencedTable.name)}]\n - |${super.code}""" - else - super.code - } + val keyReferences: Map[m.Column, m.Column] = { + val fks = model.foreignKeys + .filter(_.referencedColumns.length == 1) + .filter(_.referencedColumns.forall( + _.options.contains(slick.ast.ColumnOption.PrimaryKey))) + .flatMap(fk => + fk.referencingColumns.flatMap(from => + fk.referencedColumns.headOption.map(to => (from -> to)))) + val pk = model.primaryKey + .filter(_.columns.length == 1) + .flatMap(_.columns.headOption.map(c => (c -> c))) + + fks.toMap ++ pk + } + class TypedIdColumn(override val model: m.Column) extends Column(model) { + column => override def rawType: String = { - // write key columns as Id types - val (referencedTable, referencedColumn) = - derefColumn(table.model, column.model) - if (referencedColumn.options.contains( - slick.ast.ColumnOption.PrimaryKey)) { - val idTypeName = idType.getOrElse("Id") - s"$idTypeName[${rowTypeFor(referencedTable.name)}]" - } - else super.rawType + keyReferences.get(model).fold(super.rawType)(pKeyType) } } } -- cgit v1.2.3