aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-27 23:10:14 -0700
committerReynold Xin <rxin@databricks.com>2015-04-27 23:10:14 -0700
commit29576e786072bd4218e10036ddfc8d367b1c1446 (patch)
tree94bacf8192903166469c9d74887b695f9d9e8017 /sql/catalyst
parent874a2ca93d095a0dfa1acfdacf0e9d80388c4422 (diff)
downloadspark-29576e786072bd4218e10036ddfc8d367b1c1446.tar.gz
spark-29576e786072bd4218e10036ddfc8d367b1c1446.tar.bz2
spark-29576e786072bd4218e10036ddfc8d367b1c1446.zip
[SPARK-6829] Added math functions for DataFrames
Implemented almost all math functions found in scala.math (max, min and abs were already present). cc mengxr marmbrus Author: Burak Yavuz <brkyvz@gmail.com> Closes #5616 from brkyvz/math-udfs and squashes the following commits: fb27153 [Burak Yavuz] reverted exception message 836a098 [Burak Yavuz] fixed test and addressed small comment e5f0d13 [Burak Yavuz] addressed code review v2.2 b26c5fb [Burak Yavuz] addressed review v2.1 2761f08 [Burak Yavuz] addressed review v2 6588a5b [Burak Yavuz] fixed merge conflicts b084e10 [Burak Yavuz] Addressed code review 029e739 [Burak Yavuz] fixed atan2 test 534cc11 [Burak Yavuz] added more tests, addressed comments fa68dbe [Burak Yavuz] added double specific test data 937d5a5 [Burak Yavuz] use doubles instead of ints 8e28fff [Burak Yavuz] Added apache header 7ec8f7f [Burak Yavuz] Added math functions for DataFrames
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala168
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala165
5 files changed, 455 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 35c7f00d4e..73c9a1c7af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -79,6 +79,7 @@ trait HiveTypeCoercion {
CaseWhenCoercion ::
Division ::
PropagateTypes ::
+ ExpectedInputConversion ::
Nil
/**
@@ -643,4 +644,22 @@ trait HiveTypeCoercion {
}
}
+ /**
+ * Casts types according to the expected input types for Expressions that have the trait
+ * `ExpectsInputTypes`.
+ */
+ object ExpectedInputConversion extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
+ val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
+ case (child, actual, expected) =>
+ if (actual == expected) child else Cast(child, expected)
+ }
+ e.withNewChildren(newC)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 4e3bbc06a5..1d71c1b4b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = false
override def dataType: DataType = throw new UnsupportedOperationException
}
+
+/**
+ * Expressions that require a specific `DataType` as input should implement this trait
+ * so that the proper type conversions can be performed in the analyzer.
+ */
+trait ExpectsInputTypes {
+
+ def expectedChildTypes: Seq[DataType]
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
new file mode 100644
index 0000000000..5b4d912a64
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.mathfuncs
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row}
+import org.apache.spark.sql.types._
+
+/**
+ * A binary expression specifically for math functions that take two `Double`s as input and returns
+ * a `Double`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
+ extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
+ type EvaluatedType = Any
+ override def symbol: String = null
+ override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
+
+ override def nullable: Boolean = left.nullable || right.nullable
+ override def toString: String = s"$name($left, $right)"
+
+ override lazy val resolved =
+ left.resolved && right.resolved &&
+ left.dataType == right.dataType &&
+ !DecimalType.isFixed(left.dataType)
+
+ override def dataType: DataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this,
+ s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
+ }
+ left.dataType
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
+ if (result.isNaN) null else result
+ }
+ }
+ }
+}
+
+case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
+
+case class Hypot(
+ left: Expression,
+ right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")
+
+case class Atan2(
+ left: Expression,
+ right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") {
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
+ val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
+ evalE2.asInstanceOf[Double] + 0.0)
+ if (result.isNaN) null else result
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
new file mode 100644
index 0000000000..96cb77d487
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.mathfuncs
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression}
+import org.apache.spark.sql.types._
+
+/**
+ * A unary expression specifically for math functions. Math Functions expect a specific type of
+ * input format, therefore these functions extend `ExpectsInputTypes`.
+ * @param name The short name of the function
+ */
+abstract class MathematicalExpression(name: String)
+ extends UnaryExpression with Serializable with ExpectsInputTypes {
+ self: Product =>
+ type EvaluatedType = Any
+
+ override def dataType: DataType = DoubleType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = true
+ override def toString: String = s"$name($child)"
+}
+
+/**
+ * A unary expression specifically for math functions that take a `Double` as input and return
+ * a `Double`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class MathematicalExpressionForDouble(f: Double => Double, name: String)
+ extends MathematicalExpression(name) { self: Product =>
+
+ override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
+
+ override def eval(input: Row): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ val result = f(evalE.asInstanceOf[Double])
+ if (result.isNaN) null else result
+ }
+ }
+}
+
+/**
+ * A unary expression specifically for math functions that take an `Int` as input and return
+ * an `Int`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class MathematicalExpressionForInt(f: Int => Int, name: String)
+ extends MathematicalExpression(name) { self: Product =>
+
+ override def dataType: DataType = IntegerType
+ override def expectedChildTypes: Seq[DataType] = Seq(IntegerType)
+
+ override def eval(input: Row): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) null else f(evalE.asInstanceOf[Int])
+ }
+}
+
+/**
+ * A unary expression specifically for math functions that take a `Float` as input and return
+ * a `Float`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class MathematicalExpressionForFloat(f: Float => Float, name: String)
+ extends MathematicalExpression(name) { self: Product =>
+
+ override def dataType: DataType = FloatType
+ override def expectedChildTypes: Seq[DataType] = Seq(FloatType)
+
+ override def eval(input: Row): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ val result = f(evalE.asInstanceOf[Float])
+ if (result.isNaN) null else result
+ }
+ }
+}
+
+/**
+ * A unary expression specifically for math functions that take a `Long` as input and return
+ * a `Long`.
+ * @param f The math function.
+ * @param name The short name of the function
+ */
+abstract class MathematicalExpressionForLong(f: Long => Long, name: String)
+ extends MathematicalExpression(name) { self: Product =>
+
+ override def dataType: DataType = LongType
+ override def expectedChildTypes: Seq[DataType] = Seq(LongType)
+
+ override def eval(input: Row): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) null else f(evalE.asInstanceOf[Long])
+ }
+}
+
+case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN")
+
+case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN")
+
+case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH")
+
+case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS")
+
+case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS")
+
+case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH")
+
+case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN")
+
+case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN")
+
+case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH")
+
+case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL")
+
+case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR")
+
+case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND")
+
+case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT")
+
+case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM")
+
+case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM")
+
+case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM")
+
+case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM")
+
+case class ToDegrees(child: Expression)
+ extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES")
+
+case class ToRadians(child: Expression)
+ extends MathematicalExpressionForDouble(math.toRadians, "RADIANS")
+
+case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG")
+
+case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10")
+
+case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P")
+
+case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP")
+
+case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 76298f03c9..5390ce43c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.Matchers._
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.mathfuncs._
import org.apache.spark.sql.types._
@@ -1152,6 +1153,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(c1 ^ c2, 3, row)
checkEvaluation(~c1, -2, row)
}
+
+ /**
+ * Used for testing math functions for DataFrames.
+ * @param c The DataFrame function
+ * @param f The functions in scala.math
+ * @param domain The set of values to run the function with
+ * @param expectNull Whether the given values should return null or not
+ * @tparam T Generic type for primitives
+ */
+ def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T](
+ c: Expression => Expression,
+ f: T => T,
+ domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
+ expectNull: Boolean = false): Unit = {
+ if (expectNull) {
+ domain.foreach { value =>
+ checkEvaluation(c(Literal(value)), null, EmptyRow)
+ }
+ } else {
+ domain.foreach { value =>
+ checkEvaluation(c(Literal(value)), f(value), EmptyRow)
+ }
+ }
+ checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null))
+ }
+
+ test("sin") {
+ unaryMathFunctionEvaluation(Sin, math.sin)
+ }
+
+ test("asin") {
+ unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1))
+ unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true)
+ }
+
+ test("sinh") {
+ unaryMathFunctionEvaluation(Sinh, math.sinh)
+ }
+
+ test("cos") {
+ unaryMathFunctionEvaluation(Cos, math.cos)
+ }
+
+ test("acos") {
+ unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1))
+ unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true)
+ }
+
+ test("cosh") {
+ unaryMathFunctionEvaluation(Cosh, math.cosh)
+ }
+
+ test("tan") {
+ unaryMathFunctionEvaluation(Tan, math.tan)
+ }
+
+ test("atan") {
+ unaryMathFunctionEvaluation(Atan, math.atan)
+ }
+
+ test("tanh") {
+ unaryMathFunctionEvaluation(Tanh, math.tanh)
+ }
+
+ test("toDeg") {
+ unaryMathFunctionEvaluation(ToDegrees, math.toDegrees)
+ }
+
+ test("toRad") {
+ unaryMathFunctionEvaluation(ToRadians, math.toRadians)
+ }
+
+ test("cbrt") {
+ unaryMathFunctionEvaluation(Cbrt, math.cbrt)
+ }
+
+ test("ceil") {
+ unaryMathFunctionEvaluation(Ceil, math.ceil)
+ }
+
+ test("floor") {
+ unaryMathFunctionEvaluation(Floor, math.floor)
+ }
+
+ test("rint") {
+ unaryMathFunctionEvaluation(Rint, math.rint)
+ }
+
+ test("exp") {
+ unaryMathFunctionEvaluation(Exp, math.exp)
+ }
+
+ test("expm1") {
+ unaryMathFunctionEvaluation(Expm1, math.expm1)
+ }
+
+ test("signum") {
+ unaryMathFunctionEvaluation[Double](Signum, math.signum)
+ }
+
+ test("isignum") {
+ unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5))
+ }
+
+ test("fsignum") {
+ unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat))
+ }
+
+ test("lsignum") {
+ unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong))
+ }
+
+ test("log") {
+ unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1))
+ unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true)
+ }
+
+ test("log10") {
+ unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1))
+ unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true)
+ }
+
+ test("log1p") {
+ unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1))
+ unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true)
+ }
+
+ /**
+ * Used for testing math functions for DataFrames.
+ * @param c The DataFrame function
+ * @param f The functions in scala.math
+ * @param domain The set of values to run the function with
+ */
+ def binaryMathFunctionEvaluation(
+ c: (Expression, Expression) => Expression,
+ f: (Double, Double) => Double,
+ domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
+ expectNull: Boolean = false): Unit = {
+ if (expectNull) {
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(c(v1, v2), null, create_row(null))
+ }
+ } else {
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow)
+ checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow)
+ }
+ }
+ checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null))
+ checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null))
+ }
+
+ test("pow") {
+ binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
+ binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true)
+ }
+
+ test("hypot") {
+ binaryMathFunctionEvaluation(Hypot, math.hypot)
+ }
+
+ test("atan2") {
+ binaryMathFunctionEvaluation(Atan2, math.atan2)
+ }
}
// TODO: Make the tests work with codegen.