aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-02 10:02:19 -0700
committerDavies Liu <davies@databricks.com>2015-07-02 10:02:19 -0700
commit5b3338130dfd9db92c4894a348839a62ebb57ef3 (patch)
tree13b6289381caca923444107940da2ed5abba7628 /sql
parent0a468a46bf5b905e9b0205e98b862570b2ac556e (diff)
downloadspark-5b3338130dfd9db92c4894a348839a62ebb57ef3.tar.gz
spark-5b3338130dfd9db92c4894a348839a62ebb57ef3.tar.bz2
spark-5b3338130dfd9db92c4894a348839a62ebb57ef3.zip
[SPARK-8223] [SPARK-8224] [SQL] shift left and shift right
Jira: https://issues.apache.org/jira/browse/SPARK-8223 https://issues.apache.org/jira/browse/SPARK-8224 ~~I am aware of #7174 and will update this pr, if it's merged.~~ Done I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception). If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX` Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7178 from tarekauel/8223 and squashes the following commits: 8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description 3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix 9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223 44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala98
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala34
5 files changed, 199 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6f04298d47..aa051b1633 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -125,6 +125,8 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
+ expression[ShiftLeft]("shiftleft"),
+ expression[ShiftRight]("shiftright"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 8633eb06ff..7504c6a066 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -351,6 +351,104 @@ case class Pow(left: Expression, right: Expression)
}
}
+case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+ case (_, IntegerType) => left.dataType match {
+ case LongType | IntegerType | ShortType | ByteType =>
+ return TypeCheckResult.TypeCheckSuccess
+ case _ => // failed
+ }
+ case _ => // failed
+ }
+ TypeCheckResult.TypeCheckFailure(
+ s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
+ s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val valueLeft = left.eval(input)
+ if (valueLeft != null) {
+ val valueRight = right.eval(input)
+ if (valueRight != null) {
+ valueLeft match {
+ case l: Long => l << valueRight.asInstanceOf[Integer]
+ case i: Integer => i << valueRight.asInstanceOf[Integer]
+ case s: Short => s << valueRight.asInstanceOf[Integer]
+ case b: Byte => b << valueRight.asInstanceOf[Integer]
+ }
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def dataType: DataType = {
+ left.dataType match {
+ case LongType => LongType
+ case IntegerType | ShortType | ByteType => IntegerType
+ case _ => NullType
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
+ }
+}
+
+case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+ case (_, IntegerType) => left.dataType match {
+ case LongType | IntegerType | ShortType | ByteType =>
+ return TypeCheckResult.TypeCheckSuccess
+ case _ => // failed
+ }
+ case _ => // failed
+ }
+ TypeCheckResult.TypeCheckFailure(
+ s"ShiftRight expects long, integer, short or byte value as first argument and an " +
+ s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val valueLeft = left.eval(input)
+ if (valueLeft != null) {
+ val valueRight = right.eval(input)
+ if (valueRight != null) {
+ valueLeft match {
+ case l: Long => l >> valueRight.asInstanceOf[Integer]
+ case i: Integer => i >> valueRight.asInstanceOf[Integer]
+ case s: Short => s >> valueRight.asInstanceOf[Integer]
+ case b: Byte => b >> valueRight.asInstanceOf[Integer]
+ }
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def dataType: DataType = {
+ left.dataType match {
+ case LongType => LongType
+ case IntegerType | ShortType | ByteType => IntegerType
+ case _ => NullType
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
+ }
+}
+
/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index b3345d7069..aa27fe3cd5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
+import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
}
+ test("shift left") {
+ checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
+
+ checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
+ }
+
+ test("shift right") {
+ checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
+
+ checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
+ }
+
test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e6f623bdf3..a5b6828685 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1299,6 +1299,44 @@ object functions {
def rint(columnName: String): Column = rint(Column(columnName))
/**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(columnName: String, numBits: Int): Column =
+ shiftLeft(Column(columnName), numBits)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(columnName: String, numBits: Int): Column =
+ shiftRight(Column(columnName), numBits)
+
+ /**
* Computes the signum of the given value.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c03cde38d7..4c5696deaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("shift left") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
+ shiftLeft('f, 1)),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
+ "shiftLeft(f, 1)"),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+ }
+
+ test("shift right") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
+ shiftRight('f, 1)),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
+ "shiftRight(f, 1)"),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+ }
+
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(