From 5b3338130dfd9db92c4894a348839a62ebb57ef3 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 2 Jul 2015 10:02:19 -0700 Subject: [SPARK-8223] [SPARK-8224] [SQL] shift left and shift right Jira: https://issues.apache.org/jira/browse/SPARK-8223 https://issues.apache.org/jira/browse/SPARK-8224 ~~I am aware of #7174 and will update this pr, if it's merged.~~ Done I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception). If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX` Author: Tarek Auel Closes #7178 from tarekauel/8223 and squashes the following commits: 8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description 3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix 9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift --- .../scala/org/apache/spark/sql/functions.scala | 38 ++++++++++++++++++++++ .../apache/spark/sql/MathExpressionsSuite.scala | 34 +++++++++++++++++++ 2 files changed, 72 insertions(+) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6f623bdf3..a5b6828685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1298,6 +1298,44 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(columnName: String, numBits: Int): Column = + shiftLeft(Column(columnName), numBits) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(columnName: String, numBits: Int): Column = + shiftRight(Column(columnName), numBits) + /** * Computes the signum of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c03cde38d7..4c5696deaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( -- cgit v1.2.3