aboutsummaryrefslogblamecommitdiff
path: root/core/src/main/scala/generic.scala~
blob: 476743589a6d86e8ddc1cf04a402e229c88019ac (plain) (tree)












































































































































































                                                                                                                               
package magnolia

import scala.reflect._, macros._
import macrocompat.bundle
import scala.util.Try

import scala.collection.immutable.ListMap

object GlobalMutableState {
  private[magnolia] var state: ListMap[AnyRef, AnyRef] = ListMap()
  
  private[magnolia] def push(key: AnyRef, value: AnyRef): Unit = {
    state += ((key, value))
  }
  
  private[magnolia] def pop(): Unit = {
    state = state.init
  }
  
  private[magnolia] def has(c: whitebox.Context)(key: AnyRef): Option[c.universe.Tree] =
    state.get(key).asInstanceOf[Option[c.universe.Tree]]

  private[magnolia] var searchType: AnyRef = null
}

@bundle
class Macros(val context: whitebox.Context) extends GenericMacro(context) {
  
  
  protected def classBody(context: whitebox.Context)(genericType: context.Type, implementation: context.Tree): context.Tree = {
    import context.universe._
    q"""def extract(src: _root_.magnolia.Thing): $genericType = $implementation"""
  }

  protected def dereferenceValue(context: whitebox.Context)(value: context.Tree, elem: String): context.Tree = {
    import context.universe._
    q"$value.access($elem)"
  }
  
  protected def callDelegateMethod(context: whitebox.Context)(value: context.Tree, argument: context.Tree): context.Tree = {
    import context.universe._
    q"$value.extract($argument)"
  }
  
  protected def coproductReduction(context: whitebox.Context)(left: context.Tree, right: context.Tree): context.Tree = {
    import context.universe._
    q"$left.orElse($right)"
  }
}

abstract class GenericMacro(whiteboxContext: whitebox.Context) {

  val c = whiteboxContext

  def getImplicit(genericType: c.universe.Type,
                  typeConstructor: c.universe.Type,
                  myName: c.universe.TermName,
                  count: Int): c.Tree = {
    
    import c.universe._
    println(s"getImplicit($genericType, $count)")
    val result = GlobalMutableState.has(c)(genericType).map { nm => q"$nm" }.orElse {
      val searchType = appliedType(typeConstructor, genericType)
      if(GlobalMutableState.has(c)(genericType).isEmpty) {
        scala.util.Try {
          GlobalMutableState.searchType = genericType
          c.inferImplicitValue(searchType, false, false)
        }.toOption.orElse {
          println("Recursing")
          directInferImplicit(genericType, typeConstructor, count + 1)
        }
      } else {
        directInferImplicit(genericType, typeConstructor, count + 1)
      }
    }.getOrElse {
      println("Failed 2.")
      c.abort(c.enclosingPosition, "Could not find extractor for type "+genericType)
    }


    result
  }
  
  def directInferImplicit(genericType: c.universe.Type,
         typeConstructor: c.universe.Type,
         count: Int): Option[c.Tree] = {
    import c.universe._
   
    println(s"directInferImplicit($genericType, $count)")

    val myName: TermName = TermName(c.freshName(genericType.typeSymbol.name.encodedName.toString.toLowerCase+"Extractor"))
    val typeSymbol = genericType.typeSymbol
    val classType = if(typeSymbol.isClass) Some(typeSymbol.asClass) else None
    val isCaseClass = classType.map(_.isCaseClass).getOrElse(false)
    val isSealedTrait = classType.map(_.isSealed).getOrElse(false)
    val isAnyVal = genericType <:< typeOf[AnyVal]
    
    val resultType = appliedType(typeConstructor, genericType)

    val construct = if(isCaseClass) {
      val implicits = genericType.decls.collect {
        case m: MethodSymbol if m.isCaseAccessor => m.asMethod
      }.map { param =>
        val returnType = param.returnType
        GlobalMutableState.push(genericType, myName)
        val imp = getImplicit(returnType, typeConstructor, myName, count)
        GlobalMutableState.pop()
        val dereferenced = dereferenceValue(c)(q"src", param.name.toString)
        callDelegateMethod(c)(imp, dereferenced)
      }

      Some(q"new $genericType(..$implicits)")
    } else if(isSealedTrait) {
      val subtypes = classType.get.knownDirectSubclasses.to[List]
      Some(subtypes.map(_.asType.toType).map { searchType =>
        GlobalMutableState.push(genericType, myName)
        val res = getImplicit(searchType, typeConstructor, myName, count)
        GlobalMutableState.pop()
        res
      }.reduce(coproductReduction(c))).map { imp =>
        callDelegateMethod(c)(imp, q"src")
      }
      
    } else None

    val result = construct.map { const =>
      
      val methodImplementation = classBody(c)(genericType, const)
      q"""{
        def $myName: $resultType = new $resultType {
          $methodImplementation
        }
        $myName
      }"""
    }


    result
  }
  
  protected def classBody(c: whitebox.Context)(genericType: c.Type, implementation: c.Tree): c.Tree
  protected def coproductReduction(c: whitebox.Context)(left: c.Tree, right: c.Tree): c.Tree
  protected def dereferenceValue(c: whitebox.Context)(value: c.Tree, elem: String): c.Tree
  protected def callDelegateMethod(c: whitebox.Context)(value: c.Tree, argument: c.Tree): c.Tree

  def generic[T: c.WeakTypeTag, Tc: c.WeakTypeTag]: c.Tree = {
    import c.universe._

    val genericType: Type = weakTypeOf[T]
    val reentrant = genericType == GlobalMutableState.searchType
    if(reentrant) {
      println("Reentrant.")
      ???
    }
    val typeConstructor: Type = weakTypeOf[Tc].typeConstructor

    val result = directInferImplicit(genericType, typeConstructor, 0)

    println(result)
    try result.map { tree => c.typecheck(tree) } catch {
      case e: Exception =>
        println(result)
        println("Failed to typecheck because: "+e)
    }

    result.getOrElse {
      println("Count not infer extractor...")
      c.abort(c.enclosingPosition, "Could not infer extractor. Sorry.")
    }
  }

}