diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/Matrix.scala | 4 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/Patterns.scala | 32 | ||||
-rw-r--r-- | test/files/run/patmat-exprs.check | 1 | ||||
-rw-r--r-- | test/files/run/patmat-exprs.scala | 579 |
4 files changed, 598 insertions, 18 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index b9b9b51384..6b691826a9 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -128,6 +128,10 @@ trait Matrix extends MatrixAdditions { } def recordSyntheticSym(sym: Symbol): Symbol = { _syntheticSyms += sym + if (_syntheticSyms.size > 25000) { + cunit.error(owner.pos, "Sanity check failed: over 25000 symbols created for pattern match.") + abort("This is a bug in the pattern matcher.") + } sym } diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index e1f8204960..29563264eb 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -180,38 +180,35 @@ trait Patterns extends ast.TreeDSL { override def description = "Unapply(%s => %s)".format(necessaryType, resTypesString) } - case class SequenceExtractorPattern(tree: UnApply, elems: List[Tree]) extends UnapplyPattern with SequenceLikePattern { - override def simplify(pv: PatternVar) = { - pv.sym setFlag NO_EXHAUSTIVE - - if (pv.tpe <:< arg.tpe) this - else this rebindTo uaTyped - } - - override def description = "UnapplySeq(%s => %s)".format(necessaryType, resTypesString) - } - // Special List handling. It was like that when I got here. case class ListExtractorPattern(tree: UnApply, tpt: Tree, elems: List[Tree]) extends UnapplyPattern with SequenceLikePattern { private val cons = ConsClass.primaryConstructor.tpe.resultType private val consRef = typeRef(cons.prefix, ConsClass, List(tpt.tpe)) private val listRef = typeRef(cons.prefix, ListClass, List(tpt.tpe)) + private val seqRef = typeRef(cons.prefix, SeqClass, List(tpt.tpe)) + private def thisSeqRef = { + val tc = (tree.tpe baseType SeqClass).typeConstructor + if (tc.typeParams.size == 1) appliedType(tc, List(tpt.tpe)) + else seqRef + } // Fold a list into a well-typed x :: y :: etc :: tree. - private def listFolder(x: Tree, xs: Tree) = unbind(x) match { - case _: Star => Pattern(x) rebindTo WILD(x.tpe) boundTree + private def listFolder(x: Pattern, xs: Pattern): Pattern = x match { + case Pattern(Star(_), _) => x rebindTo WILD(x.tpe) case _ => val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy") val consType = MethodType(dummyMethod newSyntheticValueParams List(tpt.tpe, listRef), consRef) - Apply(TypeTree(consType), List(x, xs)) setType consRef + Pattern(Apply(TypeTree(consType), List(x.boundTree, xs.boundTree)) setType consRef) } + private def foldedPatterns = elems.foldRight(NilPattern)((x, y) => listFolder(Pattern(x), y)) override def necessaryType = if (nonStarPatterns.nonEmpty) consRef else listRef + override def simplify(pv: PatternVar) = { pv.sym setFlag NO_EXHAUSTIVE if (pv.tpe <:< necessaryType) - this rebindTo elems.foldRight(gen.mkNil)(listFolder) + foldedPatterns else this rebindTo (Typed(tree, TypeTree(necessaryType)) setType necessaryType) } @@ -318,9 +315,8 @@ trait Patterns extends ast.TreeDSL { } def apply(x: UnApply): Pattern = x match { - case UnapplySeq(container, tpt, elems) => - if (container == ListModule) ListExtractorPattern(x, tpt, elems) - else SequenceExtractorPattern(x, elems) + case UnapplySeq(ListModule, tpt, elems) => + ListExtractorPattern(x, tpt, elems) case _ => ExtractorPattern(x) } diff --git a/test/files/run/patmat-exprs.check b/test/files/run/patmat-exprs.check new file mode 100644 index 0000000000..b6df9385fa --- /dev/null +++ b/test/files/run/patmat-exprs.check @@ -0,0 +1 @@ +((5 + 10) + 300) diff --git a/test/files/run/patmat-exprs.scala b/test/files/run/patmat-exprs.scala new file mode 100644 index 0000000000..f502a14832 --- /dev/null +++ b/test/files/run/patmat-exprs.scala @@ -0,0 +1,579 @@ +import runtime.ScalaRunTime + +object Test { + val p = new Pattern { } + import p._ + implicit object IntOps extends NumericOps[Int] { + def zero = 0 + def one = 1 + + def add(a: Int, b: Int): Int = a + b + def sub(a: Int, b: Int): Int = a - b + def mul(a: Int, b: Int): Int = a * b + def mul(a: Int, b: Double): Int = (a * b).toInt + def div(a: Int, b: Int): Int = a / b + def div(a: Int, b: Double): Int = (a / b).toInt + def similar(a: Int, b: Int): Boolean = a == b + def abs(a: Int): Double = math.abs(a).toDouble + def sqr(a: Int): Int = a * a + def sqrt(a: Int): Int = math.sqrt(a).toInt + def log(a: Int): Int = math.log(a).toInt + def exp(a: Int): Int = math.exp(a).toInt + def sin(a: Int): Int = math.sin(a).toInt + def cos(a: Int): Int = math.cos(a).toInt + + def fromDouble(a: Double): Int = a.toInt + def fromInt(a: Int): Int = a + } + + def main(args: Array[String]): Unit = { + println((5: Expr[Int]) + 10 + 15 * 20) + } +} + + +trait Pattern { + // For trying out 2.7.7 + // + // type Numeric[T] + // import java.io.Serializable + // + // implicit def compat27a[T](x: Iterable[T]) = new { + // def iterator: Iterator[T] = x.elements + // def sum: Int = 5 + // def collect[U](pf: PartialFunction[T, U]): Iterable[U] = x map pf + // } + + /** Function that returns object of the same type it was passed */ + trait EndoFunction[-A] { + def apply[B <: A](x: B): B + } + + /** Allows for smart construction of EndoFunction from an ordinary function */ + object EndoFunction { + def apply[A](f: A => A): EndoFunction[A] = new EndoFunction[A] { + def apply[B <: A](x: B): B = f(x).asInstanceOf[B] + } + } + + trait NumericOps[T] extends Serializable { + def zero: T + def one: T + def two = add(one, one) + def three = add(two, one) + + def add(a: T, b: T): T + def add(a: T, b: T, c: T): T = add(a, add(b, c)) + def sub(a: T, b: T): T + def mul(a: T, b: T): T + def mul(a: T, b: Double): T + def div(a: T, b: T): T + def div(a: T, b: Double): T + def similar(a: T, b: T): Boolean + def neg(a: T) = sub(zero, a) + def abs(a: T): Double + def sqr(a: T): T + def sqrt(a: T): T + def log(a: T): T + def exp(a: T): T + def sin(a: T): T + def cos(a: T): T + def tan(a: T): T = div(sin(a), cos(a)) + + def fromDouble(a: Double): T + def fromInt(a: Int): T + + def sum(terms: Iterable[T]) = terms.foldLeft(zero)(add) + def sum(terms: Iterator[T]) = terms.foldLeft(zero)(add) + def product(terms: Iterable[T]) = terms.foldLeft(one)(mul) + def product(terms: Iterator[T]) = terms.foldLeft(one)(mul) + + + def similar(a: Iterable[T], b: Iterable[T]): Boolean = { + val i1 = a.iterator + val i2 = b.iterator + while (i1.hasNext && i2.hasNext) + if (!similar(i1.next, i2.next)) + return false; + true; + } + } + + /** + * Simple expression interpreter with some basic symbolic manipulation. + * Able to evaluate derivatives. + */ + + trait Expr[T] { + + import Expr._ + + /** Evaluates value of the expression. */ + def eval(context: Any => Any): T + + /** Symbolically calculates derivative of this expression. Does not simplify it. */ + def derivative(variable: Var[T]): Expr[T] + + /** Returns arguments of this operator */ + def args: Iterable[Expr[_]] + + /** Transforms arguments of this operator by applying given function. */ + def mapArgs(f: EndoFunction[Expr[_]]): Expr[T] + + /** Transforms this operator and its arguments by applying given function */ + def map(f: EndoFunction[Expr[_]]): Expr[T] = + f(mapArgs(EndoFunction[Expr[_]](x => x.map(f)))) + + /** Folds all subexpressions in this expression in depth-first order */ + def fold[A](v: A)(f: (A, Expr[_]) => A): A = + f(args.foldLeft(v) { (a, b) => b.fold(a)(f) }, this) + + /** Replaces all occurrences of one subexpression with another one */ + def replace(from: Expr[_], to: Expr[_]): Expr[T] = + map(EndoFunction[Expr[_]](x => if (x == from) to else x)) + + /** Returns true if this expression contains given subexpression */ + def contains(s: Expr[_]): Boolean = + this == s || args.exists(_ contains s) + + /** Counts number of occurrences of the given subexpression. */ + def count(condition: Expr[_] => Boolean): Int = + (if (condition(this)) 1 else 0) + args.map(_.count(condition)).sum + + /** Executes some code for every subexpression in the depth-first order */ + def foreach[U](block: Expr[_] => U): Unit = { + args.foreach(_.foreach(block)) + block(this) + } + + /** Collects subexpressions successfully transformed by the given partial function, in depth-first order. */ + def collect[U](f: PartialFunction[Expr[_], U]): List[U] = { + val a = args.flatMap(_.collect(f)).toList + if (f.isDefinedAt(this)) (f(this) :: a) else a + } + + def leaves: List[Leaf[T]] = collect { case l: Leaf[T] => l } + + def + (other: Expr[T])(implicit n: NumericOps[T]) = Add(List(this, other)) + def - (other: Expr[T])(implicit n: NumericOps[T]) = Sub(this, other) + def * (other: Expr[T])(implicit n: NumericOps[T]) = Mul(this, other) + def / (other: Expr[T])(implicit n: NumericOps[T]) = Div(this, other) + + def unary_- (implicit n: NumericOps[T]) = Neg(this) + def sqr(implicit n: NumericOps[T]) = Sqr(this) + + def < (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = LT(this, other) + def <= (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = LE(this, other) + def > (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = GT(this, other) + def >= (other: Expr[T])(implicit n: NumericOps[T], o: Ordering[T]) = GE(this, other) + + private def generalize(implicit num: NumericOps[T]): Expr[T] = { + this match { + case Add2(a, b) => Add(a :: b :: Nil) + case Add3(a, b, c) => Add(a :: b :: c :: Nil) + case Sub(a, b) => Add(a :: Neg(b) :: Nil) + case Add(x) => Add(x flatMap { + case Neg(Add(y)) => y.map(Neg(_)) + case Add(y) => y + case y => y :: Nil + }) + case x => x + } + } + + private def specialize(implicit num: NumericOps[T]): Expr[T] = { + this match { + case Add(Seq(a, b)) => Add2(a, b) + case Add(Seq(a, b, c)) => Add3(a, b, c) + case x => x + } + } + + /** Eliminates common negated components of a sum */ + private def reduceComponents(components: List[Expr[T]])(implicit num: NumericOps[T]): List[Expr[T]] = { + val pairs = + for (a <- components; b <- components if Neg(a) == b || a == Neg(b)) + yield (a, b) + pairs.foldLeft(components) { (c, pair) => + if (c.contains(pair._1) && c.contains(pair._2)) + c.diff(pair._1 :: pair._2 :: Nil) + else + c + } + } + + + /** Simplifies this expression to make evaluation faster and more accurate. + * Performs only one pass. */ + private def reduce(implicit num: NumericOps[T]): Expr[T] = { + this match { + case Add(Seq(Neg(x), Neg(y), Neg(z))) => Neg(Add(List(x, y, z))) + case Add(Seq(Mul(x, y), z)) if (x == z) => Mul(x, Add(List(y, One[T]))) + case Add(Seq(Mul(x, y), z)) if (y == z) => Mul(y, Add(List(z, One[T]))) + case Add(Seq(Mul(x, y), Mul(u, w))) if (x == u) => Mul(x, Add(List(y, w))) + case Add(Seq(Mul(x, y), Mul(u, w))) if (y == w) => Mul(y, Add(List(x, u))) + case Add(Seq(Add(x), Add(y))) => Add(x.toList ::: y.toList).simplify + case Add(Seq(Add(x), y)) => Add(y :: x.toList).simplify + case Add(Seq(x, Add(y))) => Add(x :: y.toList).simplify + case Add(x) => { + val noZeros = x.filter(_ != Zero[T]) + val noOnes = noZeros.map { case y: One[_] => Const(num.one); case y => y } + val constant = num.sum(noOnes.collect { case c: Const[T] => c.value }) + val rest = noOnes.filter(x => !x.isInstanceOf[Const[_]]).toList + val reduced = reduceComponents(rest) + val args = if (num.similar(constant, num.zero)) reduced else reduced ::: Const(constant) :: Nil + args.size match { + case 0 => Zero[T] + case 1 => args.head + case 2 => Add2(args(0), args(1)) + case 3 => Add3(args(0), args(1), args(2)) + case _ => Add(args) + } + } + case Sub(x: Zero[_], y) => Neg(y) + case Sub(x, y: Zero[_]) => x + case Sub(x, y) if x == y => Zero[T] + case Sub(Mul(x, y), z) if (x == z) => Mul(x, Sub(y, One[T])) + case Sub(Mul(x, y), z) if (y == z) => Mul(y, Sub(z, One[T])) + case Sub(Mul(x, y), Mul(u, w)) if (x == u) => Mul(x, Sub(y, w)) + case Sub(Mul(x, y), Mul(u, w)) if (y == w) => Mul(y, Sub(x, u)) + case Mul(x: Zero[_], y) => Zero[T] + case Mul(x, y: Zero[_]) => Zero[T] + case Mul(x: One[_], y) => y + case Mul(x, y: One[_]) => x + case Mul(Neg(x: One[_]), y) => Neg(y) + case Mul(x, Neg(y: One[_])) => Neg(x) + + case Mul(x, y) if (x == y) => Sqr(x) + case Div(x: Zero[_], y) => Zero[T] // warning: possibly extends domain + case Div(x, y: One[_]) => x + case Div(Sqr(x), y) if x == y => x + case Div(Mul(x, y), z) if (x == z) => y + case Div(Mul(x, y), z) if (y == z) => y + case Div(Mul(Mul(x, y), z), w) if (x == w) => Mul(y, z) + case Div(Mul(Mul(x, y), z), w) if (y == w) => Mul(x, z) + case Div(Mul(z, Mul(x, y)), w) if (x == w) => Mul(y, z) + case Div(Mul(z, Mul(x, y)), w) if (y == w) => Mul(x, z) + case Div(Mul(x, y), Mul(u, w)) if (x == u) => Div(y, w) + case Div(Mul(x, y), Mul(u, w)) if (y == w) => Div(x, u) + case Div(x: One[_], y) => Inv(y) + case Div(x, Sqr(y)) if x == y => Inv(y) + case Div(Mul(x, y), Sqr(Mul(u, w))) if x == u && y == w => Inv(Mul(x, y)) + case Div(x, y) if x == y => One[T] + + case Mul(Neg(a), Neg(b)) => Mul(a, b) + case Div(Neg(a), Neg(b)) => Div(a, b) + + case Neg(x: Zero[_]) => Zero[T] + case Neg(x: One[_]) => Const(num.neg(num.one)) + case Sub(Const(x), Const(y)) => const(num.sub(x, y)) + case Mul(Const(x), Const(y)) => const(num.mul(x, y)) + case Div(Const(x), Const(y)) => const(num.div(x, y)) + case Neg(Const(x)) => const(num.neg(x)) + case Sqr(Const(x)) => const(num.sqr(x)) + + case Mul(Const(x), Mul(Const(y), z)) => Mul(const(num.mul(x, y)), z) + case Mul(Const(x), Mul(y, Const(z))) => Mul(const(num.mul(x, z)), y) + case Mul(Mul(Const(y), z), Const(x)) => Mul(const(num.mul(x, y)), z) + case Mul(Mul(y, Const(z)), Const(x)) => Mul(const(num.mul(x, z)), y) + + case Const(x) if x == num.one => One[T] + case Const(x) if x == num.zero => Zero[T] + + case Sub(x, Neg(y)) => Add(List(x, y)) + case Sub(Neg(x), y) => Neg(Add(List(x, y))) + case Neg(Neg(x)) => x + case Neg(Mul(a: Const[T], x)) => Mul(const(num.neg(a.value)), x) + case Neg(Mul(x, a: Const[T])) => Mul(const(num.neg(a.value)), x) + case Neg(Div(Neg(a), b)) => Div(a, b) + case Neg(Div(a, Neg(b))) => Div(a, b) + case Neg(Mul(Neg(a), b)) => Mul(a, b) + case Neg(Mul(a, Neg(b))) => Mul(a, b) + + case Log(Exp(x)) => x + case x => x + } + } + + private def optimizeWith(f: Expr[T] => Expr[T]): Expr[T] = { + f(mapArgs(EndoFunction[Expr[_]]( + a => a match { case x: Expr[T] => x.optimizeWith(f) } + ))) + } + + /** Simplifies this expression to make evaluation faster and more accurate.*/ + def simplify(implicit num: NumericOps[T]): Expr[T] = { + val a1 = optimizeWith(_.generalize) + val a2 = a1.optimizeWith(_.generalize) + val b = a2.optimizeWith(_.reduce) + val c = b.optimizeWith(_.reduce) + val d = c.optimizeWith(_.specialize) + d + } + } + + + trait Leaf[T] extends Expr[T] { + val args = List[Expr[T]]() + def mapArgs(f: EndoFunction[Expr[_]]) = this + } + + trait OneArg[T] extends Expr[T] { + val expr: Expr[T] + val args = List(expr) + } + + + trait TwoArg[T] extends Expr[T] { + val left: Expr[T] + val right: Expr[T] + val args = List(left, right) + } + + trait ManyArg[T] extends Expr[T] + + /** Marker trait for specifying that you can safely divide by this */ + trait NonZero[T] extends Expr[T] + + case class Const[T](value: T)(implicit num: NumericOps[T]) extends Leaf[T] with NonZero[T] { + def derivative(variable: Var[T]) = Zero[T] + def eval(f: Any => Any) = value + override def toString = value.toString + } + + + case class Zero[T] (implicit num: NumericOps[T]) extends Leaf[T] { + def derivative(variable: Var[T]) = Zero[T] + def eval(f: Any => Any) = num.zero + override def toString = "0" + } + + case class One[T] (implicit num: NumericOps[T]) extends Leaf[T] { + def derivative(variable: Var[T]) = Zero[T] + def eval(f: Any => Any) = num.one + override def toString = "1" + } + + abstract class Var[T](implicit num: NumericOps[T]) extends Leaf[T] { + def derivative(variable: Var[T]) = if (variable == this) One[T] else Zero[T] + def eval(f: Any => Any) = f(this).asInstanceOf[T] + } + + case class NamedVar[T](name: String)(implicit num: NumericOps[T]) extends Var[T] { + override lazy val hashCode = ScalaRunTime._hashCode(this) + override def toString = name + } + + case class Add[T](args: Iterable[Expr[T]])(implicit num: NumericOps[T]) extends ManyArg[T] { + def eval(f: Any => Any) = num.sum(for (i <- args.iterator) yield i.eval(f)) + def derivative(v: Var[T]) = Add(args.map(_.derivative(v))) + def mapArgs(f: EndoFunction[Expr[_]]) = Add(args map (x => f(x))) + override def toString = "(" + args.mkString(" + ") + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Add2[T](left: Expr[T], right: Expr[T]) + (implicit num: NumericOps[T]) extends TwoArg[T] { + def eval(f: Any => Any) = num.add(left.eval(f), right.eval(f)) + def derivative(v: Var[T]) = Add2(left.derivative(v), right.derivative(v)) + def mapArgs(f: EndoFunction[Expr[_]]) = Add2(f(left), f(right)) + override def toString = "(" + left + " + " + right + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Add3[T](a1: Expr[T], a2: Expr[T], a3: Expr[T]) + (implicit num: NumericOps[T]) extends ManyArg[T] { + val args = List(a1, a2, a3) + def eval(f: Any => Any) = num.add(a1.eval(f), a2.eval(f), a3.eval(f)) + def derivative(v: Var[T]) = Add3(a1.derivative(v), a2.derivative(v), a3.derivative(v)) + def mapArgs(f: EndoFunction[Expr[_]]) = Add3(f(a1), f(a2), f(a3)) + override def toString = "(" + a1 + " + " + a2 + " + " + a3 + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Sub[T](left: Expr[T], right: Expr[T]) + (implicit num: NumericOps[T]) extends TwoArg[T] { + def derivative(v: Var[T]) = Sub(left.derivative(v), right.derivative(v)) + def eval(f: Any => Any) = num.sub(left.eval(f), right.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Sub(f(left), f(right)) + override def toString = "(" + left + " - " + right + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Neg[T](expr: Expr[T]) + (implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Neg(expr.derivative(v)) + def eval(f: Any => Any) = num.neg(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Neg(f(expr)) + override def toString = "(-" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + + } + + case class Mul[T](left: Expr[T], right: Expr[T]) + (implicit num: NumericOps[T]) extends TwoArg[T] { + def derivative(v: Var[T]) = + Add(List( + Mul(left, right.derivative(v)), + Mul(right, left.derivative(v)))) + + def eval(f: Any => Any) = num.mul(left.eval(f), right.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Mul(f(left), f(right)) + override def toString = "(" + left + " * " + right + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Div[T](left: Expr[T], right: Expr[T]) + (implicit num: NumericOps[T]) extends TwoArg[T] { + + // [f(x) / g(x)]' = [f(x) * 1 / g(x)]' = f'(x) * 1 / g(x) + f(x) * [1 / g(x)]' = + // f'(x) / g(x) + f(x) * [-1 / g(x) ^ 2] * g'(x) = (f'(x) * g(x) - f(x) * g'(x)) / g(x)^2 + def derivative(v: Var[T]) = + Div( + Sub( + Mul(left.derivative(v), right), + Mul(left, right.derivative(v))), + Sqr(right) + ) + + def eval(f: Any => Any) = num.div(left.eval(f), right.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = + Div(f(left), f(right)) + override def toString = "(" + left + " / " + right + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Inv[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + + // [1 / f(x)]' = - f'(x) / f(x) ^ 2 + def derivative(v: Var[T]) = Neg(Div(expr.derivative(v), Sqr(expr))) + def eval(f: Any => Any) = num.div(num.one, expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Inv(f(expr)) + override def toString = "(1 / " + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Sqr[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + // [f(x) ^ 2]' = 2 * f(x) * f'(x) + def derivative(v: Var[T]) = Mul(Mul(Const(num.two), expr), expr.derivative(v)) + def eval(f: Any => Any) = num.sqr(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Sqr(f(expr)) + override def toString = expr + " ^ 2" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Log[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Div(expr.derivative(v), expr) + def eval(f: Any => Any) = num.log(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Log(f(expr)) + override def toString = "log(" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Exp[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Mul(expr.derivative(v), Exp(expr)) + def eval(f: Any => Any) = num.exp(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Exp(f(expr)) + override def toString = "exp(" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + case class Sqrt[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Neg(Div(expr.derivative(v), Sqrt(expr))) + def eval(f: Any => Any) = num.sqrt(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Sqrt(f(expr)) + override def toString = "sqrt(" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Sin[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Mul(expr.derivative(v), Cos(expr)) + def eval(f: Any => Any) = num.sin(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Sin(f(expr)) + override def toString = "sin(" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + case class Cos[T](expr: Expr[T])(implicit num: NumericOps[T]) extends OneArg[T] { + def derivative(v: Var[T]) = Neg(Mul(expr.derivative(v), Sin(expr))) + def eval(f: Any => Any) = num.cos(expr.eval(f)) + def mapArgs(f: EndoFunction[Expr[_]]) = Cos(f(expr)) + override def toString = "cos(" + expr + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + + abstract class Compare[T](left: Expr[T], right: Expr[T], cmp: (T, T) => Boolean)(implicit num: NumericOps[T]) + extends Expr[Boolean] { + def derivative(v: Var[Boolean]) = throw new IllegalStateException("Derivative of Boolean not allowed") + def eval(f: Any => Any) = cmp(left.eval(f), right.eval(f)) + val args = List(left, right) + } + + case class LE[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T]) + extends Compare[T](left, right, ord.compare(_, _) <= 0) { + def mapArgs(f: EndoFunction[Expr[_]]) = LE( + f(left), f(right)) + override def toString = left.toString + " <= " + right.toString + } + + case class LT[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T]) + extends Compare[T](left, right, ord.compare(_, _) < 0) { + def mapArgs(f: EndoFunction[Expr[_]]) = LT( + f(left), f(right)) + override def toString = left.toString + " < " + right.toString + } + + case class GE[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T]) + extends Compare[T](left, right, ord.compare(_, _) >= 0) { + def mapArgs(f: EndoFunction[Expr[_]]) = GE( + f(left), f(right)) + override def toString = left.toString + " >= " + right.toString + } + + case class GT[T](left: Expr[T], right: Expr[T])(implicit num: NumericOps[T], ord: Ordering[T]) + extends Compare[T](left, right, ord.compare(_, _) > 0) { + def mapArgs(f: EndoFunction[Expr[_]]) = GT( + f(left), f(right)) + override def toString = left.toString + " > " + right.toString + } + + case class IfElse[T <: Numeric[T]] + (condition: Expr[Boolean], left: Expr[T], right: Expr[T])(implicit num: NumericOps[T]) extends Expr[T] { + + val args = List(condition, left, right) + def derivative(v: Var[T]) = IfElse(condition, left.derivative(v), right.derivative(v)) + def eval(f: Any => Any) = if (condition.eval(f)) left.eval(f) else right.eval(f) + def mapArgs(f: EndoFunction[Expr[_]]) = IfElse( + f(condition).asInstanceOf[Expr[Boolean]], + f(left), + f(right)) + override def toString = "if (" + condition + ")(" + left + ") else (" + right + ")" + override lazy val hashCode = ScalaRunTime._hashCode(this); + } + + object Expr { + /** Creates a constant expression */ + def const[T](value: T)(implicit num: NumericOps[T]): Leaf[T] = + if (num.zero == value) Zero[T] + else Const(value) + + implicit def double2Constant[T](d: Double)(implicit num: NumericOps[T]): Leaf[T] = + const(num.fromDouble(d)) + + implicit def float2Constant[T](f: Float)(implicit num: NumericOps[T]): Leaf[T] = + const(num.fromDouble(f.toDouble)) + + implicit def int2Constant[T](i: Int)(implicit num: NumericOps[T]): Leaf[T] = + const(num.fromDouble(i.toDouble)) + + implicit def long2Constant[T](l: Long)(implicit num: NumericOps[T]): Leaf[T] = + const(num.fromDouble(l.toDouble)) + } +}
\ No newline at end of file |