From a9385271a9f6b97ec6aa619cf56ee556ba2fb0de Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 15 Jul 2015 10:43:38 -0700 Subject: [SPARK-8221][SQL]Add pmod function https://issues.apache.org/jira/browse/SPARK-8221 One concern is the result would be negative if the divisor is not positive( i.e pmod(7, -3) ), but the behavior is the same as hive. Author: zhichao.li Closes #6783 from zhichao-li/pmod2 and squashes the following commits: 7083eb9 [zhichao.li] update to the latest type checking d26dba7 [zhichao.li] add pmod --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/HiveTypeCoercion.scala | 6 ++ .../sql/catalyst/expressions/arithmetic.scala | 94 ++++++++++++++++++++++ .../expressions/ArithmeticExpressionSuite.scala | 16 +++- .../scala/org/apache/spark/sql/functions.scala | 17 ++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 37 +++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ec75f51d5e..d2678ce860 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -115,6 +115,7 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Round]("round"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15da5eecc8..25087915b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -426,6 +426,12 @@ object HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + // When we compare 2 decimal types with different precisions, cast them to the smallest // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1a55a0876f..394ef556e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "min" override def prettyName: String = symbol } + +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { + + override def toString: String = s"pmod($left, $right)" + + override def symbol: String = "pmod" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "pmod") + + override def inputType: AbstractDataType = NumericType + + protected override def nullSafeEval(left: Any, right: Any) = + dataType match { + case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + dataType match { + case dt: DecimalType => + val decimalAdd = "$plus" + s""" + ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); + if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + } else { + ${ev.primitive} = r; + } + """ + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + s""" + ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if (r < 0) { + ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + } else { + ${ev.primitive} = r; + } + """ + case _ => + s""" + ${ctx.javaType(dataType)} r = $eval1 % $eval2; + if (r < 0) { + ${ev.primitive} = (r + $eval2) % $eval2; + } else { + ${ev.primitive} = r; + } + """ + } + }) + } + + private def pmod(a: Int, n: Int): Int = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Long, n: Long): Long = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Byte, n: Byte): Byte = { + val r = a % n + if (r < 0) {((r + n) % n).toByte} else r.toByte + } + + private def pmod(a: Double, n: Double): Double = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Short, n: Short): Short = { + val r = a % n + if (r < 0) {((r + n) % n).toShort} else r.toShort + } + + private def pmod(a: Float, n: Float): Float = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Decimal, n: Decimal): Decimal = { + val r = a % n + if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6c93698f80..e7e5231d32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.Decimal - class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), Array(1.toByte, 2.toByte)) } + + test("pmod") { + testNumericDataTypes { convert => + val left = Literal(convert(7)) + val right = Literal(convert(3)) + checkEvaluation(Pmod(left, right), convert(1)) + checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) + checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + checkEvaluation(Pmod(-7, 3), 2) + checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) + checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) + checkEvaluation(Pmod(2L, Long.MaxValue), 2) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5119ee31d8..c7deaca843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1371,6 +1371,23 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividendColName: String, divisorColName: String): Column = + pmod(Column(dividendColName), Column(divisorColName)) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cebec95d2..70bd78737f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") + checkAnswer( + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } } -- cgit v1.2.3