summaryrefslogtreecommitdiff
path: root/src/scalap/scala/tools/scalap/scalax/rules/Memoisable.scala
blob: b4ce8cab2324ba6064d539b4f33aa5b0996f2e29 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// -----------------------------------------------------------------------------
//
//  Scalax - The Scala Community Library
//  Copyright (c) 2005-8 The Scalax Project. All rights reserved.
//
//  The primary distribution site is http://scalax.scalaforge.org/
//
//  This software is released under the terms of the Revised BSD License.
//  There is NO WARRANTY.  See the file LICENSE for the full text.
//
// -----------------------------------------------------------------------------

package scala.tools.scalap
package scalax
package rules

import scala.collection.mutable

trait MemoisableRules extends Rules {
  def memo[In <: Memoisable, Out, A, X](key : AnyRef)(toRule : => In => Result[Out, A, X]) = {
    lazy val rule = toRule
    from[In] { in => in.memo(key, rule(in)) }
  }

  override def ruleWithName[In, Out, A, X](name : String, f : In => rules.Result[Out, A, X]) = super.ruleWithName(name, (in : In) => in match {
      case s : Memoisable => s.memo(name, f(in))
      case _ => f(in)
    })
}

trait Memoisable {
  def memo[A](key : AnyRef, a : => A) : A
}


object DefaultMemoisable {
  var debug = false
}

trait DefaultMemoisable extends Memoisable {
  protected val map = new mutable.HashMap[AnyRef, Any]

  def memo[A](key : AnyRef, a : => A) = {
    map.getOrElseUpdate(key, compute(key, a)).asInstanceOf[A]
  }

  protected def compute[A](key : AnyRef, a : => A): Any = a match {
    case success : Success[_, _] => onSuccess(key, success); success
    case other =>
      if(DefaultMemoisable.debug) println(key + " -> " + other)
      other
  }

  protected def onSuccess[S, T](key : AnyRef,  result : Success[S, T])  {
    val Success(out, t) = result
    if(DefaultMemoisable.debug) println(key + " -> " + t + " (" + out + ")")
  }
}