diff options
Diffstat (limited to 'src/core/magnolia.scala')
-rw-r--r-- | src/core/magnolia.scala | 509 |
1 files changed, 509 insertions, 0 deletions
diff --git a/src/core/magnolia.scala b/src/core/magnolia.scala new file mode 100644 index 0000000..5318aa2 --- /dev/null +++ b/src/core/magnolia.scala @@ -0,0 +1,509 @@ +package magnolia + +import scala.collection.mutable +import scala.language.existentials +import scala.language.higherKinds +import scala.reflect.macros._ + +/** the object which defines the Magnolia macro */ +object Magnolia { + import CompileTimeState._ + + /** derives a generic typeclass instance for the type `T` + * + * This is a macro definition method which should be bound to a method defined inside a Magnolia + * generic derivation object, that is, one which defines the methods `combine`, `dispatch` and + * the type constructor, `Typeclass[_]`. This will typically look like, + * <pre> + * object Derivation { + * // other definitions + * implicit def gen[T]: Typeclass[T] = Magnolia.gen[T] + * } + * </pre> + * which would support automatic derivation of typeclass instances by calling + * `Derivation.gen[T]` or with `implicitly[Typeclass[T]]`, if the implicit method is imported + * into the current scope. + * + * The definition expects a type constructor called `Typeclass`, taking one *-kinded type + * parameter to be defined on the same object as a means of determining how the typeclass should + * be genericized. While this may be obvious for typeclasses like `Show[T]` which take only a + * single type parameter, Magnolia can also derive typeclass instances for types such as + * `Decoder[Format, Type]` which would typically fix the `Format` parameter while varying the + * `Type` parameter. + * + * While there is no "interface" for a derivation, in the object-oriented sense, the Magnolia + * macro expects to be able to call certain methods on the object within which it is bound to a + * method. + * + * Specifically, for deriving case classes (product types), the macro will attempt to call the + * `combine` method with an instance of [[CaseClass]], like so, + * <pre> + * <derivation>.combine(<caseClass>): Typeclass[T] + * </pre> + * That is to say, the macro expects there to exist a method called `combine` on the derivation + * object, which may be called with the code above, and for it to return a type which conforms + * to the type `Typeclass[T]`. The implementation of `combine` will therefore typically look + * like this, + * <pre> + * def combine[T](caseClass: CaseClass[Typeclass, T]): Typeclass[T] = ... + * </pre> + * however, there is the flexibility to provide additional type parameters or additional + * implicit parameters to the definition, provided these do not affect its ability to be invoked + * as described above. + * + * Likewise, for deriving sealed traits (coproduct or sum types), the macro will attempt to call + * the `dispatch` method with an instance of [[SealedTrait]], like so, + * <pre> + * <derivation>.dispatch(<sealedTrait>): Typeclass[T] + * </pre> + * so a definition such as, + * <pre> + * def dispatch[T](sealedTrait: SealedTrait[Typeclass, T]): Typeclass[T] = ... + * </pre> + * will suffice, however the qualifications regarding additional type parameters and implicit + * parameters apply equally to `dispatch` as to `combine`. + * */ + def gen[T: c.WeakTypeTag](c: whitebox.Context): c.Tree = Stack.withContext(c) { stack => + import c.universe._ + import internal._ + + val debug = c.macroApplication.symbol.annotations + .find(_.tree.tpe <:< typeOf[debug]) + .flatMap(_.tree.children.tail.collectFirst { case Literal(Constant(s: String)) => s }) + + val magnoliaPkg = c.mirror.staticPackage("magnolia") + val scalaPkg = c.mirror.staticPackage("scala") + + val repeatedParamClass = definitions.RepeatedParamClass + val scalaSeqType = typeOf[Seq[_]].typeConstructor + + val prefixType = c.prefix.tree.tpe + val prefixObject = prefixType.typeSymbol + val prefixName = prefixObject.name.decodedName + + val typeDefs = prefixType.baseClasses.flatMap { cls => + cls.asType.toType.decls.filter(_.isType).find(_.name.toString == "Typeclass").map { tpe => + tpe.asType.toType.asSeenFrom(prefixType, cls) + } + } + + val typeConstructor = typeDefs.headOption.fold { + c.abort( + c.enclosingPosition, + s"magnolia: the derivation $prefixObject does not define the Typeclass type constructor" + ) + }(_.typeConstructor) + + def checkMethod(termName: String, category: String, expected: String): Unit = { + val term = TermName(termName) + val combineClass = c.prefix.tree.tpe.baseClasses + .find { cls => + cls.asType.toType.decl(term) != NoSymbol + } + .getOrElse { + c.abort( + c.enclosingPosition, + s"magnolia: the method `$termName` must be defined on the derivation $prefixObject to derive typeclasses for $category" + ) + } + val firstParamBlock = combineClass.asType.toType.decl(term).asTerm.asMethod.paramLists.head + if (firstParamBlock.lengthCompare(1) != 0) + c.abort(c.enclosingPosition, + s"magnolia: the method `combine` should take a single parameter of type $expected") + } + + // FIXME: Only run these methods if they're used, particularly `dispatch` + checkMethod("combine", "case classes", "CaseClass[Typeclass, _]") + checkMethod("dispatch", "sealed traits", "SealedTrait[Typeclass, _]") + + val removeDeferred = new Transformer { + override def transform(tree: Tree) = tree match { + case q"$magnolia.Deferred.apply[$_](${Literal(Constant(method: String))})" + if magnolia.symbol == magnoliaPkg => + q"${TermName(method)}" + case _ => + super.transform(tree) + } + } + + def typeclassTree(genericType: Type, typeConstructor: Type): Tree = { + val searchType = appliedType(typeConstructor, genericType) + val deferredRef = for (methodName <- stack find searchType) yield { + val methodAsString = methodName.decodedName.toString + q"$magnoliaPkg.Deferred.apply[$searchType]($methodAsString)" + } + + deferredRef.getOrElse { + val path = ChainedImplicit(s"$prefixName.Typeclass", genericType.toString) + val frame = stack.Frame(path, searchType, termNames.EMPTY) + stack.recurse(frame, searchType) { + Option(c.inferImplicitValue(searchType)) + .filterNot(_.isEmpty) + .orElse(directInferImplicit(genericType, typeConstructor)) + .getOrElse { + val missingType = stack.top.fold(searchType)(_.searchType.asInstanceOf[Type]) + val typeClassName = s"${missingType.typeSymbol.name.decodedName}.Typeclass" + val genericType = missingType.typeArgs.head + val trace = stack.trace.mkString(" in ", "\n in ", "\n") + c.abort(c.enclosingPosition, + s"magnolia: could not find $typeClassName for type $genericType\n$trace") + } + } + } + } + + def directInferImplicit(genericType: Type, typeConstructor: Type): Option[Tree] = { + val genericTypeName = genericType.typeSymbol.name.decodedName.toString.toLowerCase + val assignedName = TermName(c.freshName(s"${genericTypeName}Typeclass")) + val typeSymbol = genericType.typeSymbol + val classType = if (typeSymbol.isClass) Some(typeSymbol.asClass) else None + val isCaseClass = classType.exists(_.isCaseClass) + val isCaseObject = classType.exists(_.isModuleClass) + val isSealedTrait = classType.exists(_.isSealed) + + val primitives = Set(typeOf[Double], + typeOf[Float], + typeOf[Short], + typeOf[Byte], + typeOf[Int], + typeOf[Long], + typeOf[Char], + typeOf[Boolean], + typeOf[Unit]) + + val isValueClass = genericType <:< typeOf[AnyVal] && !primitives.exists(_ =:= genericType) + + val resultType = appliedType(typeConstructor, genericType) + + val typeName = TermName(c.freshName("typeName")) + val typeNameDef = { + val ts = genericType.typeSymbol + q"val $typeName = $magnoliaPkg.TypeName(${ts.owner.fullName}, ${ts.name.decodedName.toString})" + } + + val result = if (isCaseObject) { + val impl = q""" + $typeNameDef + ${c.prefix}.combine($magnoliaPkg.Magnolia.caseClass[$typeConstructor, $genericType]( + $typeName, true, false, new $scalaPkg.Array(0), _ => ${genericType.typeSymbol.asClass.module}) + ) + """ + Some(impl) + } else if (isCaseClass || isValueClass) { + val caseClassParameters = genericType.decls.collect { + case m: MethodSymbol if m.isCaseAccessor || (isValueClass && m.isParamAccessor) => + m.asMethod + } + + case class CaseParam(sym: MethodSymbol, + repeated: Boolean, + typeclass: Tree, + paramType: Type, + ref: TermName) + + val caseParamsReversed = caseClassParameters.foldLeft[List[CaseParam]](Nil) { + (acc, param) => + val paramName = param.name.decodedName.toString + val paramTypeSubstituted = param.typeSignatureIn(genericType).resultType + + val (repeated, paramType) = paramTypeSubstituted match { + case TypeRef(_, `repeatedParamClass`, typeArgs) => + true -> appliedType(scalaSeqType, typeArgs) + case tpe => + false -> tpe + } + + acc + .find(_.paramType =:= paramType) + .fold { + val path = ProductType(paramName, genericType.toString) + val frame = stack.Frame(path, resultType, assignedName) + val derivedImplicit = + stack.recurse(frame, appliedType(typeConstructor, paramType)) { + typeclassTree(paramType, typeConstructor) + } + + val ref = TermName(c.freshName("paramTypeclass")) + val assigned = q"""lazy val $ref = $derivedImplicit""" + CaseParam(param, repeated, assigned, paramType, ref) :: acc + } { backRef => + CaseParam(param, repeated, q"()", paramType, backRef.ref) :: acc + } + } + + val caseParams = caseParamsReversed.reverse + + val paramsVal: TermName = TermName(c.freshName("parameters")) + val fieldValues: TermName = TermName(c.freshName("fieldValues")) + + val preAssignments = caseParams.map(_.typeclass) + + val defaults = if (!isValueClass) { + val companionRef = GlobalUtil.patchedCompanionRef(c)(genericType.dealias) + val companionSym = companionRef.symbol.asModule.info + + // If a companion object is defined with alternative apply methods + // it is needed get all the alternatives + val constructorMethods = + companionSym.decl(TermName("apply")).alternatives.map(_.asMethod) + + // The last apply method in the alternatives is the one that belongs + // to the case class, not the user defined companion object + val indexedConstructorParams = + constructorMethods.last.paramLists.head.map(_.asTerm).zipWithIndex + + indexedConstructorParams.map { + case (p, idx) => + if (p.isParamWithDefault) { + val method = TermName("apply$default$" + (idx + 1)) + q"$scalaPkg.Some($companionRef.$method)" + } else q"$scalaPkg.None" + } + } else List(q"$scalaPkg.None") + + val assignments = caseParams.zip(defaults).zipWithIndex.map { + case ((CaseParam(param, repeated, typeclass, paramType, ref), defaultVal), idx) => + q"""$paramsVal($idx) = $magnoliaPkg.Magnolia.param[$typeConstructor, $genericType, + $paramType]( + ${param.name.decodedName.toString}, $repeated, $ref, $defaultVal, _.${param.name} + )""" + } + + Some(q"""{ + ..$preAssignments + val $paramsVal: $scalaPkg.Array[$magnoliaPkg.Param[$typeConstructor, $genericType]] = + new $scalaPkg.Array(${assignments.length}) + ..$assignments + + $typeNameDef + + ${c.prefix}.combine($magnoliaPkg.Magnolia.caseClass[$typeConstructor, $genericType]( + $typeName, + false, + $isValueClass, + $paramsVal, + ($fieldValues: $scalaPkg.Seq[Any]) => { + if ($fieldValues.lengthCompare($paramsVal.length) != 0) { + val msg = "`" + $typeName.full + "` has " + $paramsVal.length + " fields, not " + $fieldValues.size + throw new java.lang.IllegalArgumentException(msg) + } + new $genericType(..${caseParams.zipWithIndex.map { + case (typeclass, idx) => + val arg = q"$fieldValues($idx).asInstanceOf[${typeclass.paramType}]" + if (typeclass.repeated) q"$arg: _*" else arg + }})})) + }""") + } else if (isSealedTrait) { + val genericSubtypes = classType.get.knownDirectSubclasses.to[List] + val subtypes = genericSubtypes.map { sub => + val subType = sub.asType.toType // FIXME: Broken for path dependent types + val typeParams = sub.asType.typeParams + val typeArgs = thisType(sub).baseType(genericType.typeSymbol).typeArgs + val mapping = (typeArgs.map(_.typeSymbol), genericType.typeArgs).zipped.toMap + val newTypeArgs = typeParams.map(mapping.withDefault(_.asType.toType)) + val applied = appliedType(subType.typeConstructor, newTypeArgs) + existentialAbstraction(typeParams, applied) + } + + if (subtypes.isEmpty) { + c.info(c.enclosingPosition, + s"magnolia: could not find any direct subtypes of $typeSymbol", + force = true) + + c.abort(c.enclosingPosition, "") + } + + val subtypesVal: TermName = TermName(c.freshName("subtypes")) + + val typeclasses = for (subType <- subtypes) yield { + val path = CoproductType(genericType.toString) + val frame = stack.Frame(path, resultType, assignedName) + subType -> stack.recurse(frame, appliedType(typeConstructor, subType)) { + typeclassTree(subType, typeConstructor) + } + } + + val assignments = typeclasses.zipWithIndex.map { + case ((typ, typeclass), idx) => + q"""$subtypesVal($idx) = $magnoliaPkg.Magnolia.subtype[$typeConstructor, $genericType, $typ]( + $magnoliaPkg.TypeName(${typ.typeSymbol.owner.fullName}, ${typ.typeSymbol.name.decodedName.toString}), + $typeclass, + (t: $genericType) => t.isInstanceOf[$typ], + (t: $genericType) => t.asInstanceOf[$typ] + )""" + } + + Some(q"""{ + val $subtypesVal: $scalaPkg.Array[$magnoliaPkg.Subtype[$typeConstructor, $genericType]] = + new $scalaPkg.Array(${assignments.size}) + + ..$assignments + + $typeNameDef + + ${c.prefix}.dispatch(new $magnoliaPkg.SealedTrait( + $typeName, + $subtypesVal: $scalaPkg.Array[$magnoliaPkg.Subtype[$typeConstructor, $genericType]]) + ): $resultType + }""") + } else None + + for (term <- result) yield q"""{ + lazy val $assignedName: $resultType = $term + $assignedName + }""" + } + + val genericType: Type = weakTypeOf[T] + val searchType = appliedType(typeConstructor, genericType) + val directlyReentrant = stack.top.exists(_.searchType =:= searchType) + if (directlyReentrant) throw DirectlyReentrantException() + + val result = stack + .find(searchType) + .map { enclosingRef => + q"$magnoliaPkg.Deferred[$searchType](${enclosingRef.toString})" + } + .orElse { + directInferImplicit(genericType, typeConstructor) + } + + for (tree <- result) if (debug.isDefined && genericType.toString.contains(debug.get)) { + c.echo(c.enclosingPosition, s"Magnolia macro expansion for $genericType") + c.echo(NoPosition, s"... = ${showCode(tree)}\n\n") + } + + val dereferencedResult = + if (stack.nonEmpty) result + else for (tree <- result) yield c.untypecheck(removeDeferred.transform(tree)) + + dereferencedResult.getOrElse { + c.abort(c.enclosingPosition, + s"magnolia: could not infer $prefixName.Typeclass for type $genericType") + } + } + + /** constructs a new [[Subtype]] instance + * + * This method is intended to be called only from code generated by the Magnolia macro, and + * should not be called directly from users' code. */ + def subtype[Tc[_], T, S <: T](name: TypeName, + tc: => Tc[S], + isType: T => Boolean, + asType: T => S): Subtype[Tc, T] = + new Subtype[Tc, T] with PartialFunction[T, S] { + type SType = S + def typeName: TypeName = name + def typeclass: Tc[SType] = tc + def cast: PartialFunction[T, SType] = this + def isDefinedAt(t: T) = isType(t) + def apply(t: T): SType = asType(t) + } + + /** constructs a new [[Param]] instance + * + * This method is intended to be called only from code generated by the Magnolia macro, and + * should not be called directly from users' code. */ + def param[Tc[_], T, P](name: String, + isRepeated: Boolean, + typeclassParam: => Tc[P], + defaultVal: => Option[P], + deref: T => P): Param[Tc, T] = new Param[Tc, T] { + type PType = P + def label: String = name + def repeated: Boolean = isRepeated + def default: Option[PType] = defaultVal + def typeclass: Tc[PType] = typeclassParam + def dereference(t: T): PType = deref(t) + } + + /** constructs a new [[CaseClass]] instance + * + * This method is intended to be called only from code generated by the Magnolia macro, and + * should not be called directly from users' code. */ + def caseClass[Tc[_], T](name: TypeName, + obj: Boolean, + valClass: Boolean, + params: Array[Param[Tc, T]], + constructor: Seq[Any] => T): CaseClass[Tc, T] = + new CaseClass[Tc, T](name, obj, valClass, params) { + def rawConstruct(fieldValues: Seq[Any]): T = constructor(fieldValues) + } +} + +private[magnolia] final case class DirectlyReentrantException() + extends Exception("attempt to recurse directly") + +private[magnolia] object Deferred { def apply[T](method: String): T = ??? } + +private[magnolia] object CompileTimeState { + + sealed abstract class TypePath(path: String) { override def toString = path } + final case class CoproductType(typeName: String) extends TypePath(s"coproduct type $typeName") + + final case class ProductType(paramName: String, typeName: String) + extends TypePath(s"parameter '$paramName' of product type $typeName") + + final case class ChainedImplicit(typeClassName: String, typeName: String) + extends TypePath(s"chained implicit $typeClassName for type $typeName") + + final class Stack[C <: whitebox.Context] { + private var frames = List.empty[Frame] + private val cache = mutable.Map.empty[C#Type, C#Tree] + + def isEmpty: Boolean = frames.isEmpty + def nonEmpty: Boolean = frames.nonEmpty + def top: Option[Frame] = frames.headOption + def pop(): Unit = frames = frames drop 1 + def push(frame: Frame): Unit = frames ::= frame + + def clear(): Unit = { + frames = Nil + cache.clear() + } + + def find(searchType: C#Type): Option[C#TermName] = frames.collectFirst { + case Frame(_, tpe, term) if tpe =:= searchType => term + } + + def recurse[T <: C#Tree](frame: Frame, searchType: C#Type)(fn: => T): T = { + push(frame) + val result = cache.getOrElseUpdate(searchType, fn) + pop() + result.asInstanceOf[T] + } + + def trace: List[TypePath] = + frames + .drop(1) + .foldLeft[(C#Type, List[TypePath])]((null, Nil)) { + case ((_, Nil), frame) => + (frame.searchType, frame.path :: Nil) + case (continue @ (tpe, acc), frame) => + if (tpe =:= frame.searchType) continue + else (frame.searchType, frame.path :: acc) + } + ._2 + .reverse + + override def toString: String = + frames.mkString("magnolia stack:\n", "\n", "\n") + + final case class Frame(path: TypePath, searchType: C#Type, term: C#TermName) + } + + object Stack { + private val global = new Stack[whitebox.Context] + private val workSet = mutable.Set.empty[whitebox.Context#Symbol] + + def withContext(c: whitebox.Context)(fn: Stack[c.type] => c.Tree): c.Tree = { + workSet += c.macroApplication.symbol + val depth = c.enclosingMacros.count(m => workSet(m.macroApplication.symbol)) + try fn(global.asInstanceOf[Stack[c.type]]) + finally if (depth <= 1) { + global.clear() + workSet.clear() + } + } + } +} |