diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-06-17 23:31:30 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-06-17 23:31:30 -0700 |
commit | fee3438a32136a8edbca71efb566965587a88826 (patch) | |
tree | 2fa725811cd231a35130bcc0e75b22df1257c73d /sql | |
parent | 78a430ea4d2aef58a8bf38ce488553ca6acea428 (diff) | |
download | spark-fee3438a32136a8edbca71efb566965587a88826.tar.gz spark-fee3438a32136a8edbca71efb566965587a88826.tar.bz2 spark-fee3438a32136a8edbca71efb566965587a88826.zip |
[SPARK-8218][SQL] Add binary log math function
JIRA: https://issues.apache.org/jira/browse/SPARK-8218
Because there is already `log` unary function defined, the binary log function is called `logarithm` for now.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #6725 from viirya/expr_binary_log and squashes the following commits:
bf96bd9 [Liang-Chi Hsieh] Compare log result in string.
102070d [Liang-Chi Hsieh] Round log result to better comparing in python test.
fd01863 [Liang-Chi Hsieh] For comments.
beed631 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
6089d11 [Liang-Chi Hsieh] Remove unnecessary override.
8cf37b7 [Liang-Chi Hsieh] For comments.
bc89597 [Liang-Chi Hsieh] For comments.
db7dc38 [Liang-Chi Hsieh] Use ctor instead of companion object.
0634ef7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
1750034 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
3d75bfc [Liang-Chi Hsieh] Fix scala style.
5b39c02 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
23c54a3 [Liang-Chi Hsieh] Fix scala style.
ebc9929 [Liang-Chi Hsieh] Let Logarithm accept one parameter too.
605574d [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
21c3bfd [Liang-Chi Hsieh] Fix scala style.
c6c187f [Liang-Chi Hsieh] For comments.
c795342 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log
f373bac [Liang-Chi Hsieh] Add binary log expression.
Diffstat (limited to 'sql')
5 files changed, 68 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 97b123ec2f..13b2bb05f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -112,6 +112,7 @@ object FunctionRegistry { expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), + expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 42c596b5b3..67cb0b508c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -255,3 +255,23 @@ case class Pow(left: Expression, right: Expression) """ } } + +case class Logarithm(left: Expression, right: Expression) + extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + def this(child: Expression) = { + this(EulerNumber(), child) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val logCode = if (left.isInstanceOf[EulerNumber]) { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + } + logCode + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 864c954ee8..0050ad3fe8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -204,4 +204,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Atan2, math.atan2) } + test("binary log") { + val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) + val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) + + domain.foreach { case (v1, v2) => + checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) + } + checkEvaluation( + Logarithm(Literal.create(null, DoubleType), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal.create(null, DoubleType)), + null, + create_row(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c5b77724aa..dff0932c45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1084,6 +1084,22 @@ object functions { def log(columnName: String): Column = log(Column(columnName)) /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + + /** * Computes the logarithm of the given value in base 10. * * @group math_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index e2daaf6b73..7c9c121b95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("binary log") { + val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") + checkAnswer( + df.select(org.apache.spark.sql.functions.log("a"), + org.apache.spark.sql.functions.log(2.0, "a"), + org.apache.spark.sql.functions.log("b")), + Row(math.log(123), math.log(123) / math.log(2), null)) + + checkAnswer( + df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + Row(math.log(123), math.log(123) / math.log(2), null)) + } + test("abs") { val input = Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) |