summaryrefslogtreecommitdiff
path: root/test/files/run/patmat-exprs.scala
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2011-04-29 20:13:10 +0000
committerPaul Phillips <paulp@improving.org>2011-04-29 20:13:10 +0000
commite439b24609ae40ff0f019026c71affc52e69a8fd (patch)
tree01cba666530f1834fa41aa876b69bc55f7482553 /test/files/run/patmat-exprs.scala
parent80bb9cfb7b2c05a48eb43852b252c36136697835 (diff)
downloadscala-e439b24609ae40ff0f019026c71affc52e69a8fd.tar.gz
scala-e439b24609ae40ff0f019026c71affc52e69a8fd.tar.bz2
scala-e439b24609ae40ff0f019026c71affc52e69a8fd.zip
Fixing the pattern matcher regression I introdu...
Fixing the pattern matcher regression I introduced between rc1 and rc2. Not done with this situation but at least I managed to boil out the big problem and keep my five closed pattern matcher tickets to boot. Closes #4523, no review.
Diffstat (limited to 'test/files/run/patmat-exprs.scala')
-rw-r--r--test/files/run/patmat-exprs.scala579
1 files changed, 579 insertions, 0 deletions
diff --git a/test/files/run/patmat-exprs.scala b/test/files/run/patmat-exprs.scala
new file mode 100644
index 0000000000..f502a14832
--- /dev/null
+++ b/test/files/run/patmat-exprs.scala
@@ -0,0 +1,579 @@
+import runtime.ScalaRunTime
+
+object Test {
+ val p = new Pattern { }
+ import p._
+ implicit object IntOps extends NumericOps[Int] {
+ def zero = 0
+ def one = 1
+
+ def add(a: Int, b: Int): Int = a + b
+ def sub(a: Int, b: Int): Int = a - b
+ def mul(a: Int, b: Int): Int = a * b
+ def mul(a: Int, b: Double): Int = (a * b).toInt
+ def div(a: Int, b: Int): Int = a / b
+ def div(a: Int, b: Double): Int = (a / b).toInt
+ def similar(a: Int, b: Int): Boolean = a == b
+ def abs(a: Int): Double = math.abs(a).toDouble
+ def sqr(a: Int): Int = a * a
+ def sqrt(a: Int): Int = math.sqrt(a).toInt
+ def log(a: Int): Int = math.log(a).toInt
+ def exp(a: Int): Int = math.exp(a).toInt
+ def sin(a: Int): Int = math.sin(a).toInt
+ def cos(a: Int): Int = math.cos(a).toInt
+
+ def fromDouble(a: Double): Int = a.toInt
+ def fromInt(a: Int): Int = a
+ }
+
+ def main(args: Array[String]): Unit = {
+ println((5: Expr[Int]) + 10 + 15 * 20)
+ }
+}
+
+
+trait Pattern {
+ // For trying out 2.7.7
+ //
+ // type Numeric[T]
+ // import java.io.Serializable
+ //
+ // implicit def compat27a[T](x: Iterable[T]) = new {
+ // def iterator: Iterator[T] = x.elements
+ // def sum: Int = 5
+ // def collect[U](pf: PartialFunction[T, U]): Iterable[U] = x map pf
+ // }
+
+ /** Function that returns object of the same type it was passed */
+ trait EndoFunction[-A] {
+ def apply[B <: A](x: B): B
+ }
+
+ /** Allows for smart construction of EndoFunction from an ordinary function */
+ object EndoFunction {
+ def apply[A](f: A => A): EndoFunction[A] = new EndoFunction[A] {
+ def apply[B <: A](x: B): B = f(x).asInstanceOf[B]
+ }
+ }
+
+ trait NumericOps[T] extends Serializable {
+ def zero: T
+ def one: T
+ def two = add(one, one)
+ def three = add(two, one)
+
+ def add(a: T, b: T): T
+ def add(a: T, b: T, c: T): T = add(a, add(b, c))
+ def sub(a: T, b: T): T
+ def mul(a: T, b: T): T
+ def mul(a: T, b: Double): T
+ def div(a: T, b: T): T
+ def div(a: T, b: Double): T
+ def similar(a: T, b: T): Boolean
+ def neg(a: T) = sub(zero, a)
+ def abs(a: T): Double
+ def sqr(a: T): T
+ def sqrt(a: T): T
+ def log(a: T): T
+ def exp(a: T): T
+ def sin(a: T): T
+ def cos(a: T): T
+ def tan(a: T): T = div(sin(a), cos(a))
+
+ def fromDouble(a: Double): T
+ def fromInt(a: Int): T
+
+ def sum(terms: Iterable[T]) = terms.foldLeft(zero)(add)
+ def sum(terms: Iterator[T]) = terms.foldLeft(zero)(add)
+ def product(terms: Iterable[T]) = terms.foldLeft(one)(mul)
+ def product(terms: Iterator[T]) = terms.foldLeft(one)(mul)
+
+
+ def similar(a: Iterable[T], b: Iterable[T]): Boolean = {
+ val i1 = a.iterator
+ val i2 = b.iterator
+ while (i1.hasNext && i2.hasNext)
+ if (!similar(i1.next, i2.next))
+ return false;
+ true;
+ }
+ }
+
+ /**
+ * Simple expression interpreter with some basic symbolic manipulation.
+ * Able to evaluate derivatives.
+ */
+
+ trait Expr[T] {
+
+ import Expr._
+
+ /** Evaluates value of the expression. */
+ def eval(context: Any => Any): T
+
+ /** Symbolically calculates derivative of this expression. Does not simplify it. */
+ def derivative(variable: Var[T]): Expr[T]
+
+ /** Returns arguments of this operator */
+ def args: Iterable[Expr[_]]
+
+ /** Transforms arguments of this operator by applying given function. */
+ def mapArgs(f: EndoFunction[Expr[_]]): Expr[T]
+
+ /** Transforms this operator and its arguments by applying given function */
+ def map(f: EndoFunction[Expr[_]]): Expr[T] =
+ f(mapArgs(EndoFunction[Expr[_]](x => x.map(f))))
+
+ /** Folds all subexpressions in this expression in depth-first order */
+ def fold[A](v: A)(f: (A, Expr[_]) => A): A =
+ f(args.foldLeft(v) { (a, b) => b.fold(a)(f) }, this)
+
+ /** Replaces all occurrences of one subexpression with another one */
+ def replace(from: Expr[_], to: Expr[_]): Expr[T] =
+ map(EndoFunction[Expr[_]](x => if (x == from) to else x))
+
+ /** Returns true if this expression contains given subexpression */
+ def contains(s: Expr[_]): Boolean =
+ this == s || args.exists(_ contains s)
+
+ /** Counts number of occurrences of the given subexpression. */
+ def count(condition: Expr[_] => Boolean): Int =
+ (if (condition(this)) 1 else 0) + args.map(_.count(condition)).sum
+
+ /** Executes some code for every subexpression in the depth-first order */
+ def foreach[U](block: Expr[_] => U): Unit = {
+ args.foreach(_.foreach(block))
+ block(this)
+ }
+
+ /** Collects subexpressions successfully transformed by the given partial function, in depth-first order. */
+ def collect[U](f: PartialFunction[Expr[_], U]): List[U] = {
+ val a = args.flatMap(_.collect(f)).toList
+ if (f.isDefinedAt(this)) (f(this) :: a) else a
+ }
+
+ def leaves: List[Leaf[T]] = collect { case l: Leaf[T] => l }
+
+ def + (other: Expr[T])(implicit n: NumericOps[T]) = Add(List(this, other))
+ def - (other: Expr[T])(implicit n: NumericOps[T]) = Sub(this, other)
+ def * (other: Expr[T])(implicit n: NumericOps[T]) = Mul(this, other)
+ def / (other: Expr[T])(implicit n: NumericOps[T]) = Div(this, other)
+
+ def unary_- (implicit n: NumericOps[T]) = Neg(this)
+ def sqr(implicit n: NumericOps[T]) = Sqr(this)
+
+ def < (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = LT(this, other)
+ def <= (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = LE(this, other)
+ def > (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = GT(this, other)
+ def >= (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = GE(this, other)
+
+ private def generalize(implicit num: NumericOps[T]): Expr[T] = {
+ this match {
+ case Add2(a, b) => Add(a :: b :: Nil)
+ case Add3(a, b, c) => Add(a :: b :: c :: Nil)
+ case Sub(a, b) => Add(a :: Neg(b) :: Nil)
+ case Add(x) => Add(x flatMap {
+ case Neg(Add(y)) => y.map(Neg(_))
+ case Add(y) => y
+ case y => y :: Nil
+ })
+ case x => x
+ }
+ }
+
+ private def specialize(implicit num: NumericOps[T]): Expr[T] = {
+ this match {
+ case Add(Seq(a, b)) => Add2(a, b)
+ case Add(Seq(a, b, c)) => Add3(a, b, c)
+ case x => x
+ }
+ }
+
+ /** Eliminates common negated components of a sum */
+ private def reduceComponents(components: List[Expr[T]])(implicit num: NumericOps[T]): List[Expr[T]] = {
+ val pairs =
+ for (a <- components; b <- components if Neg(a) == b || a == Neg(b))
+ yield (a, b)
+ pairs.foldLeft(components) { (c, pair) =>
+ if (c.contains(pair._1) && c.contains(pair._2))
+ c.diff(pair._1 :: pair._2 :: Nil)
+ else
+ c
+ }
+ }
+
+
+ /** Simplifies this expression to make evaluation faster and more accurate.
+ * Performs only one pass. */
+ private def reduce(implicit num: NumericOps[T]): Expr[T] = {
+ this match {
+ case Add(Seq(Neg(x), Neg(y), Neg(z))) => Neg(Add(List(x, y, z)))
+ case Add(Seq(Mul(x, y), z)) if (x == z) => Mul(x, Add(List(y, One[T])))
+ case Add(Seq(Mul(x, y), z)) if (y == z) => Mul(y, Add(List(z, One[T])))
+ case Add(Seq(Mul(x, y), Mul(u, w))) if (x == u) => Mul(x, Add(List(y, w)))
+ case Add(Seq(Mul(x, y), Mul(u, w))) if (y == w) => Mul(y, Add(List(x, u)))
+ case Add(Seq(Add(x), Add(y))) => Add(x.toList ::: y.toList).simplify
+ case Add(Seq(Add(x), y)) => Add(y :: x.toList).simplify
+ case Add(Seq(x, Add(y))) => Add(x :: y.toList).simplify
+ case Add(x) => {
+ val noZeros = x.filter(_ != Zero[T])
+ val noOnes = noZeros.map { case y: One[_] => Const(num.one); case y => y }
+ val constant = num.sum(noOnes.collect { case c: Const[T] => c.value })
+ val rest = noOnes.filter(x => !x.isInstanceOf[Const[_]]).toList
+ val reduced = reduceComponents(rest)
+ val args = if (num.similar(constant, num.zero)) reduced else reduced ::: Const(constant) :: Nil
+ args.size match {
+ case 0 => Zero[T]
+ case 1 => args.head
+ case 2 => Add2(args(0), args(1))
+ case 3 => Add3(args(0), args(1), args(2))
+ case _ => Add(args)
+ }
+ }
+ case Sub(x: Zero[_], y) => Neg(y)
+ case Sub(x, y: Zero[_]) => x
+ case Sub(x, y) if x == y => Zero[T]
+ case Sub(Mul(x, y), z) if (x == z) => Mul(x, Sub(y, One[T]))
+ case Sub(Mul(x, y), z) if (y == z) => Mul(y, Sub(z, One[T]))
+ case Sub(Mul(x, y), Mul(u, w)) if (x == u) => Mul(x, Sub(y, w))
+ case Sub(Mul(x, y), Mul(u, w)) if (y == w) => Mul(y, Sub(x, u))
+ case Mul(x: Zero[_], y) => Zero[T]
+ case Mul(x, y: Zero[_]) => Zero[T]
+ case Mul(x: One[_], y) => y
+ case Mul(x, y: One[_]) => x
+ case Mul(Neg(x: One[_]), y) => Neg(y)
+ case Mul(x, Neg(y: One[_])) => Neg(x)
+
+ case Mul(x, y) if (x == y) => Sqr(x)
+ case Div(x: Zero[_], y) => Zero[T] // warning: possibly extends domain
+ case Div(x, y: One[_]) => x
+ case Div(Sqr(x), y) if x == y => x
+ case Div(Mul(x, y), z) if (x == z) => y
+ case Div(Mul(x, y), z) if (y == z) => y
+ case Div(Mul(Mul(x, y), z), w) if (x == w) => Mul(y, z)
+ case Div(Mul(Mul(x, y), z), w) if (y == w) => Mul(x, z)
+ case Div(Mul(z, Mul(x, y)), w) if (x == w) => Mul(y, z)
+ case Div(Mul(z, Mul(x, y)), w) if (y == w) => Mul(x, z)
+ case Div(Mul(x, y), Mul(u, w)) if (x == u) => Div(y, w)
+ case Div(Mul(x, y), Mul(u, w)) if (y == w) => Div(x, u)
+ case Div(x: One[_], y) => Inv(y)
+ case Div(x, Sqr(y)) if x == y => Inv(y)
+ case Div(Mul(x, y), Sqr(Mul(u, w))) if x == u && y == w => Inv(Mul(x, y))
+ case Div(x, y) if x == y => One[T]
+
+ case Mul(Neg(a), Neg(b)) => Mul(a, b)
+ case Div(Neg(a), Neg(b)) => Div(a, b)
+
+ case Neg(x: Zero[_]) => Zero[T]
+ case Neg(x: One[_]) => Const(num.neg(num.one))
+ case Sub(Const(x), Const(y)) => const(num.sub(x, y))
+ case Mul(Const(x), Const(y)) => const(num.mul(x, y))
+ case Div(Const(x), Const(y)) => const(num.div(x, y))
+ case Neg(Const(x)) => const(num.neg(x))
+ case Sqr(Const(x)) => const(num.sqr(x))
+
+ case Mul(Const(x), Mul(Const(y), z)) => Mul(const(num.mul(x, y)), z)
+ case Mul(Const(x), Mul(y, Const(z))) => Mul(const(num.mul(x, z)), y)
+ case Mul(Mul(Const(y), z), Const(x)) => Mul(const(num.mul(x, y)), z)
+ case Mul(Mul(y, Const(z)), Const(x)) => Mul(const(num.mul(x, z)), y)
+
+ case Const(x) if x == num.one => One[T]
+ case Const(x) if x == num.zero => Zero[T]
+
+ case Sub(x, Neg(y)) => Add(List(x, y))
+ case Sub(Neg(x), y) => Neg(Add(List(x, y)))
+ case Neg(Neg(x)) => x
+ case Neg(Mul(a: Const[T], x)) => Mul(const(num.neg(a.value)), x)
+ case Neg(Mul(x, a: Const[T])) => Mul(const(num.neg(a.value)), x)
+ case Neg(Div(Neg(a), b)) => Div(a, b)
+ case Neg(Div(a, Neg(b))) => Div(a, b)
+ case Neg(Mul(Neg(a), b)) => Mul(a, b)
+ case Neg(Mul(a, Neg(b))) => Mul(a, b)
+
+ case Log(Exp(x)) => x
+ case x => x
+ }
+ }
+
+ private def optimizeWith(f: Expr[T] => Expr[T]): Expr[T] = {
+ f(mapArgs(EndoFunction[Expr[_]](
+ a => a match { case x: Expr[T] => x.optimizeWith(f) }
+ )))
+ }
+
+ /** Simplifies this expression to make evaluation faster and more accurate.*/
+ def simplify(implicit num: NumericOps[T]): Expr[T] = {
+ val a1 = optimizeWith(_.generalize)
+ val a2 = a1.optimizeWith(_.generalize)
+ val b = a2.optimizeWith(_.reduce)
+ val c = b.optimizeWith(_.reduce)
+ val d = c.optimizeWith(_.specialize)
+ d
+ }
+ }
+
+
+ trait Leaf[T] extends Expr[T] {
+ val args = List[Expr[T]]()
+ def mapArgs(f: EndoFunction[Expr[_]]) = this
+ }
+
+ trait OneArg[T] extends Expr[T] {
+ val expr: Expr[T]
+ val args = List(expr)
+ }
+
+
+ trait TwoArg[T] extends Expr[T] {
+ val left: Expr[T]
+ val right: Expr[T]
+ val args = List(left, right)
+ }
+
+ trait ManyArg[T] extends Expr[T]
+
+ /** Marker trait for specifying that you can safely divide by this */
+ trait NonZero[T] extends Expr[T]
+
+ case class Const[T](value: T)(implicit num: NumericOps[T]) extends Leaf[T] with NonZero[T] {
+ def derivative(variable: Var[T]) = Zero[T]
+ def eval(f: Any => Any) = value
+ override def toString = value.toString
+ }
+
+
+ case class Zero[T] (implicit num: NumericOps[T]) extends Leaf[T] {
+ def derivative(variable: Var[T]) = Zero[T]
+ def eval(f: Any => Any) = num.zero
+ override def toString = "0"
+ }
+
+ case class One[T] (implicit num: NumericOps[T]) extends Leaf[T] {
+ def derivative(variable: Var[T]) = Zero[T]
+ def eval(f: Any => Any) = num.one
+ override def toString = "1"
+ }
+
+ abstract class Var[T](implicit num: NumericOps[T]) extends Leaf[T] {
+ def derivative(variable: Var[T]) = if (variable == this) One[T] else Zero[T]
+ def eval(f: Any => Any) = f(this).asInstanceOf[T]
+ }
+
+ case class NamedVar[T](name: String)(implicit num: NumericOps[T]) extends Var[T] {
+ override lazy val hashCode = ScalaRunTime._hashCode(this)
+ override def toString = name
+ }
+
+ case class Add[T](args: Iterable[Expr[T]])(implicit num: NumericOps[T]) extends ManyArg[T] {
+ def eval(f: Any => Any) = num.sum(for (i <- args.iterator) yield i.eval(f))
+ def derivative(v: Var[T]) = Add(args.map(_.derivative(v)))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Add(args map (x => f(x)))
+ override def toString = "(" + args.mkString(" + ") + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Add2[T](left: Expr[T], right: Expr[T])
+ (implicit num: NumericOps[T]) extends TwoArg[T] {
+ def eval(f: Any => Any) = num.add(left.eval(f), right.eval(f))
+ def derivative(v: Var[T]) = Add2(left.derivative(v), right.derivative(v))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Add2(f(left), f(right))
+ override def toString = "(" + left + " + " + right + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Add3[T](a1: Expr[T], a2: Expr[T], a3: Expr[T])
+ (implicit num: NumericOps[T]) extends ManyArg[T] {
+ val args = List(a1, a2, a3)
+ def eval(f: Any => Any) = num.add(a1.eval(f), a2.eval(f), a3.eval(f))
+ def derivative(v: Var[T]) = Add3(a1.derivative(v), a2.derivative(v), a3.derivative(v))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Add3(f(a1), f(a2), f(a3))
+ override def toString = "(" + a1 + " + " + a2 + " + " + a3 + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Sub[T](left: Expr[T], right: Expr[T])
+ (implicit num: NumericOps[T]) extends TwoArg[T] {
+ def derivative(v: Var[T]) = Sub(left.derivative(v), right.derivative(v))
+ def eval(f: Any => Any) = num.sub(left.eval(f), right.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Sub(f(left), f(right))
+ override def toString = "(" + left + " - " + right + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Neg[T](expr: Expr[T])
+ (implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Neg(expr.derivative(v))
+ def eval(f: Any => Any) = num.neg(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Neg(f(expr))
+ override def toString = "(-" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+
+ }
+
+ case class Mul[T](left: Expr[T], right: Expr[T])
+ (implicit num: NumericOps[T]) extends TwoArg[T] {
+ def derivative(v: Var[T]) =
+ Add(List(
+ Mul(left, right.derivative(v)),
+ Mul(right, left.derivative(v))))
+
+ def eval(f: Any => Any) = num.mul(left.eval(f), right.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Mul(f(left), f(right))
+ override def toString = "(" + left + " * " + right + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Div[T](left: Expr[T], right: Expr[T])
+ (implicit num: NumericOps[T]) extends TwoArg[T] {
+
+ // [f(x) / g(x)]' = [f(x) * 1 / g(x)]' = f'(x) * 1 / g(x) + f(x) * [1 / g(x)]' =
+ // f'(x) / g(x) + f(x) * [-1 / g(x) ^ 2] * g'(x) = (f'(x) * g(x) - f(x) * g'(x)) / g(x)^2
+ def derivative(v: Var[T]) =
+ Div(
+ Sub(
+ Mul(left.derivative(v), right),
+ Mul(left, right.derivative(v))),
+ Sqr(right)
+ )
+
+ def eval(f: Any => Any) = num.div(left.eval(f), right.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) =
+ Div(f(left), f(right))
+ override def toString = "(" + left + " / " + right + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Inv[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+
+ // [1 / f(x)]' = - f'(x) / f(x) ^ 2
+ def derivative(v: Var[T]) = Neg(Div(expr.derivative(v), Sqr(expr)))
+ def eval(f: Any => Any) = num.div(num.one, expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Inv(f(expr))
+ override def toString = "(1 / " + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Sqr[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ // [f(x) ^ 2]' = 2 * f(x) * f'(x)
+ def derivative(v: Var[T]) = Mul(Mul(Const(num.two), expr), expr.derivative(v))
+ def eval(f: Any => Any) = num.sqr(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Sqr(f(expr))
+ override def toString = expr + " ^ 2"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Log[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Div(expr.derivative(v), expr)
+ def eval(f: Any => Any) = num.log(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Log(f(expr))
+ override def toString = "log(" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Exp[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Mul(expr.derivative(v), Exp(expr))
+ def eval(f: Any => Any) = num.exp(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Exp(f(expr))
+ override def toString = "exp(" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ case class Sqrt[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Neg(Div(expr.derivative(v), Sqrt(expr)))
+ def eval(f: Any => Any) = num.sqrt(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Sqrt(f(expr))
+ override def toString = "sqrt(" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Sin[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Mul(expr.derivative(v), Cos(expr))
+ def eval(f: Any => Any) = num.sin(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Sin(f(expr))
+ override def toString = "sin(" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ case class Cos[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] {
+ def derivative(v: Var[T]) = Neg(Mul(expr.derivative(v), Sin(expr)))
+ def eval(f: Any => Any) = num.cos(expr.eval(f))
+ def mapArgs(f: EndoFunction[Expr[_]]) = Cos(f(expr))
+ override def toString = "cos(" + expr + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+
+ abstract class Compare[T](left: Expr[T], right: Expr[T], cmp: (T, T) => Boolean)(implicit num: NumericOps[T])
+ extends Expr[Boolean] {
+ def derivative(v: Var[Boolean]) = throw new IllegalStateException("Derivative of Boolean not allowed")
+ def eval(f: Any => Any) = cmp(left.eval(f), right.eval(f))
+ val args = List(left, right)
+ }
+
+ case class LE[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T])
+ extends Compare[T](left, right, ord.compare(_, _) <= 0) {
+ def mapArgs(f: EndoFunction[Expr[_]]) = LE(
+ f(left), f(right))
+ override def toString = left.toString + " <= " + right.toString
+ }
+
+ case class LT[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T])
+ extends Compare[T](left, right, ord.compare(_, _) < 0) {
+ def mapArgs(f: EndoFunction[Expr[_]]) = LT(
+ f(left), f(right))
+ override def toString = left.toString + " < " + right.toString
+ }
+
+ case class GE[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T])
+ extends Compare[T](left, right, ord.compare(_, _) >= 0) {
+ def mapArgs(f: EndoFunction[Expr[_]]) = GE(
+ f(left), f(right))
+ override def toString = left.toString + " >= " + right.toString
+ }
+
+ case class GT[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T])
+ extends Compare[T](left, right, ord.compare(_, _) > 0) {
+ def mapArgs(f: EndoFunction[Expr[_]]) = GT(
+ f(left), f(right))
+ override def toString = left.toString + " > " + right.toString
+ }
+
+ case class IfElse[T <: Numeric[T]]
+ (condition: Expr[Boolean], left: Expr[T], right: Expr[T])(implicit num: NumericOps[T]) extends Expr[T] {
+
+ val args = List(condition, left, right)
+ def derivative(v: Var[T]) = IfElse(condition, left.derivative(v), right.derivative(v))
+ def eval(f: Any => Any) = if (condition.eval(f)) left.eval(f) else right.eval(f)
+ def mapArgs(f: EndoFunction[Expr[_]]) = IfElse(
+ f(condition).asInstanceOf[Expr[Boolean]],
+ f(left),
+ f(right))
+ override def toString = "if (" + condition + ")(" + left + ") else (" + right + ")"
+ override lazy val hashCode = ScalaRunTime._hashCode(this);
+ }
+
+ object Expr {
+ /** Creates a constant expression */
+ def const[T](value: T)(implicit num: NumericOps[T]): Leaf[T] =
+ if (num.zero == value) Zero[T]
+ else Const(value)
+
+ implicit def double2Constant[T](d: Double)(implicit num: NumericOps[T]): Leaf[T] =
+ const(num.fromDouble(d))
+
+ implicit def float2Constant[T](f: Float)(implicit num: NumericOps[T]): Leaf[T] =
+ const(num.fromDouble(f.toDouble))
+
+ implicit def int2Constant[T](i: Int)(implicit num: NumericOps[T]): Leaf[T] =
+ const(num.fromDouble(i.toDouble))
+
+ implicit def long2Constant[T](l: Long)(implicit num: NumericOps[T]): Leaf[T] =
+ const(num.fromDouble(l.toDouble))
+ }
+} \ No newline at end of file