From 1ea072197cf8a992b37d7efe0636358a236b9d6d Mon Sep 17 00:00:00 2001 From: Jon Pretty Date: Sun, 11 Jun 2017 21:36:32 +0200 Subject: Appears to be working for both covariant and contravariant typeclasses --- core/src/main/scala/magnolia.scala | 98 ++++++++++++++++++------- examples/src/main/scala/example.scala | 35 ++++++++- tests/shared/src/main/scala/magnolia/main.scala | 6 +- 3 files changed, 109 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala index 9aafb44..f3b7691 100644 --- a/core/src/main/scala/magnolia.scala +++ b/core/src/main/scala/magnolia.scala @@ -42,7 +42,7 @@ class Macros(val c: whitebox.Context) { genericType: c.universe.Type, typeConstructor: c.universe.Type, assignedName: c.TermName, - dereferencerImplicit: c.Tree): c.Tree = { + derivationImplicit: Either[c.Tree, c.Tree]): c.Tree = { findType(genericType).map { methodName => val methodAsString = methodName.encodedName.toString @@ -51,7 +51,7 @@ class Macros(val c: whitebox.Context) { }.orElse { val searchType = appliedType(typeConstructor, genericType) findType(genericType).map { _ => - directInferImplicit(genericType, typeConstructor, dereferencerImplicit) + directInferImplicit(genericType, typeConstructor, derivationImplicit) }.getOrElse { scala.util.Try { val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase @@ -63,7 +63,7 @@ class Macros(val c: whitebox.Context) { $assignedName }""" }.get - }.toOption.orElse(directInferImplicit(genericType, typeConstructor, dereferencerImplicit)) + }.toOption.orElse(directInferImplicit(genericType, typeConstructor, derivationImplicit)) } }.getOrElse { val currentStack: Stack = recursionStack(c.enclosingPosition) @@ -79,7 +79,7 @@ class Macros(val c: whitebox.Context) { private def directInferImplicit(genericType: c.universe.Type, typeConstructor: c.universe.Type, - dereferencerImplicit: c.Tree): Option[c.Tree] = { + derivationImplicit: Either[c.Tree, c.Tree]): Option[c.Tree] = { val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass")) @@ -97,59 +97,95 @@ class Macros(val c: whitebox.Context) { }.map { param => val paramName = param.name.encodedName.toString val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, assignedName) { - getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, dereferencerImplicit) + getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, derivationImplicit) }.getOrElse { c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType") } - val dereferencedValue = q"$dereferencerImplicit.dereference(sourceParameter, ${param.name.toString})" - - q"$dereferencerImplicit.delegate($derivedImplicit, $dereferencedValue)" + derivationImplicit match { + case Left(impl) => + val dereferencedValue = q"$impl.dereference(sourceParameter, ${param.name.toString})" + q"$impl.call($derivedImplicit, $dereferencedValue)" + case Right(impl) => + val paramName = TermName(param.name.toString) + val dereferencedValue = q"sourceParameter.$paramName" + q"$impl.call($derivedImplicit, $dereferencedValue)" + } } - Some(q"new $genericType(..$implicits)") + derivationImplicit match { + case Left(_) => + Some(q"new $genericType(..$implicits)") + case Right(impl) => + Some(q"$impl.join(_root_.scala.List(..$implicits))") + } } else if(isSealedTrait) { val subtypes = classType.get.knownDirectSubclasses.to[List] Some { - val reduction = subtypes.map(_.asType.toType).map { searchType => + val components = subtypes.map(_.asType.toType).map { searchType => recurse(CoproductType(genericType.toString), genericType, assignedName) { - getImplicit(None, searchType, typeConstructor, assignedName, dereferencerImplicit) + getImplicit(None, searchType, typeConstructor, assignedName, derivationImplicit) }.getOrElse { c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType") } - }.reduce { (left, right) => q"$dereferencerImplicit.combine($left, $right)" } + } - q"$dereferencerImplicit.delegate($reduction, sourceParameter)" + derivationImplicit match { + case Left(impl) => + val reduction = components.reduce { (left, right) => q"$impl.combine($left, $right)" } + q"$impl.call($reduction, sourceParameter)" + case Right(impl) => + val parts = subtypes.zip(components) + parts.tail.foldLeft(q"$impl.call(${parts.head._2}, sourceParameter.asInstanceOf[${parts.head._1}])") { case (aggregated, (componentType, derivedImplicit)) => + q"if(sourceParameter.isInstanceOf[$componentType]) $impl.call($derivedImplicit, sourceParameter.asInstanceOf[$componentType]) else $aggregated" + } + } } } else None construct.map { const => - q"""{ - def $assignedName: $resultType = $dereferencerImplicit.construct { sourceParameter => $const } + val impl = derivationImplicit.merge + val res = q"""{ + def $assignedName: $resultType = $impl.construct { sourceParameter => $const } $assignedName }""" + + try c.typecheck(res) catch { + case e: Exception => + e.printStackTrace() + } + res } } - def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = { + def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = try { import c.universe._ val genericType: Type = weakTypeOf[T] val currentStack: List[Frame] = recursionStack.get(c.enclosingPosition).map(_.frames).getOrElse(List()) val directlyReentrant = Some(genericType) == currentStack.headOption.map(_.genericType) val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor - val dereferencerTypeclass = weakTypeOf[Dereferencer[_]].typeConstructor - val dereferencerType = appliedType(dereferencerTypeclass, typeConstructor) - val dereferencerImplicit = c.untypecheck(c.inferImplicitValue(dereferencerType, false, false)) + + val coDerivationTypeclass = weakTypeOf[CovariantDerivation[_]].typeConstructor + val contraDerivationTypeclass = weakTypeOf[ContravariantDerivation[_]].typeConstructor + + val coDerivationType = appliedType(coDerivationTypeclass, List(typeConstructor)) + val contraDerivationType = appliedType(contraDerivationTypeclass, List(typeConstructor)) + val derivationImplicit = try { + Left(c.untypecheck(c.inferImplicitValue(coDerivationType, false, false))) + } catch { + case e: Exception => + Right(c.untypecheck(c.inferImplicitValue(contraDerivationType))) + } if(directlyReentrant) throw DirectlyReentrantException() val result: Option[c.Tree] = if(!recursionStack.isEmpty) { findType(genericType) match { case None => - directInferImplicit(genericType, typeConstructor, dereferencerImplicit) + directInferImplicit(genericType, typeConstructor, derivationImplicit) case Some(enclosingRef) => val methodAsString = enclosingRef.toString val searchType = appliedType(typeConstructor, genericType) @@ -157,16 +193,20 @@ class Macros(val c: whitebox.Context) { } } else { val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor - directInferImplicit(genericType, typeConstructor, dereferencerImplicit) + directInferImplicit(genericType, typeConstructor, derivationImplicit) } - + + if(currentStack.isEmpty) recursionStack = Map() + result.map { tree => if(currentStack.isEmpty) c.untypecheck(removeLazy.transform(tree)) else tree }.getOrElse { - if(currentStack.isEmpty) println("Foo") c.abort(c.enclosingPosition, "could not infer typeclass for type $genericType") } + } catch { + case DirectlyReentrantException() => ??? + case e: Exception => e.printStackTrace(); ??? } } @@ -210,10 +250,18 @@ private[magnolia] object CompileTimeState { Map() } -trait Dereferencer[Typeclass[_]] { +trait CovariantDerivation[Typeclass[_]] { type Value def dereference(value: Value, param: String): Value - def delegate[T](typeclass: Typeclass[T], value: Value): T + def call[T](typeclass: Typeclass[T], value: Value): T def combine[Supertype, Right <: Supertype](left: Typeclass[_ <: Supertype], right: Typeclass[Right]): Typeclass[Supertype] def construct[T](body: Value => T): Typeclass[T] } + +trait ContravariantDerivation[Typeclass[_]] { + type Return + def call[T](typeclass: Typeclass[T], value: T): Return + def construct[T](body: T => Return): Typeclass[T] + def join(elements: List[Return]): Return + +} diff --git a/examples/src/main/scala/example.scala b/examples/src/main/scala/example.scala index e61f8bd..c719ad8 100644 --- a/examples/src/main/scala/example.scala +++ b/examples/src/main/scala/example.scala @@ -30,10 +30,10 @@ object Extractor extends Extractor_1 { implicit val stringExtractor: Extractor[String] = Extractor(_.str) implicit val doubleExtractor: Extractor[Double] = Extractor(_.str.length.toDouble) - implicit val dereferencer: Dereferencer[Extractor] { type Value = Thing } = new Dereferencer[Extractor] { + implicit val derivation: CovariantDerivation[Extractor] { type Value = Thing } = new CovariantDerivation[Extractor] { type Value = Thing def dereference(value: Thing, param: String): Thing = value.access(param) - def delegate[T](extractor: Extractor[T], value: Thing): T = extractor.extract(value) + def call[T](extractor: Extractor[T], value: Thing): T = extractor.extract(value) def combine[Supertype, Right <: Supertype](left: Extractor[_ <: Supertype], right: Extractor[Right]): Extractor[Supertype] = left.orElse(right) @@ -41,7 +41,6 @@ object Extractor extends Extractor_1 { def extract(source: Thing): T = body(source) } } - } trait Extractor_1 extends Extractor_2 { @@ -49,6 +48,36 @@ trait Extractor_1 extends Extractor_2 { def extract(source: Thing): List[T] = List(implicitly[Extractor[T]].extract(source)) } } + trait Extractor_2 { implicit def generic[T]: Extractor[T] = macro Macros.magnolia[T, Extractor[_]] } + +trait Serializer[T] { + def serialize(src: T): String +} + +object Serializer extends Serializer_1 { + implicit val deriv: ContravariantDerivation[Serializer] { type Return = String } = new ContravariantDerivation[Serializer] { + type Return = String + def call[T](typeclass: Serializer[T], value: T): String = typeclass.serialize(value) + def construct[T](body: T => String): Serializer[T] = new Serializer[T] { + def serialize(value: T): String = body(value) + } + def join(xs: List[String]): String = xs.mkString(", ") + } +} + +trait Serializer_1 extends Serializer_2 { + implicit val intSerializer: Serializer[Int] = { t => "int" } + implicit val strSerializer: Serializer[String] = { t => "string" } + implicit val doubleSerializer: Serializer[Double] = { t => "double" } + implicit def listSerializer[T: Serializer]: Serializer[List[T]] = { ts => + println(ts) + s"List[${ts.map { t => implicitly[Serializer[T]].serialize(t) }.mkString("-")}]" + } +} + +trait Serializer_2 { + implicit def generic[T]: Serializer[T] = macro Macros.magnolia[T, Serializer[_]] +} diff --git a/tests/shared/src/main/scala/magnolia/main.scala b/tests/shared/src/main/scala/magnolia/main.scala index 9c99d69..889757c 100644 --- a/tests/shared/src/main/scala/magnolia/main.scala +++ b/tests/shared/src/main/scala/magnolia/main.scala @@ -2,7 +2,7 @@ package magnolia sealed trait Tree -case class Branch(left: List[Leaf]) extends Tree +case class Branch(left: List[Twig]) extends Tree case class Leaf(node: List[String], right: List[Branch], left2: List[Branch], another: List[Leaf], broken: Double) extends Tree case class Twig(alpha: List[Twig], beta: List[Leaf], gamma: Double, delta: List[Tree]) extends Tree @@ -10,7 +10,9 @@ object Main { def main(args: Array[String]): Unit = { - println(implicitly[Extractor[List[Twig]]].extract(Thing("42"))) + + println(implicitly[Serializer[List[Tree]]].serialize(List(Branch(List(Twig(Nil, Nil, 43, Nil)))))) + println(implicitly[Extractor[List[Tree]]].extract(Thing("42"))) } } -- cgit v1.2.3