From 1a7a7d7d579c5cba104daffbda977915802bf9b9 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 2 Jul 2015 20:37:31 -0700 Subject: [SPARK-8213][SQL]Add function factorial Author: zhichao.li Closes #6822 from zhichao-li/factorial and squashes the following commits: 26edf4f [zhichao.li] add factorial --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 80 +++++++++++++++++++++- .../catalyst/expressions/MathFunctionsSuite.scala | 15 +++- .../scala/org/apache/spark/sql/functions.scala | 16 +++++ .../apache/spark/sql/MathExpressionsSuite.scala | 13 +++- 5 files changed, 122 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7e4d1c4ef..9163b032ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), + expression[Factorial]("factorial"), expression[Hypot]("hypot"), expression[Hex]("hex"), expression[Logarithm]("log"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 035980da56..701ab9912a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -21,8 +21,10 @@ import java.lang.{Long => JLong} import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{StringType} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType} import org.apache.spark.unsafe.types.UTF8String /** @@ -159,6 +161,82 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") +object Factorial { + + def factorial(n: Int): Long = { + if (n < factorials.length) factorials(n) else Long.MaxValue + } + + private val factorials: Array[Long] = Array[Long]( + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800L, + 87178291200L, + 1307674368000L, + 20922789888000L, + 355687428096000L, + 6402373705728000L, + 121645100408832000L, + 2432902008176640000L + ) +} + +case class Factorial(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def dataType: DataType = LongType + + override def foldable: Boolean = child.foldable + + // If the value not in the range of [0, 20], it still will be null, so set it to be true here. + override def nullable: Boolean = true + + override def toString: String = s"factorial($child)" + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val input = evalE.asInstanceOf[Integer] + if (input > 20 || input < 0) { + null + } else { + Factorial.factorial(input) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} > 20 || ${eval.primitive} < 0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = + org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive}); + } + } + """ + } +} + case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") case class Log2(child: Expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index aa27fe3cd5..8457864d17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import com.google.common.math.LongMath + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType} +import org.apache.spark.sql.types.{DataType, LongType} +import org.apache.spark.sql.types.{IntegerType, DoubleType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -157,6 +160,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Floor, math.floor) } + test("factorial") { + val dataLong = (0 to 20) + dataLong.foreach { value => + checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) + } + checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null)) + checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) + checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + } + test("rint") { testUnary(Rint, math.rint) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4ee1fb8374..0d5d49c3dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1022,6 +1022,22 @@ object functions { */ def expm1(columnName: String): Column = expm1(Column(columnName)) + /** + * Computes the factorial of the given value. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(e: Column): Column = Factorial(e.expr) + + /** + * Computes the factorial of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(columnName: String): Column = factorial(Column(columnName)) + /** * Computes the floor of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 4c5696deaf..dc8f994adb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} - private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) @@ -183,6 +182,18 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(floor, math.floor) } + test("factorial") { + val df = (0 to 5).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(factorial('a)), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + checkAnswer( + df.selectExpr("factorial(a)"), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + } + test("rint") { testOneToOneMathFunction(rint, math.rint) } -- cgit v1.2.3