aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon Pretty <jon.pretty@propensive.com>2017-06-22 18:59:34 +0200
committerJon Pretty <jon.pretty@propensive.com>2017-06-22 18:59:34 +0200
commit70919a58e73e928d1e2c04a22a0636897e45816c (patch)
treef74b3eb2b541203692ac10d4537307484c1187ee
parent7c544b6255dbf4730e14ff132d2c44611ea176b6 (diff)
parent9d8199139ed99ee5131b3d498c628b0d4427605c (diff)
downloadmagnolia-70919a58e73e928d1e2c04a22a0636897e45816c.tar.gz
magnolia-70919a58e73e928d1e2c04a22a0636897e45816c.tar.bz2
magnolia-70919a58e73e928d1e2c04a22a0636897e45816c.zip
Merge branch 'krzemin-feature/support-2-arg-contravariant-derivations'
-rw-r--r--core/src/main/scala/magnolia.scala167
-rw-r--r--examples/src/main/scala/example.scala32
-rw-r--r--tests/shared/src/main/scala/magnolia/main.scala12
3 files changed, 145 insertions, 66 deletions
diff --git a/core/src/main/scala/magnolia.scala b/core/src/main/scala/magnolia.scala
index d2978b5..27172d3 100644
--- a/core/src/main/scala/magnolia.scala
+++ b/core/src/main/scala/magnolia.scala
@@ -2,7 +2,6 @@ package magnolia
import scala.reflect._, macros._
import macrocompat.bundle
-import scala.util.Try
import scala.collection.immutable.ListMap
import language.existentials
import language.higherKinds
@@ -12,6 +11,14 @@ class Macros(val c: whitebox.Context) {
import c.universe._
import CompileTimeState._
+
+ sealed trait DerivationImplicit { def tree: Tree }
+ case class CovariantDerivationImplicit(tree: Tree) extends DerivationImplicit
+ sealed trait ContravariantDerivationImplicit extends DerivationImplicit
+ case class ContravariantDerivation1Implicit(tree: Tree) extends ContravariantDerivationImplicit
+ case class ContravariantDerivation2Implicit(tree: Tree) extends ContravariantDerivationImplicit
+
+
private def findType(key: Type): Option[TermName] =
recursionStack(c.enclosingPosition).frames.find(_.genericType == key).map(_.termName(c))
@@ -22,14 +29,14 @@ class Macros(val c: whitebox.Context) {
recursionStack.get(c.enclosingPosition).map(_.push(path, key, value)).getOrElse(
Stack(List(Frame(path, key, value)), Nil))
)
-
+
try Some(fn) catch { case e: Exception => None } finally {
val currentStack = recursionStack(c.enclosingPosition)
recursionStack = recursionStack.updated(c.enclosingPosition,
currentStack.pop())
}
}
-
+
private val removeLazy: Transformer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case q"_root_.magnolia.Lazy.apply[$returnType](${Literal(Constant(method: String))})" =>
@@ -38,49 +45,44 @@ class Macros(val c: whitebox.Context) {
super.transform(tree)
}
}
-
+
private def getImplicit(paramName: Option[String],
genericType: Type,
typeConstructor: Type,
assignedName: TermName,
- derivationImplicit: Either[Tree, Tree]): Tree = {
-
+ derivationImplicit: DerivationImplicit): Tree = {
+
+ val searchType = appliedType(typeConstructor, genericType)
findType(genericType).map { methodName =>
val methodAsString = methodName.encodedName.toString
- val searchType = appliedType(typeConstructor, genericType)
q"_root_.magnolia.Lazy.apply[$searchType]($methodAsString)"
}.orElse {
- val searchType = appliedType(typeConstructor, genericType)
- findType(genericType).map { _ =>
- directInferImplicit(genericType, typeConstructor, derivationImplicit)
- }.getOrElse {
- scala.util.Try {
- val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
- val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
- recurse(ChainedImplicit(genericType.toString), genericType, assignedName) {
- val inferredImplicit = c.inferImplicitValue(searchType, false, false)
- q"""{
- def $assignedName: $searchType = $inferredImplicit
- $assignedName
- }"""
- }.get
- }.toOption.orElse(directInferImplicit(genericType, typeConstructor, derivationImplicit))
- }
+ scala.util.Try {
+ val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
+ val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
+ recurse(ChainedImplicit(genericType.toString), genericType, assignedName) {
+ val inferredImplicit = c.inferImplicitValue(searchType, false, false)
+ q"""{
+ def $assignedName: $searchType = $inferredImplicit
+ $assignedName
+ }"""
+ }.get
+ }.toOption.orElse(directInferImplicit(genericType, typeConstructor, derivationImplicit))
}.getOrElse {
val currentStack: Stack = recursionStack(c.enclosingPosition)
-
+
val error = ImplicitNotFound(genericType.toString,
recursionStack(c.enclosingPosition).frames.map(_.path))
-
- val updatedStack = currentStack.copy(errors = error :: currentStack.errors)
+
+ val updatedStack = currentStack.copy(errors = error :: currentStack.errors)
recursionStack = recursionStack.updated(c.enclosingPosition, updatedStack)
c.abort(c.enclosingPosition, s"Could not find type class for type $genericType")
}
}
-
+
private def directInferImplicit(genericType: Type,
typeConstructor: Type,
- derivationImplicit: Either[Tree, Tree]): Option[Tree] = {
+ derivationImplicit: DerivationImplicit): Option[Tree] = {
val genericTypeName: String = genericType.typeSymbol.name.encodedName.toString.toLowerCase
val assignedName: TermName = TermName(c.freshName(s"${genericTypeName}Typeclass"))
@@ -89,7 +91,7 @@ class Macros(val c: whitebox.Context) {
val isCaseClass = classType.map(_.isCaseClass).getOrElse(false)
val isSealedTrait = classType.map(_.isSealed).getOrElse(false)
val isValueClass = genericType <:< typeOf[AnyVal]
-
+
val resultType = appliedType(typeConstructor, genericType)
val construct = if(isCaseClass) {
@@ -99,37 +101,42 @@ class Macros(val c: whitebox.Context) {
val implicits = caseClassParameters.map { param =>
val paramName = param.name.encodedName.toString
-
+
val derivedImplicit = recurse(ProductType(paramName, genericType.toString), genericType,
assignedName) {
-
+
getImplicit(Some(paramName), param.returnType, typeConstructor, assignedName,
derivationImplicit)
-
+
}.getOrElse {
c.abort(c.enclosingPosition, s"failed to get implicit for type $genericType")
}
-
+
derivationImplicit match {
- case Left(impl) =>
+ case CovariantDerivationImplicit(impl) =>
val dereferencedValue = q"$impl.dereference(sourceParameter, ${param.name.toString})"
q"$impl.call($derivedImplicit, $dereferencedValue)"
- case Right(impl) =>
+ case ContravariantDerivation1Implicit(impl) =>
val paramName = TermName(param.name.toString)
val dereferencedValue = q"sourceParameter.$paramName"
q"$impl.call($derivedImplicit, $dereferencedValue)"
+ case ContravariantDerivation2Implicit(impl) =>
+ val paramName = TermName(param.name.toString)
+ val dereferencedValue1 = q"sourceParameter1.$paramName"
+ val dereferencedValue2 = q"sourceParameter2.$paramName"
+ q"$impl.call($derivedImplicit, $dereferencedValue1, $dereferencedValue2)"
}
}
derivationImplicit match {
- case Left(_) =>
+ case CovariantDerivationImplicit(_) =>
Some(q"new $genericType(..$implicits)")
- case Right(impl) =>
+ case contra: ContravariantDerivationImplicit =>
val namedImplicits = caseClassParameters.zip(implicits).map { case (param, tree) =>
q"(${param.name.encodedName.toString}, $tree)"
}
- Some(q"$impl.join(_root_.scala.collection.immutable.ListMap(..$namedImplicits))")
+ Some(q"${contra.tree}.join(_root_.scala.collection.immutable.ListMap(..$namedImplicits))")
}
} else if(isSealedTrait) {
@@ -150,36 +157,58 @@ class Macros(val c: whitebox.Context) {
}
derivationImplicit match {
- case Left(impl) =>
+ case CovariantDerivationImplicit(impl) =>
val reduction = components.reduce { (left, right) => q"$impl.combine($left, $right)" }
q"$impl.call($reduction, sourceParameter)"
- case Right(impl) =>
+
+ case ContravariantDerivation1Implicit(impl) =>
val parts = subtypes.tail.zip(components.tail)
-
val base = q"""
$impl.call(${components.head}, sourceParameter.asInstanceOf[${subtypes.head}])
"""
-
parts.foldLeft(base) { case (aggregated, (componentType, derivedImplicit)) =>
q"""
if(sourceParameter.isInstanceOf[$componentType])
$impl.call($derivedImplicit, sourceParameter.asInstanceOf[$componentType])
else $aggregated"""
}
+
+ case ContravariantDerivation2Implicit(impl) =>
+ val parts = subtypes.tail.zip(components.tail)
+ val base = q"""
+ $impl.call(${components.head}, sourceParameter1.asInstanceOf[${subtypes.head}], sourceParameter2.asInstanceOf[${subtypes.head}])
+ """
+ parts.foldLeft(base) { case (aggregated, (componentType, derivedImplicit)) =>
+ q"""
+ if(sourceParameter1.isInstanceOf[$componentType] && sourceParameter2.isInstanceOf[$componentType])
+ $impl.call($derivedImplicit, sourceParameter1.asInstanceOf[$componentType], sourceParameter2.asInstanceOf[$componentType])
+ else $aggregated"""
+ }
}
}
} else None
construct.map { const =>
- val impl = derivationImplicit.merge
- q"""{
- def $assignedName: $resultType = $impl.construct { sourceParameter => $const }
- $assignedName
- }"""
+
+ derivationImplicit match {
+ case CovariantDerivationImplicit(_) =>
+ ???
+ case ContravariantDerivation1Implicit(impl) =>
+ q"""{
+ def $assignedName: $resultType = $impl.construct { sourceParameter => $const }
+ $assignedName
+ }"""
+ case ContravariantDerivation2Implicit(impl) =>
+ q"""{
+ def $assignedName: $resultType = $impl.construct { case (sourceParameter1, sourceParameter2) => $const }
+ $assignedName
+ }"""
+ }
}
}
def magnolia[T: WeakTypeTag, Typeclass: WeakTypeTag]: Tree = {
+ import scala.util.{Try, Success, Failure}
val genericType: Type = weakTypeOf[T]
val currentStack: Stack = recursionStack.get(c.enclosingPosition).getOrElse(Stack(List(), List()))
@@ -188,23 +217,29 @@ class Macros(val c: whitebox.Context) {
val coDerivationTypeclass = weakTypeOf[CovariantDerivation[_]].typeConstructor
val contraDerivationTypeclass = weakTypeOf[ContravariantDerivation[_]].typeConstructor
-
+ val contraDerivation2Typeclass = weakTypeOf[ContravariantDerivation2[_]].typeConstructor
+
val coDerivationType = appliedType(coDerivationTypeclass, List(typeConstructor))
val contraDerivationType = appliedType(contraDerivationTypeclass, List(typeConstructor))
- val derivationImplicit = try {
- Left(c.untypecheck(c.inferImplicitValue(coDerivationType, false, false)))
- } catch {
- case e: Exception =>
- try Right(c.untypecheck(c.inferImplicitValue(contraDerivationType, false, false))) catch {
- case e: Exception =>
- c.info(c.enclosingPosition, s"could not find an implicit instance of "+
- s"CovariantDerivation[$typeConstructor] or "+
- s"ContravariantDerivation[$typeConstructor]", true)
-
- throw e
- }
- }
-
+ val contraDerivation2Type = appliedType(contraDerivation2Typeclass, List(typeConstructor))
+
+ def findDerivationImplicit[T <: DerivationImplicit](tpe: c.Type, cons: Tree => T): Try[DerivationImplicit] =
+ Try(cons(c.untypecheck(c.inferImplicitValue(tpe, false, false))))
+
+ val derivationImplicit =
+ findDerivationImplicit(coDerivationType, CovariantDerivationImplicit)
+ .orElse(findDerivationImplicit(contraDerivationType, ContravariantDerivation1Implicit))
+ .orElse(findDerivationImplicit(contraDerivation2Type, ContravariantDerivation2Implicit)) match {
+ case Failure(e) =>
+ c.info(c.enclosingPosition, s"could not find an implicit instance of "+
+ s"CovariantDerivation[$typeConstructor] or "+
+ s"ContravariantDerivation[$typeConstructor] or "+
+ s"ContravariantDerivation2[$typeConstructor]", true)
+ throw e
+ case Success(di) =>
+ di
+ }
+
if(directlyReentrant) throw DirectlyReentrantException()
currentStack.errors.foreach { error =>
@@ -226,7 +261,6 @@ class Macros(val c: whitebox.Context) {
Some(q"_root_.magnolia.Lazy[$searchType]($methodAsString)")
}
} else {
- val typeConstructor: Type = weakTypeOf[Typeclass].typeConstructor
directInferImplicit(genericType, typeConstructor, derivationImplicit)
}
@@ -301,5 +335,12 @@ trait ContravariantDerivation[Typeclass[_]] {
def call[T](typeclass: Typeclass[T], value: T): Return
def construct[T](body: T => Return): Typeclass[T]
def join(elements: ListMap[String, Return]): Return
+}
+trait ContravariantDerivation2[Typeclass[_]] {
+ type Return
+ def call[T](typeclass: Typeclass[T], value1: T, value2: T): Return
+ def construct[T](body: (T, T) => Return): Typeclass[T]
+ def join(elements: ListMap[String, Return]): Return
}
+
diff --git a/examples/src/main/scala/example.scala b/examples/src/main/scala/example.scala
index e649b88..c649fba 100644
--- a/examples/src/main/scala/example.scala
+++ b/examples/src/main/scala/example.scala
@@ -14,13 +14,23 @@ object `package` {
implicit val showBool: Show[Boolean] = _.toString
implicit def showList[T: Show]: Show[List[T]] = xs => xs.map { x => s"list:${implicitly[Show[T]].show(x)}" }.mkString(";")
implicit def showSet[T: Show]: Show[Set[T]] = s => "set"
+
+ implicit class Equable[T: Eq](t: T) {
+ def isEqualTo(other: T): Boolean = implicitly[Eq[T]].isEqual(t, other)
+ }
+ implicit val eqString: Eq[String] = _ == _
+ implicit val eqBool: Eq[Boolean] = _ == _
+ implicit def eqList[T: Eq]: Eq[List[T]] =
+ (l1, l2) => l1.size == l2.size && (l1 zip l2).forall { case (e1, e2) => e1 isEqualTo e2 }
+ implicit def eqSet[T: Eq]: Eq[Set[T]] =
+ (s1, s2) => s1.size == s2.size && (s1 zip s2).forall { case (e1, e2) => e1 isEqualTo e2 }
}
sealed trait EmptyType
sealed trait Tree
case class Branch(left: Tree, right: Tree) extends Tree
-case class Leaf(value: Int, no: EmptyType) extends Tree
+case class Leaf(value: Int, no: String) extends Tree
sealed trait Entity
case class Person(name: String, address: Address) extends Entity
@@ -42,3 +52,23 @@ object Show extends Show_1 {
trait Show_1 {
implicit def generic[T]: Show[T] = macro Macros.magnolia[T, Show[_]]
}
+
+trait Eq[T] { def isEqual(a: T, b: T): Boolean }
+
+object Eq extends Eq_1 {
+
+ implicit val eqInt: Eq[Int] = _ == _
+
+ implicit val derivation = new ContravariantDerivation2[Eq] {
+ type Return = Boolean
+ def call[T](eq: Eq[T], value1: T, value2: T): Boolean =
+ if(value1.getClass == value2.getClass) eq.isEqual(value1, value2) else false
+ def construct[T](body: (T, T) => Boolean): Eq[T] = body(_, _)
+ def join(elements: ListMap[String, Boolean]): Boolean = elements.forall(_._2)
+ }
+}
+
+trait Eq_1 {
+ implicit def generic[T]: Eq[T] = macro Macros.magnolia[T, Eq[_]]
+}
+
diff --git a/tests/shared/src/main/scala/magnolia/main.scala b/tests/shared/src/main/scala/magnolia/main.scala
index 014ec33..9449f8d 100644
--- a/tests/shared/src/main/scala/magnolia/main.scala
+++ b/tests/shared/src/main/scala/magnolia/main.scala
@@ -4,10 +4,18 @@ import examples._
object Main {
def main(args: Array[String]): Unit = {
- println(Branch(Branch(Leaf(1, null), Leaf(2, null)), Leaf(3, null)).show)
+
+ val tree1: Tree = Branch(Branch(Leaf(1, "abc"), Leaf(2, "def")), Leaf(3, "ghi"))
+ val tree2: Tree = Branch(Leaf(1, "abc"), Leaf(2, "def"))
+
+ println(tree1.show)
+ println(tree1 isEqualTo tree1)
+ println(tree1 isEqualTo tree2)
+
println(List[Entity](Person("John Smith",
- Address(List("1 High Street", "London", "SW1A 1AA"),
+ Address(List("1 High Street", "London", "SW1A 1AA"),
Country("UK", "GBR", false)))).show)
+
}
}