aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala17
2 files changed, 37 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d5d49c3dd..4b70dc5fdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1344,6 +1344,26 @@ object functions {
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
/**
+ * Unsigned shift the the given value numBits right. If the given value is a long value,
+ * it will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRightUnsigned(columnName: String, numBits: Int): Column =
+ shiftRightUnsigned(Column(columnName), numBits)
+
+ /**
+ * Unsigned shift the the given value numBits right. If the given value is a long value,
+ * it will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRightUnsigned(e: Column, numBits: Int): Column =
+ ShiftRightUnsigned(e.expr, lit(numBits).expr)
+
+ /**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index dc8f994adb..24bef21b99 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}
+ test("shift right unsigned") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
+ shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
+ Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
+ "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
+ Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+ }
+
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(