aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Pretty <jon.pretty@propensive.com>2017-11-05 19:57:30 +0000
committerJon Pretty <jon.pretty@propensive.com>2017-11-05 19:57:30 +0000
commita42cceae99ca8517ecff77fecdb23eba4d2c1036 (patch)
tree8deb9af6764158556ba016a660aebaa5aa71723f
parent3f9ee733ac73f31337433227eb6871efce18981c (diff)
downloadmagnolia-a42cceae99ca8517ecff77fecdb23eba4d2c1036.tar.gz
magnolia-a42cceae99ca8517ecff77fecdb23eba4d2c1036.tar.bz2
magnolia-a42cceae99ca8517ecff77fecdb23eba4d2c1036.zip
Deduplication within case class parameter typeclasses
-rw-r--r--core/src/main/scala/magnolia.scala119
-rw-r--r--examples/src/main/scala/typeclasses.scala10
-rw-r--r--tests/src/main/scala/tests.scala2
3 files changed, 80 insertions, 51 deletions
diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala
index 2d2bbcd..50d3ab5 100644
--- a/core/src/main/scala/magnolia.scala
+++ b/core/src/main/scala/magnolia.scala
@@ -48,8 +48,9 @@ object JoinContext {
}
}
-abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, val parameters: Array[Param[Tc, T]]) {
+abstract class JoinContext[Tc[_], T](val typeName: String, val isObject: Boolean, params: Array[Param[Tc, T]]) {
def construct(param: ((Param[Tc, T]) => Any)): T
+ def parameters: Seq[Param[Tc, T]] = params
}
object Magnolia {
@@ -65,6 +66,8 @@ object Magnolia {
def findType(key: Type): Option[TermName] =
recursionStack(c.enclosingPosition).frames.find(_.genericType == key).map(_.termName(c))
+ case class Typeclass(typ: c.Type, tree: c.Tree)
+
def recurse[T](path: TypePath, key: Type, value: TermName)(fn: => T):
Option[T] = {
recursionStack = recursionStack.updated(
@@ -89,11 +92,11 @@ object Magnolia {
}
}
- def typeclassTree(paramName: Option[String],
- genericType: Type,
- typeConstructor: Type,
- assignedName: TermName): Tree = {
+ def typeclassTree(paramName: Option[String], genericType: Type, typeConstructor: Type,
+ assignedName: TermName): Tree = {
+
val searchType = appliedType(typeConstructor, genericType)
+
findType(genericType).map { methodName =>
val methodAsString = methodName.encodedName.toString
q"_root_.magnolia.Deferred.apply[$searchType]($methodAsString)"
@@ -105,7 +108,7 @@ object Magnolia {
recurse(ChainedImplicit(genericType.toString), genericType, assignedName) {
c.inferImplicitValue(searchType, false, false)
}.get
- }.toOption.orElse(directInferImplicit(genericType, typeConstructor))
+ }.toOption.orElse(directInferImplicit(genericType, typeConstructor).map(_.tree))
}
recursionStack = recursionStack.updated(c.enclosingPosition, newStack)
inferredImplicit
@@ -123,7 +126,7 @@ object Magnolia {
}
def directInferImplicit(genericType: c.Type,
- typeConstructor: Type): Option[c.Tree] = {
+ typeConstructor: Type): Option[Typeclass] = {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
@@ -138,51 +141,72 @@ object Magnolia {
// FIXME: Handle AnyVals
val result = if(isCaseObject) {
- val termSym = genericType.typeSymbol.companionSymbol
- val obj = termSym.asTerm
+ val obj = genericType.typeSymbol.companion.asTerm
val className = obj.name.toString
val impl = q"""
- ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, _root_.scala.Array(), $obj))
+ ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, true, new _root_.scala.Array(0), $obj))
"""
- Some(impl)
+ Some(Typeclass(genericType, impl))
} else if(isCaseClass) {
val caseClassParameters = genericType.decls.collect {
case m: MethodSymbol if m.isCaseAccessor => m.asMethod
}
val className = genericType.toString
- val typeclasses: List[(c.universe.MethodSymbol, c.Tree, c.Type)] = caseClassParameters.map { param =>
+ case class CaseParam(sym: c.universe.MethodSymbol, typeclass: c.Tree, paramType: c.Type, ref: c.TermName)
+
+ val caseParams: List[CaseParam] = caseClassParameters.foldLeft(List[CaseParam]()) { case (acc, param) =>
val paramName = param.name.encodedName.toString
val paramType = param.returnType.substituteTypes(genericType.etaExpand.typeParams, genericType.typeArgs)
- val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType,
- assignedName) {
-
- typeclassTree(Some(paramName), paramType, typeConstructor, assignedName)
-
- }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType"))
-
- (param, derivedImplicit, paramType)
- }.to[List]
+ acc.find(_.paramType == paramType).map { backRef =>
+ CaseParam(param, q"()", paramType, backRef.ref) :: acc
+ }.getOrElse {
+ val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType,
+ assignedName) {
+ typeclassTree(Some(paramName), paramType, typeConstructor, assignedName)
+ }.getOrElse(c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType"))
+
+ val ref = TermName(c.freshName("paramTypeclass"))
+ val assigned = q"""val $ref = $derivedImplicit"""
+ CaseParam(param, assigned, paramType, ref) :: acc
+ }
+ }.to[List].reverse
- val callables = typeclasses.map { case (param, typeclass, paramType) =>
- q"""_root_.magnolia.Param[$typeConstructor, $genericType, $paramType](${param.name.toString}, $typeclass, p => p.${TermName(param.name.toString)})"""
+ val paramsVal: TermName = TermName(c.freshName("parameters"))
+ val fnVal: TermName = TermName(c.freshName("fn"))
+
+ val preAssignments = caseParams.map(_.typeclass)
+
+ val assignments = caseParams.zipWithIndex.map { case (CaseParam(param, typeclass, paramType, ref), idx) =>
+ q"""$paramsVal($idx) = _root_.magnolia.Param[$typeConstructor, $genericType, $paramType](
+ ${param.name.toString}, $ref, _.${TermName(param.name.toString)}
+ )"""
}
- Some {
- q"""
- val parameters: _root_.scala.Array[Param[$typeConstructor, $genericType]] = _root_.scala.Array(..$callables)
- ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType]($className, false, parameters,
- (fn: (Param[$typeConstructor, $genericType] => Any)) => new $genericType(..${typeclasses.zipWithIndex.map { case (typeclass, idx) =>
- q"fn(parameters($idx)).asInstanceOf[${typeclass._3}]"
- } })
+ Some(Typeclass(genericType,
+ q"""{
+ ..$preAssignments
+ val $paramsVal: _root_.scala.Array[Param[$typeConstructor, $genericType]] =
+ new _root_.scala.Array(${assignments.length})
+ ..$assignments
+
+ ${c.prefix}.join(_root_.magnolia.JoinContext[$typeConstructor, $genericType](
+ $className,
+ false,
+ $paramsVal,
+ ($fnVal: Param[$typeConstructor, $genericType] => Any) =>
+ new $genericType(..${caseParams.zipWithIndex.map { case (typeclass, idx) =>
+ q"$fnVal($paramsVal($idx)).asInstanceOf[${typeclass.paramType}]"
+ } })
))
- """
- }
+ }"""
+ ))
} else if(isSealedTrait) {
val genericSubtypes = classType.get.knownDirectSubclasses.to[List]
val subtypes = genericSubtypes.map { sub =>
- val mapping = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs.zip(genericType.typeArgs).toMap
+ val typeArgs = sub.asType.typeSignature.baseType(genericType.typeSymbol).typeArgs
+ val mapping = typeArgs.zip(genericType.typeArgs).toMap
val newTypeParams = sub.asType.toType.typeArgs.map(mapping(_))
appliedType(sub.asType.toType.typeConstructor, newTypeParams)
}
@@ -193,15 +217,17 @@ object Magnolia {
c.abort(c.enclosingPosition, "")
}
+
+ val subclassesVal: TermName = TermName(c.freshName("subclasses"))
- val subclasses = subtypes.map { searchType =>
+ val assignments = subtypes.map { searchType =>
recurse(CoproductType(genericType.toString), genericType, assignedName) {
(searchType, typeclassTree(None, searchType, typeConstructor, assignedName))
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType")
}
- }.map { case (typ, typeclass) =>
- q"""_root_.magnolia.Subclass[$typeConstructor, $genericType, $typ](
+ }.zipWithIndex.map { case ((typ, typeclass), idx) =>
+ q"""$subclassesVal($idx) = _root_.magnolia.Subclass[$typeConstructor, $genericType, $typ](
${typ.typeSymbol.name.toString},
$typeclass,
(t: $genericType) => t.isInstanceOf[$typ],
@@ -210,17 +236,22 @@ object Magnolia {
}
Some {
- q"""{
- ${c.prefix}.split(_root_.scala.collection.immutable.List[_root_.magnolia.Subclass[$typeConstructor, $genericType]](..$subclasses))
- }"""
+ Typeclass(genericType, q"""{
+ val $subclassesVal: _root_.scala.Array[_root_.magnolia.Subclass[$typeConstructor, $genericType]] =
+ new _root_.scala.Array(${assignments.size})
+
+ ..$assignments
+
+ ${c.prefix}.dispatch($subclassesVal: _root_.scala.Seq[_root_.magnolia.Subclass[$typeConstructor, $genericType]])
+ }""")
}
} else None
- result.map { r =>
- q"""{
+ result.map { case Typeclass(t, r) =>
+ Typeclass(t, q"""{
def $assignedName: $resultType = $r
$assignedName
- }"""
+ }""")
}
}
@@ -248,13 +279,13 @@ object Magnolia {
val result: Option[Tree] = if(!currentStack.frames.isEmpty) {
findType(genericType) match {
case None =>
- directInferImplicit(genericType, typeConstructor)
+ directInferImplicit(genericType, typeConstructor).map(_.tree)
case Some(enclosingRef) =>
val methodAsString = enclosingRef.toString
val searchType = appliedType(typeConstructor, genericType)
Some(q"_root_.magnolia.Deferred[$searchType]($methodAsString)")
}
- } else directInferImplicit(genericType, typeConstructor)
+ } else directInferImplicit(genericType, typeConstructor).map(_.tree)
if(currentStack.frames.isEmpty) recursionStack = ListMap()
diff --git a/examples/src/main/scala/typeclasses.scala b/examples/src/main/scala/typeclasses.scala
index c79a6d7..aaa74fd 100644
--- a/examples/src/main/scala/typeclasses.scala
+++ b/examples/src/main/scala/typeclasses.scala
@@ -19,7 +19,7 @@ object Show {
}.mkString(s"${context.typeName.split("\\.").last}(", ",", ")")
}
- def split[T](subclasses: List[Subclass[Typeclass, T]])(value: T): String =
+ def dispatch[T](subclasses: Seq[Subclass[Typeclass, T]])(value: T): String =
subclasses.map { sub => sub.cast.andThen { value =>
sub.typeclass.show(sub.cast(value))
} }.reduce(_ orElse _)(value)
@@ -38,7 +38,7 @@ object Eq {
context.parameters.forall { param => param.typeclass.equal(param.dereference(value1), param.dereference(value2)) }
}
- def split[T](subclasses: List[Subclass[Eq, T]]): Eq[T] = new Eq[T] {
+ def dispatch[T](subclasses: Seq[Subclass[Eq, T]]): Eq[T] = new Eq[T] {
def equal(value1: T, value2: T) =
subclasses.map { case subclass =>
subclass.cast.andThen { value => subclass.typeclass.equal(subclass.cast(value1), subclass.cast(value2)) }
@@ -58,7 +58,7 @@ object Default {
def default = context.construct { param => param.typeclass.default }
}
- def split[T](subclasses: List[Subclass[Default, T]])(): Default[T] = new Default[T] {
+ def dispatch[T](subclasses: Seq[Subclass[Default, T]])(): Default[T] = new Default[T] {
def default = subclasses.head.typeclass.default
}
@@ -73,7 +73,7 @@ object Decoder {
def join[T](context: JoinContext[Decoder, T])(value: String): T =
context.construct { param => param.typeclass.decode(value) }
- def split[T](subclasses: List[Subclass[Decoder, T]])(param: String): T =
+ def dispatch[T](subclasses: Seq[Subclass[Decoder, T]])(param: String): T =
subclasses.map { subclass =>
{ case _ if decodes(subclass.typeclass, param) => subclass.typeclass.decode(param) }: PartialFunction[String, T]
}.reduce(_ orElse _)(param)
@@ -106,7 +106,7 @@ case class Cyrillic(б: Letter, в: Letter, г: Letter, д: Letter, ж: Letter,
case class Latin(a: Letter, b: Letter, c: Letter, d: Letter, e: Letter, f: Letter, g: Letter, h: Letter, i: Letter, j: Letter, k: Letter, l: Letter, m: Letter) extends Alphabet
case class Letter(name: String, phonetic: String)
-//case class Country(name: String, language: Language, leader: Person)
+case class Country(name: String, language: Language, leader: Person)
case class Language(name: String, code: String, alphabet: Alphabet)
//case class Person(name: String, dateOfBirth: Date)
case class Date(year: Int, month: Month, day: Int)
diff --git a/tests/src/main/scala/tests.scala b/tests/src/main/scala/tests.scala
index f269643..3245e1b 100644
--- a/tests/src/main/scala/tests.scala
+++ b/tests/src/main/scala/tests.scala
@@ -25,8 +25,6 @@ object Tests extends TestApp {
Show.generic[Person].show(Person("John Smith", 34))
}.assert(_ == "Person(name=John Smith,age=34)")
- //Show.generic[Tree[String]]
-
test("serialize a Branch") {
import magnolia.examples._
implicitly[Show[String, Branch[String]]].show(Branch(Leaf("LHS"), Leaf("RHS")))