summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
blob: 6a0e48f9c02189b4003bb02ea319b0ce810c4aa4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/* NSC -- new Scala compiler
 * Copyright 2005-2009 LAMP/EPFL
 * @author Burak Emir
 */
// $Id$

package scala.tools.nsc
package matching

import scala.tools.nsc.util.NoPosition

/**
 *  @author Burak Emir
 */
trait PatternNodes extends ast.TreeDSL
{
  self: transform.ExplicitOuter =>

  import global.{ typer => _, _ }
  import CODE._
  import definitions.{ ListClass, ConsClass }

  object Types {
    import definitions._
    implicit def enrichType(x: Type): RichType = new RichType(x)

    class RichType(undecodedTpe: Type) {
      def tpe = decodedEqualsType(undecodedTpe)
      def isAnyRef = tpe <:< AnyRefClass.tpe

      // These tests for final classes can inspect the typeSymbol
      private def is(s: Symbol) = tpe.typeSymbol eq s
      def      isByte = is(ByteClass)
      def     isShort = is(ShortClass)
      def       isInt = is(IntClass)
      def      isChar = is(CharClass)
      def   isBoolean = is(BooleanClass)
      def   isNothing = is(NothingClass)
      def     isArray = is(ArrayClass)
    }
  }

  /** For folding a list into a well-typed x :: y :: etc :: tree. */
  private def listFolder(tpe: Type) = {
    val MethodType(_, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe
    val consRef                             = typeRef(pre, sym, List(tpe))
    val listRef                             = typeRef(pre, ListClass, List(tpe))

    def fold(x: Tree, xs: Tree) = unbind(x) match {
      case _: Star  => makeBind(definedVars(x), WILD(x.tpe))
      case _        =>
        val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy")
        val consType    = MethodType(dummyMethod newSyntheticValueParams List(tpe, listRef), consRef)

        Apply(TypeTree(consType), List(x, xs)) setType consRef
    }

    fold _
  }

  def normalizedListPattern(pats: List[Tree], tptArg: Type): Tree =
    pats.foldRight(gen.mkNil)(listFolder(tptArg))
}