aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-06-10 20:22:32 -0700
committerReynold Xin <rxin@databricks.com>2015-06-10 20:22:32 -0700
commit2758ff0a96f03a61e10999b2462acf7a13236b7c (patch)
tree62bfdf3fcf69c961c58b2c5c8a006f61204fcfcb
parent9fe3adccef687c92ff1ac17d946af089c8e28d66 (diff)
downloadspark-2758ff0a96f03a61e10999b2462acf7a13236b7c.tar.gz
spark-2758ff0a96f03a61e10999b2462acf7a13236b7c.tar.bz2
spark-2758ff0a96f03a61e10999b2462acf7a13236b7c.zip
[SPARK-8217] [SQL] math function log2
Author: Daoyuan Wang <daoyuan.wang@intel.com> This patch had conflicts when merged, resolved by Committer: Reynold Xin <rxin@databricks.com> Closes #6718 from adrian-wang/udflog2 and squashes the following commits: 3909f48 [Daoyuan Wang] math function: log2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala12
5 files changed, 54 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 39875d7f21..a7816e3275 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -111,6 +111,7 @@ object FunctionRegistry {
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Pi]("pi"),
+ expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Rint]("rint"),
expression[Signum]("signum"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index e1d8c9a0cd..97e960b8d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
+case class Log2(child: Expression)
+ extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val eval = child.gen(ctx)
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
+}
+
case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10")
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 1fe69059d3..864c954ee8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
}
+ test("log2") {
+ def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
+ testUnary(Log2, f, (0 to 20).map(_ * 0.1))
+ testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
+ }
+
test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 083f6b6bce..c5b77724aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1084,7 +1084,7 @@ object functions {
def log(columnName: String): Column = log(Column(columnName))
/**
- * Computes the logarithm of the given value in Base 10.
+ * Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
@@ -1092,7 +1092,7 @@ object functions {
def log10(e: Column): Column = Log10(e.expr)
/**
- * Computes the logarithm of the given value in Base 10.
+ * Computes the logarithm of the given value in base 10.
*
* @group math_funcs
* @since 1.4.0
@@ -1125,6 +1125,22 @@ object functions {
def pi(): Column = Pi()
/**
+ * Computes the logarithm of the given column in base 2.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def log2(expr: Column): Column = Log2(expr.expr)
+
+ /**
+ * Computes the logarithm of the given value in base 2.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def log2(columnName: String): Column = log2(Column(columnName))
+
+ /**
* Returns the value of the first argument raised to the power of the second argument.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 171a2151e6..659b64c185 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -128,5 +128,17 @@ class DataFrameFunctionsSuite extends QueryTest {
})
}
+ test("log2 functions test") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ checkAnswer(
+ df.select(log2("b") + log2("a")),
+ Row(1))
+ checkAnswer(
+ ctx.sql("SELECT LOG2(8)"),
+ Row(3))
+ checkAnswer(
+ ctx.sql("SELECT LOG2(null)"),
+ Row(null))
+ }
}