aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-07-02 20:37:31 -0700
committerReynold Xin <rxin@databricks.com>2015-07-02 20:37:31 -0700
commit1a7a7d7d579c5cba104daffbda977915802bf9b9 (patch)
tree1f379c921b8a50e3738368d23802c5533dd1ff9b
parentaa7bbc143844020e4711b3aa4ce75c1b7733a80d (diff)
downloadspark-1a7a7d7d579c5cba104daffbda977915802bf9b9.tar.gz
spark-1a7a7d7d579c5cba104daffbda977915802bf9b9.tar.bz2
spark-1a7a7d7d579c5cba104daffbda977915802bf9b9.zip
[SPARK-8213][SQL]Add function factorial
Author: zhichao.li <zhichao.li@intel.com> Closes #6822 from zhichao-li/factorial and squashes the following commits: 26edf4f [zhichao.li] add factorial
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala80
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala13
5 files changed, 122 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e7e4d1c4ef..9163b032ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -113,6 +113,7 @@ object FunctionRegistry {
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
+ expression[Factorial]("factorial"),
expression[Hypot]("hypot"),
expression[Hex]("hex"),
expression[Logarithm]("log"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 035980da56..701ab9912a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -21,8 +21,10 @@ import java.lang.{Long => JLong}
import java.util.Arrays
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{StringType}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataType, DoubleType, LongType, IntegerType}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -159,6 +161,82 @@ case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXP
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
+object Factorial {
+
+ def factorial(n: Int): Long = {
+ if (n < factorials.length) factorials(n) else Long.MaxValue
+ }
+
+ private val factorials: Array[Long] = Array[Long](
+ 1,
+ 1,
+ 2,
+ 6,
+ 24,
+ 120,
+ 720,
+ 5040,
+ 40320,
+ 362880,
+ 3628800,
+ 39916800,
+ 479001600,
+ 6227020800L,
+ 87178291200L,
+ 1307674368000L,
+ 20922789888000L,
+ 355687428096000L,
+ 6402373705728000L,
+ 121645100408832000L,
+ 2432902008176640000L
+ )
+}
+
+case class Factorial(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[DataType] = Seq(IntegerType)
+
+ override def dataType: DataType = LongType
+
+ override def foldable: Boolean = child.foldable
+
+ // If the value not in the range of [0, 20], it still will be null, so set it to be true here.
+ override def nullable: Boolean = true
+
+ override def toString: String = s"factorial($child)"
+
+ override def eval(input: InternalRow): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ val input = evalE.asInstanceOf[Integer]
+ if (input > 20 || input < 0) {
+ null
+ } else {
+ Factorial.factorial(input)
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val eval = child.gen(ctx)
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ if (${eval.primitive} > 20 || ${eval.primitive} < 0) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} =
+ org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive});
+ }
+ }
+ """
+ }
+}
+
case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log2(child: Expression)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index aa27fe3cd5..8457864d17 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import com.google.common.math.LongMath
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
+import org.apache.spark.sql.types.{DataType, LongType}
+import org.apache.spark.sql.types.{IntegerType, DoubleType}
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -157,6 +160,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Floor, math.floor)
}
+ test("factorial") {
+ val dataLong = (0 to 20)
+ dataLong.foreach { value =>
+ checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
+ }
+ checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null))
+ checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
+ checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
+ }
+
test("rint") {
testUnary(Rint, math.rint)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4ee1fb8374..0d5d49c3dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1023,6 +1023,22 @@ object functions {
def expm1(columnName: String): Column = expm1(Column(columnName))
/**
+ * Computes the factorial of the given value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def factorial(e: Column): Column = Factorial(e.expr)
+
+ /**
+ * Computes the factorial of the given column.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def factorial(columnName: String): Column = factorial(Column(columnName))
+
+ /**
* Computes the floor of the given value.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 4c5696deaf..dc8f994adb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
-
private object MathExpressionsTestData {
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
case class NullDoubles(a: java.lang.Double)
@@ -183,6 +182,18 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(floor, math.floor)
}
+ test("factorial") {
+ val df = (0 to 5).map(i => (i, i)).toDF("a", "b")
+ checkAnswer(
+ df.select(factorial('a)),
+ Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
+ )
+ checkAnswer(
+ df.selectExpr("factorial(a)"),
+ Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
+ )
+ }
+
test("rint") {
testOneToOneMathFunction(rint, math.rint)
}