aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala49
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala17
6 files changed, 113 insertions, 0 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 12263e6a75..69e563ef36 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -436,6 +436,19 @@ def shiftRight(col, numBits):
return Column(jc)
+@since(1.5)
+def shiftRightUnsigned(col, numBits):
+ """Unsigned shift the the given value numBits right.
+
+ >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
+ .collect()
+ [Row(r=9223372036854775787)]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
+ return Column(jc)
+
+
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9163b032ad..cd5ba1217c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -129,6 +129,7 @@ object FunctionRegistry {
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
+ expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 273a6c5016..0fc320fb08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}
+case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+ case (_, IntegerType) => left.dataType match {
+ case LongType | IntegerType | ShortType | ByteType =>
+ return TypeCheckResult.TypeCheckSuccess
+ case _ => // failed
+ }
+ case _ => // failed
+ }
+ TypeCheckResult.TypeCheckFailure(
+ s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
+ s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val valueLeft = left.eval(input)
+ if (valueLeft != null) {
+ val valueRight = right.eval(input)
+ if (valueRight != null) {
+ valueLeft match {
+ case l: Long => l >>> valueRight.asInstanceOf[Integer]
+ case i: Integer => i >>> valueRight.asInstanceOf[Integer]
+ case s: Short => s >>> valueRight.asInstanceOf[Integer]
+ case b: Byte => b >>> valueRight.asInstanceOf[Integer]
+ }
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def dataType: DataType = {
+ left.dataType match {
+ case LongType => LongType
+ case IntegerType | ShortType | ByteType => IntegerType
+ case _ => NullType
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
+ }
+}
+
/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 8457864d17..20839c83d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}
+ test("shift right unsigned") {
+ checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
+ checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
+ checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
+ checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
+
+ checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
+ }
+
test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d5d49c3dd..4b70dc5fdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1344,6 +1344,26 @@ object functions {
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
/**
+ * Unsigned shift the the given value numBits right. If the given value is a long value,
+ * it will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRightUnsigned(columnName: String, numBits: Int): Column =
+ shiftRightUnsigned(Column(columnName), numBits)
+
+ /**
+ * Unsigned shift the the given value numBits right. If the given value is a long value,
+ * it will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRightUnsigned(e: Column, numBits: Int): Column =
+ ShiftRightUnsigned(e.expr, lit(numBits).expr)
+
+ /**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index dc8f994adb..24bef21b99 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}
+ test("shift right unsigned") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
+ shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
+ Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
+ "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
+ Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+ }
+
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(