aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-02 10:02:19 -0700
committerDavies Liu <davies@databricks.com>2015-07-02 10:02:19 -0700
commit5b3338130dfd9db92c4894a348839a62ebb57ef3 (patch)
tree13b6289381caca923444107940da2ed5abba7628 /sql/core
parent0a468a46bf5b905e9b0205e98b862570b2ac556e (diff)
downloadspark-5b3338130dfd9db92c4894a348839a62ebb57ef3.tar.gz
spark-5b3338130dfd9db92c4894a348839a62ebb57ef3.tar.bz2
spark-5b3338130dfd9db92c4894a348839a62ebb57ef3.zip
[SPARK-8223] [SPARK-8224] [SQL] shift left and shift right
Jira: https://issues.apache.org/jira/browse/SPARK-8223 https://issues.apache.org/jira/browse/SPARK-8224 ~~I am aware of #7174 and will update this pr, if it's merged.~~ Done I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception). If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX` Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7178 from tarekauel/8223 and squashes the following commits: 8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description 3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix 9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala34
2 files changed, 72 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e6f623bdf3..a5b6828685 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1299,6 +1299,44 @@ object functions {
def rint(columnName: String): Column = rint(Column(columnName))
/**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(columnName: String, numBits: Int): Column =
+ shiftLeft(Column(columnName), numBits)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(columnName: String, numBits: Int): Column =
+ shiftRight(Column(columnName), numBits)
+
+ /**
* Computes the signum of the given value.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c03cde38d7..4c5696deaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("shift left") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
+ shiftLeft('f, 1)),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
+ "shiftLeft(f, 1)"),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+ }
+
+ test("shift right") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
+ shiftRight('f, 1)),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
+ "shiftRight(f, 1)"),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+ }
+
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(