aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Pretty <jon.pretty@propensive.com>2017-06-11 21:36:32 +0200
committerJon Pretty <jon.pretty@propensive.com>2017-06-11 21:36:32 +0200
commit1ea072197cf8a992b37d7efe0636358a236b9d6d (patch)
treefed6804a2e65f783881ce1029bae60303b175803
parent9ff2305bcdd742529bd184ba90ecdef32ca2fe4d (diff)
downloadmagnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.tar.gz
magnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.tar.bz2
magnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.zip
Appears to be working for both covariant and contravariant typeclasses
-rw-r--r--core/src/main/scala/magnolia.scala98
-rw-r--r--examples/src/main/scala/example.scala35
-rw-r--r--tests/shared/src/main/scala/magnolia/main.scala6
3 files changed, 109 insertions, 30 deletions
diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala
index 9aafb44..f3b7691 100644
--- a/core/src/main/scala/magnolia.scala
+++ b/core/src/main/scala/magnolia.scala
@@ -42,7 +42,7 @@ class Macros(val c: whitebox.Context) {
genericType: c.universe.Type,
typeConstructor: c.universe.Type,
assignedName: c.TermName,
- dereferencerImplicit: c.Tree): c.Tree = {
+ derivationImplicit: Either[c.Tree, c.Tree]): c.Tree = {
findType(genericType).map { methodName =>
val methodAsString = methodName.encodedName.toString
@@ -51,7 +51,7 @@ class Macros(val c: whitebox.Context) {
}.orElse {
val searchType = appliedType(typeConstructor, genericType)
findType(genericType).map { _ =>
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
}.getOrElse {
scala.util.Try {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
@@ -63,7 +63,7 @@ class Macros(val c: whitebox.Context) {
$assignedName
}"""
}.get
- }.toOption.orElse(directInferImplicit(genericType, typeConstructor, dereferencerImplicit))
+ }.toOption.orElse(directInferImplicit(genericType, typeConstructor, derivationImplicit))
}
}.getOrElse {
val currentStack: Stack = recursionStack(c.enclosingPosition)
@@ -79,7 +79,7 @@ class Macros(val c: whitebox.Context) {
private def directInferImplicit(genericType: c.universe.Type,
typeConstructor: c.universe.Type,
- dereferencerImplicit: c.Tree): Option[c.Tree] = {
+ derivationImplicit: Either[c.Tree, c.Tree]): Option[c.Tree] = {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
@@ -97,59 +97,95 @@ class Macros(val c: whitebox.Context) {
}.map { param =>
val paramName = param.name.encodedName.toString
val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, assignedName) {
- getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, dereferencerImplicit)
+ getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, derivationImplicit)
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")
}
- val dereferencedValue = q"$dereferencerImplicit.dereference(sourceParameter, ${param.name.toString})"
-
- q"$dereferencerImplicit.delegate($derivedImplicit, $dereferencedValue)"
+ derivationImplicit match {
+ case Left(impl) =>
+ val dereferencedValue = q"$impl.dereference(sourceParameter, ${param.name.toString})"
+ q"$impl.call($derivedImplicit, $dereferencedValue)"
+ case Right(impl) =>
+ val paramName = TermName(param.name.toString)
+ val dereferencedValue = q"sourceParameter.$paramName"
+ q"$impl.call($derivedImplicit, $dereferencedValue)"
+ }
}
- Some(q"new $genericType(..$implicits)")
+ derivationImplicit match {
+ case Left(_) =>
+ Some(q"new $genericType(..$implicits)")
+ case Right(impl) =>
+ Some(q"$impl.join(_root_.scala.List(..$implicits))")
+ }
} else if(isSealedTrait) {
val subtypes = classType.get.knownDirectSubclasses.to[List]
Some {
- val reduction = subtypes.map(_.asType.toType).map { searchType =>
+ val components = subtypes.map(_.asType.toType).map { searchType =>
recurse(CoproductType(genericType.toString), genericType, assignedName) {
- getImplicit(None, searchType, typeConstructor, assignedName, dereferencerImplicit)
+ getImplicit(None, searchType, typeConstructor, assignedName, derivationImplicit)
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType")
}
- }.reduce { (left, right) => q"$dereferencerImplicit.combine($left, $right)" }
+ }
- q"$dereferencerImplicit.delegate($reduction, sourceParameter)"
+ derivationImplicit match {
+ case Left(impl) =>
+ val reduction = components.reduce { (left, right) => q"$impl.combine($left, $right)" }
+ q"$impl.call($reduction, sourceParameter)"
+ case Right(impl) =>
+ val parts = subtypes.zip(components)
+ parts.tail.foldLeft(q"$impl.call(${parts.head._2}, sourceParameter.asInstanceOf[${parts.head._1}])") { case (aggregated, (componentType, derivedImplicit)) =>
+ q"if(sourceParameter.isInstanceOf[$componentType]) $impl.call($derivedImplicit, sourceParameter.asInstanceOf[$componentType]) else $aggregated"
+ }
+ }
}
} else None
construct.map { const =>
- q"""{
- def $assignedName: $resultType = $dereferencerImplicit.construct { sourceParameter => $const }
+ val impl = derivationImplicit.merge
+ val res = q"""{
+ def $assignedName: $resultType = $impl.construct { sourceParameter => $const }
$assignedName
}"""
+
+ try c.typecheck(res) catch {
+ case e: Exception =>
+ e.printStackTrace()
+ }
+ res
}
}
- def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = {
+ def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = try {
import c.universe._
val genericType: Type = weakTypeOf[T]
val currentStack: List[Frame] = recursionStack.get(c.enclosingPosition).map(_.frames).getOrElse(List())
val directlyReentrant = Some(genericType) == currentStack.headOption.map(_.genericType)
val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor
- val dereferencerTypeclass = weakTypeOf[Dereferencer[_]].typeConstructor
- val dereferencerType = appliedType(dereferencerTypeclass, typeConstructor)
- val dereferencerImplicit = c.untypecheck(c.inferImplicitValue(dereferencerType, false, false))
+
+ val coDerivationTypeclass = weakTypeOf[CovariantDerivation[_]].typeConstructor
+ val contraDerivationTypeclass = weakTypeOf[ContravariantDerivation[_]].typeConstructor
+
+ val coDerivationType = appliedType(coDerivationTypeclass, List(typeConstructor))
+ val contraDerivationType = appliedType(contraDerivationTypeclass, List(typeConstructor))
+ val derivationImplicit = try {
+ Left(c.untypecheck(c.inferImplicitValue(coDerivationType, false, false)))
+ } catch {
+ case e: Exception =>
+ Right(c.untypecheck(c.inferImplicitValue(contraDerivationType)))
+ }
if(directlyReentrant) throw DirectlyReentrantException()
val result: Option[c.Tree] = if(!recursionStack.isEmpty) {
findType(genericType) match {
case None =>
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
case Some(enclosingRef) =>
val methodAsString = enclosingRef.toString
val searchType = appliedType(typeConstructor, genericType)
@@ -157,16 +193,20 @@ class Macros(val c: whitebox.Context) {
}
} else {
val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
}
-
+
+ if(currentStack.isEmpty) recursionStack = Map()
+
result.map { tree =>
if(currentStack.isEmpty) c.untypecheck(removeLazy.transform(tree)) else tree
}.getOrElse {
- if(currentStack.isEmpty) println("Foo")
c.abort(c.enclosingPosition, "could not infer typeclass for type $genericType")
}
+ } catch {
+ case DirectlyReentrantException() => ???
+ case e: Exception => e.printStackTrace(); ???
}
}
@@ -210,10 +250,18 @@ private[magnolia] object CompileTimeState {
Map()
}
-trait Dereferencer[Typeclass[_]] {
+trait CovariantDerivation[Typeclass[_]] {
type Value
def dereference(value: Value, param: String): Value
- def delegate[T](typeclass: Typeclass[T], value: Value): T
+ def call[T](typeclass: Typeclass[T], value: Value): T
def combine[Supertype, Right <: Supertype](left: Typeclass[_ <: Supertype], right: Typeclass[Right]): Typeclass[Supertype]
def construct[T](body: Value => T): Typeclass[T]
}
+
+trait ContravariantDerivation[Typeclass[_]] {
+ type Return
+ def call[T](typeclass: Typeclass[T], value: T): Return
+ def construct[T](body: T => Return): Typeclass[T]
+ def join(elements: List[Return]): Return
+
+}
diff --git a/examples/src/main/scala/example.scala b/examples/src/main/scala/example.scala
index e61f8bd..c719ad8 100644
--- a/examples/src/main/scala/example.scala
+++ b/examples/src/main/scala/example.scala
@@ -30,10 +30,10 @@ object Extractor extends Extractor_1 {
implicit val stringExtractor: Extractor[String] = Extractor(_.str)
implicit val doubleExtractor: Extractor[Double] = Extractor(_.str.length.toDouble)
- implicit val dereferencer: Dereferencer[Extractor] { type Value = Thing } = new Dereferencer[Extractor] {
+ implicit val derivation: CovariantDerivation[Extractor] { type Value = Thing } = new CovariantDerivation[Extractor] {
type Value = Thing
def dereference(value: Thing, param: String): Thing = value.access(param)
- def delegate[T](extractor: Extractor[T], value: Thing): T = extractor.extract(value)
+ def call[T](extractor: Extractor[T], value: Thing): T = extractor.extract(value)
def combine[Supertype, Right <: Supertype](left: Extractor[_ <: Supertype],
right: Extractor[Right]): Extractor[Supertype] = left.orElse(right)
@@ -41,7 +41,6 @@ object Extractor extends Extractor_1 {
def extract(source: Thing): T = body(source)
}
}
-
}
trait Extractor_1 extends Extractor_2 {
@@ -49,6 +48,36 @@ trait Extractor_1 extends Extractor_2 {
def extract(source: Thing): List[T] = List(implicitly[Extractor[T]].extract(source))
}
}
+
trait Extractor_2 {
implicit def generic[T]: Extractor[T] = macro Macros.magnolia[T, Extractor[_]]
}
+
+trait Serializer[T] {
+ def serialize(src: T): String
+}
+
+object Serializer extends Serializer_1 {
+ implicit val deriv: ContravariantDerivation[Serializer] { type Return = String } = new ContravariantDerivation[Serializer] {
+ type Return = String
+ def call[T](typeclass: Serializer[T], value: T): String = typeclass.serialize(value)
+ def construct[T](body: T => String): Serializer[T] = new Serializer[T] {
+ def serialize(value: T): String = body(value)
+ }
+ def join(xs: List[String]): String = xs.mkString(", ")
+ }
+}
+
+trait Serializer_1 extends Serializer_2 {
+ implicit val intSerializer: Serializer[Int] = { t => "int" }
+ implicit val strSerializer: Serializer[String] = { t => "string" }
+ implicit val doubleSerializer: Serializer[Double] = { t => "double" }
+ implicit def listSerializer[T: Serializer]: Serializer[List[T]] = { ts =>
+ println(ts)
+ s"List[${ts.map { t => implicitly[Serializer[T]].serialize(t) }.mkString("-")}]"
+ }
+}
+
+trait Serializer_2 {
+ implicit def generic[T]: Serializer[T] = macro Macros.magnolia[T, Serializer[_]]
+}
diff --git a/tests/shared/src/main/scala/magnolia/main.scala b/tests/shared/src/main/scala/magnolia/main.scala
index 9c99d69..889757c 100644
--- a/tests/shared/src/main/scala/magnolia/main.scala
+++ b/tests/shared/src/main/scala/magnolia/main.scala
@@ -2,7 +2,7 @@ package magnolia
sealed trait Tree
-case class Branch(left: List[Leaf]) extends Tree
+case class Branch(left: List[Twig]) extends Tree
case class Leaf(node: List[String], right: List[Branch], left2: List[Branch], another: List[Leaf], broken: Double) extends Tree
case class Twig(alpha: List[Twig], beta: List[Leaf], gamma: Double, delta: List[Tree]) extends Tree
@@ -10,7 +10,9 @@ object Main {
def main(args: Array[String]): Unit = {
- println(implicitly[Extractor[List[Twig]]].extract(Thing("42")))
+
+ println(implicitly[Serializer[List[Tree]]].serialize(List(Branch(List(Twig(Nil, Nil, 43, Nil))))))
+ println(implicitly[Extractor[List[Tree]]].extract(Thing("42")))
}
}