aboutsummaryrefslogtreecommitdiff
path: root/src/dotty/tools/dotc/transform/InterceptedMethods.scala
blob: b56985ffe6ef0f7a66f25ae7715c657fe0875907 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package dotty.tools.dotc
package transform

import TreeTransforms._
import core.DenotTransformers._
import core.Denotations._
import core.SymDenotations._
import core.Contexts._
import core.Types._
import ast.Trees._
import ast.tpd.{Apply, Tree, cpy}
import dotty.tools.dotc.ast.tpd
import scala.collection.mutable
import dotty.tools.dotc._
import core._
import Contexts._
import Symbols._
import Decorators._
import NameOps._
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransformer, TreeTransform}
import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.ast.{untpd, tpd}
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Types.MethodType
import dotty.tools.dotc.core.Names.Name
import dotty.runtime.LazyVals
import scala.collection.mutable.ListBuffer
import dotty.tools.dotc.core.Denotations.SingleDenotation
import dotty.tools.dotc.core.SymDenotations.SymDenotation
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
import StdNames._

/** Replace member references as follows:
  *
  * - `x == y` for == in class Any becomes `x equals y` with equals in class Object.
  * - `x != y` for != in class Any becomes `!(x equals y)` with equals in class Object.
  * - `x.##` for ## in other classes becomes calls to ScalaRunTime.hash,
  *     using the most precise overload available
  * - `x.getClass` for getClass in primitives becomes `x.getClass` with getClass in class Object.
  */
class InterceptedMethods extends TreeTransform {

  import tpd._

  override def name: String = "intercepted"

  private var getClassMethods: Set[Symbol] = _
  private var poundPoundMethods: Set[Symbol] = _
  private var Any_comparisons: Set[Symbol] = _
  private var interceptedMethods: Set[Symbol] = _
  private var primitiveGetClassMethods: Set[Symbol] = _

  /** perform context-dependant initialization */
  override def init(implicit ctx: Context, info: TransformerInfo): Unit = {
    getClassMethods =  Set(defn.Any_getClass, defn.AnyVal_getClass)
    poundPoundMethods = Set(defn.Any_##, defn.Object_##)
    Any_comparisons = Set(defn.Any_==, defn.Any_!=)
    interceptedMethods = getClassMethods ++ poundPoundMethods ++ Any_comparisons
    primitiveGetClassMethods = Set[Symbol](defn.Any_getClass, defn.AnyVal_getClass) ++
      defn.ScalaValueClasses.map(x => x.requiredMethod(nme.getClass_))
  }

  // this should be removed if we have guarantee that ## will get Apply node
  override def transformSelect(tree: tpd.Select)(implicit ctx: Context, info: TransformerInfo): Tree = {
    if (tree.symbol.isTerm && poundPoundMethods.contains(tree.symbol.asTerm)) {
      val rewrite = PoundPoundValue(tree.qualifier)
      ctx.log(s"$name rewrote $tree to $rewrite")
      rewrite
    }
    else tree
  }

  private def PoundPoundValue(tree: Tree)(implicit ctx: Context) = {
    val s = tree.tpe.widen.typeSymbol
    if (s == defn.NullClass) Literal(Constant(0))
    else {
      // Since we are past typer, we need to avoid creating trees carrying
      // overloaded types.  This logic is custom (and technically incomplete,
      // although serviceable) for def hash.  What is really needed is for
      // the overloading logic presently hidden away in a few different
      // places to be properly exposed so we can just call "resolveOverload"
      // after typer.  Until then:

      def alts = defn.ScalaRuntimeModule.info.member(nme.hash_)

      // if tpe is a primitive value type, alt1 will match on the exact value,
      // taking in account that null.asInstanceOf[Int] == 0
      def alt1 = alts.suchThat(_.info.firstParamTypes.head =:= tree.tpe.widen)

      // otherwise alt2 will match. alt2 also knows how to handle 'null' runtime value
      def alt2 = defn.ScalaRuntimeModule.info.member(nme.hash_)
        .suchThat(_.info.firstParamTypes.head.typeSymbol == defn.AnyClass)

      if (defn.ScalaNumericValueClasses contains s) {
        tpd.Apply(Ident(alt1.termRef), List(tree))
      } else tpd.Apply(Ident(alt2.termRef), List(tree))
    }
  }

  override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo): Tree = {
    def unknown = {
      assert(false, s"The symbol '${tree.fun.symbol}' was interecepted but didn't match any cases, " +
        s"that means the intercepted methods set doesn't match the code")
      tree
    }
    if (tree.fun.symbol.isTerm && tree.args.isEmpty &&
        (interceptedMethods contains tree.fun.symbol.asTerm)) {
      val rewrite: Tree = tree.fun match {
        case Select(qual, name) =>
          if (poundPoundMethods contains tree.fun.symbol.asTerm) {
            PoundPoundValue(qual)
          } else if (Any_comparisons contains tree.fun.symbol.asTerm) {
            if (tree.fun.symbol eq defn.Any_==) {
              Apply(Select(qual, defn.Object_equals.termRef), tree.args)
            } else if (tree.fun.symbol eq defn.Any_!=) {
              Select(Apply(Select(qual, defn.Object_equals.termRef), tree.args), defn.Boolean_!.termRef)
            } else unknown
          } /* else if (isPrimitiveValueClass(qual.tpe.typeSymbol)) {
            // todo: this is needed to support value classes
            // Rewrite 5.getClass to ScalaRunTime.anyValClass(5)
            global.typer.typed(gen.mkRuntimeCall(nme.anyValClass,
              List(qual, typer.resolveClassTag(tree.pos, qual.tpe.widen))))
          }*/
          else if (primitiveGetClassMethods.contains(tree.fun.symbol)) {
            // if we got here then we're trying to send a primitive getClass method to either
            // a) an Any, in which cage Object_getClass works because Any erases to object. Or
            //
            // b) a non-primitive, e.g. because the qualifier's type is a refinement type where one parent
            //    of the refinement is a primitive and another is AnyRef. In that case
            //    we get a primitive form of _getClass trying to target a boxed value
            //    so we need replace that method name with Object_getClass to get correct behavior.
            //    See SI-5568.
            Apply(Select(qual, defn.Object_getClass.termRef), Nil)
          } else {
            unknown
          }
        case _ =>
          unknown
      }
      ctx.log(s"$name rewrote $tree to $rewrite")
      rewrite
    }
    else tree
  }
}