From ab535b9a1dab40ea7335ff9abb9b522fc2b5ed66 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 3 Jul 2015 15:39:16 -0700 Subject: [SPARK-8226] [SQL] Add function shiftrightunsigned Author: zhichao.li Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits: 6bcca5a [zhichao.li] change coding style 3e9f5ae [zhichao.li] python style d85ae0b [zhichao.li] add shiftrightunsigned --- .../main/scala/org/apache/spark/sql/functions.scala | 20 ++++++++++++++++++++ .../org/apache/spark/sql/MathExpressionsSuite.scala | 17 +++++++++++++++++ 2 files changed, 37 insertions(+) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d5d49c3dd..4b70dc5fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1343,6 +1343,26 @@ object functions { */ def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(columnName: String, numBits: Int): Column = + shiftRightUnsigned(Column(columnName), numBits) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + /** * Shift the the given value numBits right. If the given value is a long value, it will return * a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dc8f994adb..24bef21b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest { Row(21.toLong, 21, 21.toShort, 21.toByte, null)) } + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( -- cgit v1.2.3