aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJon Pretty <jon.pretty@propensive.com>2017-06-11 21:36:32 +0200
committerJon Pretty <jon.pretty@propensive.com>2017-06-11 21:36:32 +0200
commit1ea072197cf8a992b37d7efe0636358a236b9d6d (patch)
treefed6804a2e65f783881ce1029bae60303b175803 /core
parent9ff2305bcdd742529bd184ba90ecdef32ca2fe4d (diff)
downloadmagnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.tar.gz
magnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.tar.bz2
magnolia-1ea072197cf8a992b37d7efe0636358a236b9d6d.zip
Appears to be working for both covariant and contravariant typeclasses
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/magnolia.scala98
1 files changed, 73 insertions, 25 deletions
diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala
index 9aafb44..f3b7691 100644
--- a/core/src/main/scala/magnolia.scala
+++ b/core/src/main/scala/magnolia.scala
@@ -42,7 +42,7 @@ class Macros(val c: whitebox.Context) {
genericType: c.universe.Type,
typeConstructor: c.universe.Type,
assignedName: c.TermName,
- dereferencerImplicit: c.Tree): c.Tree = {
+ derivationImplicit: Either[c.Tree, c.Tree]): c.Tree = {
findType(genericType).map { methodName =>
val methodAsString = methodName.encodedName.toString
@@ -51,7 +51,7 @@ class Macros(val c: whitebox.Context) {
}.orElse {
val searchType = appliedType(typeConstructor, genericType)
findType(genericType).map { _ =>
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
}.getOrElse {
scala.util.Try {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
@@ -63,7 +63,7 @@ class Macros(val c: whitebox.Context) {
$assignedName
}"""
}.get
- }.toOption.orElse(directInferImplicit(genericType, typeConstructor, dereferencerImplicit))
+ }.toOption.orElse(directInferImplicit(genericType, typeConstructor, derivationImplicit))
}
}.getOrElse {
val currentStack: Stack = recursionStack(c.enclosingPosition)
@@ -79,7 +79,7 @@ class Macros(val c: whitebox.Context) {
private def directInferImplicit(genericType: c.universe.Type,
typeConstructor: c.universe.Type,
- dereferencerImplicit: c.Tree): Option[c.Tree] = {
+ derivationImplicit: Either[c.Tree, c.Tree]): Option[c.Tree] = {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
@@ -97,59 +97,95 @@ class Macros(val c: whitebox.Context) {
}.map { param =>
val paramName = param.name.encodedName.toString
val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType, assignedName) {
- getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, dereferencerImplicit)
+ getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName, derivationImplicit)
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")
}
- val dereferencedValue = q"$dereferencerImplicit.dereference(sourceParameter, ${param.name.toString})"
-
- q"$dereferencerImplicit.delegate($derivedImplicit, $dereferencedValue)"
+ derivationImplicit match {
+ case Left(impl) =>
+ val dereferencedValue = q"$impl.dereference(sourceParameter, ${param.name.toString})"
+ q"$impl.call($derivedImplicit, $dereferencedValue)"
+ case Right(impl) =>
+ val paramName = TermName(param.name.toString)
+ val dereferencedValue = q"sourceParameter.$paramName"
+ q"$impl.call($derivedImplicit, $dereferencedValue)"
+ }
}
- Some(q"new $genericType(..$implicits)")
+ derivationImplicit match {
+ case Left(_) =>
+ Some(q"new $genericType(..$implicits)")
+ case Right(impl) =>
+ Some(q"$impl.join(_root_.scala.List(..$implicits))")
+ }
} else if(isSealedTrait) {
val subtypes = classType.get.knownDirectSubclasses.to[List]
Some {
- val reduction = subtypes.map(_.asType.toType).map { searchType =>
+ val components = subtypes.map(_.asType.toType).map { searchType =>
recurse(CoproductType(genericType.toString), genericType, assignedName) {
- getImplicit(None, searchType, typeConstructor, assignedName, dereferencerImplicit)
+ getImplicit(None, searchType, typeConstructor, assignedName, derivationImplicit)
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $searchType")
}
- }.reduce { (left, right) => q"$dereferencerImplicit.combine($left, $right)" }
+ }
- q"$dereferencerImplicit.delegate($reduction, sourceParameter)"
+ derivationImplicit match {
+ case Left(impl) =>
+ val reduction = components.reduce { (left, right) => q"$impl.combine($left, $right)" }
+ q"$impl.call($reduction, sourceParameter)"
+ case Right(impl) =>
+ val parts = subtypes.zip(components)
+ parts.tail.foldLeft(q"$impl.call(${parts.head._2}, sourceParameter.asInstanceOf[${parts.head._1}])") { case (aggregated, (componentType, derivedImplicit)) =>
+ q"if(sourceParameter.isInstanceOf[$componentType]) $impl.call($derivedImplicit, sourceParameter.asInstanceOf[$componentType]) else $aggregated"
+ }
+ }
}
} else None
construct.map { const =>
- q"""{
- def $assignedName: $resultType = $dereferencerImplicit.construct { sourceParameter => $const }
+ val impl = derivationImplicit.merge
+ val res = q"""{
+ def $assignedName: $resultType = $impl.construct { sourceParameter => $const }
$assignedName
}"""
+
+ try c.typecheck(res) catch {
+ case e: Exception =>
+ e.printStackTrace()
+ }
+ res
}
}
- def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = {
+ def magnolia[T: c.WeakTypeTag, Typeclass: c.WeakTypeTag]: c.Tree = try {
import c.universe._
val genericType: Type = weakTypeOf[T]
val currentStack: List[Frame] = recursionStack.get(c.enclosingPosition).map(_.frames).getOrElse(List())
val directlyReentrant = Some(genericType) == currentStack.headOption.map(_.genericType)
val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor
- val dereferencerTypeclass = weakTypeOf[Dereferencer[_]].typeConstructor
- val dereferencerType = appliedType(dereferencerTypeclass, typeConstructor)
- val dereferencerImplicit = c.untypecheck(c.inferImplicitValue(dereferencerType, false, false))
+
+ val coDerivationTypeclass = weakTypeOf[CovariantDerivation[_]].typeConstructor
+ val contraDerivationTypeclass = weakTypeOf[ContravariantDerivation[_]].typeConstructor
+
+ val coDerivationType = appliedType(coDerivationTypeclass, List(typeConstructor))
+ val contraDerivationType = appliedType(contraDerivationTypeclass, List(typeConstructor))
+ val derivationImplicit = try {
+ Left(c.untypecheck(c.inferImplicitValue(coDerivationType, false, false)))
+ } catch {
+ case e: Exception =>
+ Right(c.untypecheck(c.inferImplicitValue(contraDerivationType)))
+ }
if(directlyReentrant) throw DirectlyReentrantException()
val result: Option[c.Tree] = if(!recursionStack.isEmpty) {
findType(genericType) match {
case None =>
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
case Some(enclosingRef) =>
val methodAsString = enclosingRef.toString
val searchType = appliedType(typeConstructor, genericType)
@@ -157,16 +193,20 @@ class Macros(val c: whitebox.Context) {
}
} else {
val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor
- directInferImplicit(genericType, typeConstructor, dereferencerImplicit)
+ directInferImplicit(genericType, typeConstructor, derivationImplicit)
}
-
+
+ if(currentStack.isEmpty) recursionStack = Map()
+
result.map { tree =>
if(currentStack.isEmpty) c.untypecheck(removeLazy.transform(tree)) else tree
}.getOrElse {
- if(currentStack.isEmpty) println("Foo")
c.abort(c.enclosingPosition, "could not infer typeclass for type $genericType")
}
+ } catch {
+ case DirectlyReentrantException() => ???
+ case e: Exception => e.printStackTrace(); ???
}
}
@@ -210,10 +250,18 @@ private[magnolia] object CompileTimeState {
Map()
}
-trait Dereferencer[Typeclass[_]] {
+trait CovariantDerivation[Typeclass[_]] {
type Value
def dereference(value: Value, param: String): Value
- def delegate[T](typeclass: Typeclass[T], value: Value): T
+ def call[T](typeclass: Typeclass[T], value: Value): T
def combine[Supertype, Right <: Supertype](left: Typeclass[_ <: Supertype], right: Typeclass[Right]): Typeclass[Supertype]
def construct[T](body: Value => T): Typeclass[T]
}
+
+trait ContravariantDerivation[Typeclass[_]] {
+ type Return
+ def call[T](typeclass: Typeclass[T], value: T): Return
+ def construct[T](body: T => Return): Typeclass[T]
+ def join(elements: List[Return]): Return
+
+}