aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-06-10 09:45:45 -0700
committerReynold Xin <rxin@databricks.com>2015-06-10 09:45:45 -0700
commitc6ba7cca3338e3f4f719d86dbcff4406d949edc7 (patch)
treeb2bb4038929aa10fdc571017fb09c3b59a16d38e
parente90035e676e492de840f44b61b330db526313019 (diff)
downloadspark-c6ba7cca3338e3f4f719d86dbcff4406d949edc7.tar.gz
spark-c6ba7cca3338e3f4f719d86dbcff4406d949edc7.tar.bz2
spark-c6ba7cca3338e3f4f719d86dbcff4406d949edc7.zip
[SPARK-8215] [SPARK-8212] [SQL] add leaf math expression for e and pi
Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #6716 from adrian-wang/epi and squashes the following commits: e2e8dbd [Daoyuan Wang] move tests 11b351c [Daoyuan Wang] add tests and remove pu db331c9 [Daoyuan Wang] py style 599ddd8 [Daoyuan Wang] add py e6783ef [Daoyuan Wang] register function 82d426e [Daoyuan Wang] add function entry dbf3ab5 [Daoyuan Wang] add PI and E
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala35
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala19
5 files changed, 96 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 936ffc7d5f..ba89a5c8d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -106,6 +106,7 @@ object FunctionRegistry {
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
expression[Cos]("cos"),
+ expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
@@ -113,6 +114,7 @@ object FunctionRegistry {
expression[Log]("log"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
+ expression[Pi]("pi"),
expression[Pow]("pow"),
expression[Rint]("rint"),
expression[Signum]("signum"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 7dacb6a9b4..e1d8c9a0cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -21,8 +21,33 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType}
/**
+ * A leaf expression specifically for math constants. Math constants expect no input.
+ * @param c The math constant.
+ * @param name The short name of the function
+ */
+abstract class LeafMathExpression(c: Double, name: String)
+ extends LeafExpression with Serializable {
+ self: Product =>
+
+ override def dataType: DataType = DoubleType
+ override def foldable: Boolean = true
+ override def nullable: Boolean = false
+ override def toString: String = s"$name()"
+
+ override def eval(input: Row): Any = c
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ s"""
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name;
+ """
+ }
+}
+
+/**
* A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`.
+ * @param f The math function.
* @param name The short name of the function
*/
abstract class UnaryMathExpression(f: Double => Double, name: String)
@@ -100,6 +125,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
+// Leaf math functions
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+case class EulerNumber() extends LeafMathExpression(math.E, "E")
+
+case class Pi() extends LeafMathExpression(math.Pi, "PI")
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
// Unary math functions
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 25ebc70d09..1fe69059d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -23,6 +23,20 @@ import org.apache.spark.sql.types.DoubleType
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
/**
+ * Used for testing leaf math expressions.
+ *
+ * @param e expression
+ * @param c The constants in scala.math
+ * @tparam T Generic type for primitives
+ */
+ private def testLeaf[T](
+ e: () => Expression,
+ c: T): Unit = {
+ checkEvaluation(e(), c, EmptyRow)
+ checkEvaluation(e(), c, create_row(null))
+ }
+
+ /**
* Used for testing unary math expressions.
*
* @param c expression
@@ -74,6 +88,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
}
+ test("e") {
+ testLeaf(EulerNumber, math.E)
+ }
+
+ test("pi") {
+ testLeaf(Pi, math.Pi)
+ }
+
test("sin") {
testUnary(Sin, math.sin)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 454af47913..b3fc1e6cd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -945,6 +945,15 @@ object functions {
def cosh(columnName: String): Column = cosh(Column(columnName))
/**
+ * Returns the double value that is closer than any other to e, the base of the natural
+ * logarithms.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def e(): Column = EulerNumber()
+
+ /**
* Computes the exponential of the given value.
*
* @group math_funcs
@@ -1106,6 +1115,15 @@ object functions {
def log1p(columnName: String): Column = log1p(Column(columnName))
/**
+ * Returns the double value that is closer than any other to pi, the ratio of the circumference
+ * of a circle to its diameter.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pi(): Column = Pi()
+
+ /**
* Returns the value of the first argument raised to the power of the second argument.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 53c2befb73..b93ad39f5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -85,6 +85,25 @@ class DataFrameFunctionsSuite extends QueryTest {
}
}
+ test("constant functions") {
+ checkAnswer(
+ testData2.select(e()).limit(1),
+ Row(scala.math.E)
+ )
+ checkAnswer(
+ testData2.select(pi()).limit(1),
+ Row(scala.math.Pi)
+ )
+ checkAnswer(
+ ctx.sql("SELECT E()"),
+ Row(scala.math.E)
+ )
+ checkAnswer(
+ ctx.sql("SELECT PI()"),
+ Row(scala.math.Pi)
+ )
+ }
+
test("bitwiseNOT") {
checkAnswer(
testData2.select(bitwiseNOT($"a")),