aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-07-15 10:43:38 -0700
committerReynold Xin <rxin@databricks.com>2015-07-15 10:43:38 -0700
commita9385271a9f6b97ec6aa619cf56ee556ba2fb0de (patch)
treeee3d548fcf187df1445d5b91e68d7511a6d17e89 /sql
parentfa4ec3606a965238423f977808163983c9d56e0a (diff)
downloadspark-a9385271a9f6b97ec6aa619cf56ee556ba2fb0de.tar.gz
spark-a9385271a9f6b97ec6aa619cf56ee556ba2fb0de.tar.bz2
spark-a9385271a9f6b97ec6aa619cf56ee556ba2fb0de.zip
[SPARK-8221][SQL]Add pmod function
https://issues.apache.org/jira/browse/SPARK-8221 One concern is the result would be negative if the divisor is not positive( i.e pmod(7, -3) ), but the behavior is the same as hive. Author: zhichao.li <zhichao.li@intel.com> Closes #6783 from zhichao-li/pmod2 and squashes the following commits: 7083eb9 [zhichao.li] update to the latest type checking d26dba7 [zhichao.li] add pmod
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala94
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala37
6 files changed, 170 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ec75f51d5e..d2678ce860 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -115,6 +115,7 @@ object FunctionRegistry {
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Pow]("power"),
+ expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15da5eecc8..25087915b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -426,6 +426,12 @@ object HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
+ case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ )
+
// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1a55a0876f..394ef556e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "min"
override def prettyName: String = symbol
}
+
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+
+ override def toString: String = s"pmod($left, $right)"
+
+ override def symbol: String = "pmod"
+
+ protected def checkTypesInternal(t: DataType) =
+ TypeUtils.checkForNumericExpr(t, "pmod")
+
+ override def inputType: AbstractDataType = NumericType
+
+ protected override def nullSafeEval(left: Any, right: Any) =
+ dataType match {
+ case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
+ case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
+ case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short])
+ case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
+ case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float])
+ case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double])
+ case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal])
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ dataType match {
+ case dt: DecimalType =>
+ val decimalAdd = "$plus"
+ s"""
+ ${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
+ if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
+ ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ // byte and short are casted into int when add, minus, times or divide
+ case ByteType | ShortType =>
+ s"""
+ ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
+ if (r < 0) {
+ ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ case _ =>
+ s"""
+ ${ctx.javaType(dataType)} r = $eval1 % $eval2;
+ if (r < 0) {
+ ${ev.primitive} = (r + $eval2) % $eval2;
+ } else {
+ ${ev.primitive} = r;
+ }
+ """
+ }
+ })
+ }
+
+ private def pmod(a: Int, n: Int): Int = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Long, n: Long): Long = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Byte, n: Byte): Byte = {
+ val r = a % n
+ if (r < 0) {((r + n) % n).toByte} else r.toByte
+ }
+
+ private def pmod(a: Double, n: Double): Double = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Short, n: Short): Short = {
+ val r = a % n
+ if (r < 0) {((r + n) % n).toShort} else r.toShort
+ }
+
+ private def pmod(a: Float, n: Float): Float = {
+ val r = a % n
+ if (r < 0) {(r + n) % n} else r
+ }
+
+ private def pmod(a: Decimal, n: Decimal): Decimal = {
+ val r = a % n
+ if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 6c93698f80..e7e5231d32 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.Decimal
-
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
/**
@@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
Array(1.toByte, 2.toByte))
}
+
+ test("pmod") {
+ testNumericDataTypes { convert =>
+ val left = Literal(convert(7))
+ val right = Literal(convert(3))
+ checkEvaluation(Pmod(left, right), convert(1))
+ checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
+ checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
+ checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
+ }
+ checkEvaluation(Pmod(-7, 3), 2)
+ checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
+ checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
+ checkEvaluation(Pmod(2L, Long.MaxValue), 2)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5119ee31d8..c7deaca843 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1372,6 +1372,23 @@ object functions {
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))
/**
+ * Returns the positive value of dividend mod divisor.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr)
+
+ /**
+ * Returns the positive value of dividend mod divisor.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def pmod(dividendColName: String, divisorColName: String): Column =
+ pmod(Column(dividendColName), Column(divisorColName))
+
+ /**
* Returns the double value that is closest in value to the argument and
* is equal to a mathematical integer.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6cebec95d2..70bd78737f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest {
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}
+
+ test("pmod") {
+ val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
+ checkAnswer(
+ intData.select(pmod('a, 'b)),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.select(pmod('a, lit(3))),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.select(pmod(lit(-7), 'b)),
+ Seq(Row(2), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(a, b)"),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(a, 3)"),
+ Seq(Row(1), Row(2))
+ )
+ checkAnswer(
+ intData.selectExpr("pmod(-7, b)"),
+ Seq(Row(2), Row(2))
+ )
+ val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
+ checkAnswer(
+ doubleData.select(pmod('a, 'b)),
+ Seq(Row(3.1000000000000005)) // same as hive
+ )
+ checkAnswer(
+ doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
+ Seq(Row(2))
+ )
+ }
}