summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/patmat/ScalacPatternExpanders.scala
blob: d4f44303bbde7faaacc4f724f8cb15a5b2d72051 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
/* NSC -- new Scala compiler
 * Copyright 2005-2013 LAMP/EPFL
 * @author  Paul Phillips
 */

package scala
package tools
package nsc
package transform
package patmat

/** This is scalac-specific logic layered on top of the scalac-agnostic
 *  "matching products to patterns" logic defined in PatternExpander.
 */
trait ScalacPatternExpanders {
  val global: Global

  import global._
  import definitions._
  import treeInfo._
  import analyzer._

  type PatternAligned = ScalacPatternExpander#Aligned

  implicit class AlignedOps(val aligned: PatternAligned) {
    import aligned._
    def expectedTypes     = typedPatterns map (_.tpe)
    def unexpandedFormals = extractor.varargsTypes
  }
  trait ScalacPatternExpander extends PatternExpander[Tree, Type] {
    def NoPattern = EmptyTree
    def NoType    = global.NoType

    def newPatterns(patterns: List[Tree]): Patterns = patterns match {
      case init :+ last if isStar(last) => Patterns(init, last)
      case _                            => Patterns(patterns, NoPattern)
    }
    def elementTypeOf(tpe: Type) = {
      val seq = repeatedToSeq(tpe)

      ( typeOfMemberNamedHead(seq)
          orElse typeOfMemberNamedApply(seq)
          orElse definitions.elementType(ArrayClass, seq)
      )
    }
    def newExtractor(whole: Type, fixed: List[Type], repeated: Repeated, typeOfSinglePattern: Type): Extractor =
      logResult(s"newExtractor($whole, $fixed, $repeated, $typeOfSinglePattern")(Extractor(whole, fixed, repeated, typeOfSinglePattern))
    def newExtractor(whole: Type, fixed: List[Type], repeated: Repeated): Extractor = newExtractor(whole, fixed, repeated, tupleType(fixed))

    // Turn Seq[A] into Repeated(Seq[A], A, A*)
    def repeatedFromSeq(seqType: Type): Repeated = {
      val elem     = elementTypeOf(seqType)
      val repeated = scalaRepeatedType(elem)

      Repeated(seqType, elem, repeated)
    }
    // Turn A* into Repeated(Seq[A], A, A*)
    def repeatedFromVarargs(repeated: Type): Repeated =
      Repeated(repeatedToSeq(repeated), repeatedToSingle(repeated), repeated)

    /** In this case we are basing the pattern expansion on a case class constructor.
     *  The argument is the MethodType carried by the primary constructor.
     */
    def applyMethodTypes(method: Type): Extractor = {
      val whole = method.finalResultType

      method.paramTypes match {
        case init :+ last if isScalaRepeatedParamType(last) => newExtractor(whole, init, repeatedFromVarargs(last))
        case tps                                            => newExtractor(whole, tps, NoRepeated)
      }
    }

    /** In this case, expansion is based on an unapply or unapplySeq method.
     *  Unfortunately the MethodType does not carry the information of whether
     *  it was unapplySeq, so we have to funnel that information in separately.
     */
    def unapplyMethodTypes(context: Context, whole: Type, result: Type, isSeq: Boolean): Extractor = {
      if (result =:= BooleanTpe) newExtractor(whole, Nil, NoRepeated)
      else {
        val getResult = typeOfMemberNamedGet(result)
        def noGetError() = {
          val name = "unapply" + (if (isSeq) "Seq" else "")
          context.error(context.tree.pos, s"The result type of an $name method must contain a member `get` to be used as an extractor pattern, no such member exists in ${result}")
        }
        val expanded = getResult match {
          case global.NoType => noGetError(); Nil
          case rawGet if !hasSelectors(rawGet) => rawGet :: Nil
          case rawGet                          => typesOfSelectors(rawGet)
        }
        expanded match {
          case init :+ last if isSeq => newExtractor(whole, init, repeatedFromSeq(last), getResult)
          case tps                   => newExtractor(whole, tps, NoRepeated, getResult)
        }
      }
    }
  }
  object alignPatterns extends ScalacPatternExpander {
    private def validateAligned(context: Context, tree: Tree, aligned: Aligned): Aligned = {
      import aligned._

      def owner         = tree.symbol.owner
      def offering      = extractor.offeringString
      def symString     = tree.symbol.fullLocationString
      def offerString   = if (extractor.isErroneous) "" else s" offering $offering"
      def arityExpected = ( if (extractor.hasSeq) "at least " else "" ) + productArity

      def err(msg: String)         = context.error(tree.pos, msg)
      def warn(msg: String)        = context.warning(tree.pos, msg)
      def arityError(what: String) = err(s"$what patterns for $owner$offerString: expected $arityExpected, found $totalArity")

      if (isStar && !isSeq)
        err("Star pattern must correspond with varargs or unapplySeq")
      else if (elementArity < 0)
        arityError("not enough")
      else if (elementArity > 0 && !isSeq)
        arityError("too many")
      else if (settings.warnStarsAlign && isSeq && productArity > 0 && elementArity > 0) warn {
        if (isStar) "Sequence wildcard (_*) does not align with repeated case parameter or extracted sequence; the result may be unexpected."
        else "A repeated case parameter or extracted sequence is not matched by a sequence wildcard (_*), and may fail at runtime."
      }

      aligned
    }

    def apply(context: Context, sel: Tree, args: List[Tree]): Aligned = {
      val fn = sel match {
        case Unapplied(fn) => fn
        case _             => sel
      }
      val patterns  = newPatterns(args)
      val isUnapply = sel.symbol.name == nme.unapply

      val extractor = sel.symbol.name match {
        case nme.unapply    => unapplyMethodTypes(context, firstParamType(fn.tpe), sel.tpe, isSeq = false)
        case nme.unapplySeq => unapplyMethodTypes(context, firstParamType(fn.tpe), sel.tpe, isSeq = true)
        case _              => applyMethodTypes(fn.tpe)
      }

      /** Rather than let the error that is SI-6675 pollute the entire matching
       *  process, we will tuple the extractor before creation Aligned so that
       *  it contains known good values.
       */
      def productArity    = extractor.productArity
      def acceptMessage   = if (extractor.isErroneous) "" else s" to hold ${extractor.offeringString}"
      val requiresTupling = isUnapply && patterns.totalArity == 1 && productArity > 1

      val normalizedExtractor = if (requiresTupling) {
        val tupled = extractor.asSinglePattern
        if (effectivePatternArity(args) == 1 && isTupleType(extractor.typeOfSinglePattern)) {
          val sym = sel.symbol.owner
          currentRun.reporting.deprecationWarning(sel.pos, sym, s"${sym} expects $productArity patterns$acceptMessage but crushing into $productArity-tuple to fit single pattern (SI-6675)")
        }
        tupled
      } else extractor
      validateAligned(context, fn, Aligned(patterns, normalizedExtractor))
    }

    def apply(context: Context, tree: Tree): Aligned = tree match {
      case Apply(fn, args)   => apply(context, fn, args)
      case UnApply(fn, args) => apply(context, fn, args)
    }
  }
}