aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py5
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--python/pyspark/sql/tests.py13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala7
7 files changed, 97 insertions, 2 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 24f370543d..cee804f5cc 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1277,6 +1277,11 @@ class Column(object):
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")
+ # bitwise operators
+ bitwiseOR = _bin_op("bitwiseOR")
+ bitwiseAND = _bin_op("bitwiseAND")
+ bitwiseXOR = _bin_op("bitwiseXOR")
+
def getItem(self, key):
"""An expression that gets an item at position `ordinal` out of a list,
or gets an item by key out of a dict.
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 692af868dd..274c410a1e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -104,6 +104,8 @@ _functions = {
'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
'measured in radians.',
+ 'bitwiseNOT': 'Computes bitwise not.',
+
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
'first': 'Aggregate function: returns the first value in a group.',
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b232f3a965..45dfedce22 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -645,6 +645,19 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
+ def test_bitwise_operations(self):
+ from pyspark.sql import functions
+ row = Row(a=170, b=75)
+ df = self.sqlCtx.createDataFrame([row])
+ result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
+ self.assertEqual(170 & 75, result['(a & b)'])
+ result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
+ self.assertEqual(170 | 75, result['(a | b)'])
+ result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
+ self.assertEqual(170 ^ 75, result['(a ^ b)'])
+ result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
+ self.assertEqual(~75, result['~b'])
+
class HiveContextSQLTests(ReusedPySparkTestCase):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8eb632d3d6..8bbe11b412 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -698,6 +698,37 @@ class Column(protected[sql] val expr: Expression) extends Logging {
println(expr.prettyString)
}
}
+
+ /**
+ * Compute bitwise OR of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseOR($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr)
+
+ /**
+ * Compute bitwise AND of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseAND($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr)
+
+ /**
+ * Compute bitwise XOR of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseXOR($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 830b501771..1728b0b8c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -438,6 +438,14 @@ object functions {
*/
def upper(e: Column): Column = Upper(e.expr)
+
+ /**
+ * Computes bitwise NOT.
+ *
+ * @group normal_funcs
+ */
+ def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr)
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Math Functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 3c1ad656fc..d96186c268 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
- // TODO: Add test cases for bitwise operations.
-
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
@@ -385,4 +383,35 @@ class ColumnExpressionSuite extends QueryTest {
assert(row.getDouble(1) >= -4.0)
}
}
+
+ test("bitwiseAND") {
+ checkAnswer(
+ testData2.select($"a".bitwiseAND(75)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) & 75)))
+
+ checkAnswer(
+ testData2.select($"a".bitwiseAND($"b").bitwiseAND(22)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) & r.getInt(1) & 22)))
+ }
+
+ test("bitwiseOR") {
+ checkAnswer(
+ testData2.select($"a".bitwiseOR(170)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) | 170)))
+
+ checkAnswer(
+ testData2.select($"a".bitwiseOR($"b").bitwiseOR(42)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) | r.getInt(1) | 42)))
+ }
+
+ test("bitwiseXOR") {
+ checkAnswer(
+ testData2.select($"a".bitwiseXOR(112)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ 112)))
+
+ checkAnswer(
+ testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
+ testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ca03713ef4..b1e0faa310 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
@@ -81,4 +82,10 @@ class DataFrameFunctionsSuite extends QueryTest {
struct(col("a") * 2)
}
}
+
+ test("bitwiseNOT") {
+ checkAnswer(
+ testData2.select(bitwiseNOT($"a")),
+ testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
+ }
}