aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author云峤 <chensong.cs@alibaba-inc.com>2015-05-05 13:23:53 -0700
committerReynold Xin <rxin@databricks.com>2015-05-05 13:23:53 -0700
commit735bc3d042b1e3e12de57b66f166af14254ad314 (patch)
tree853b8c985baf71c78ab51ddc7724dfad630523c6
parent489700c809a7c0a836538f3d0bd58bed609e8768 (diff)
downloadspark-735bc3d042b1e3e12de57b66f166af14254ad314.tar.gz
spark-735bc3d042b1e3e12de57b66f166af14254ad314.tar.bz2
spark-735bc3d042b1e3e12de57b66f166af14254ad314.zip
[SPARK-7294][SQL] ADD BETWEEN
Author: 云峤 <chensong.cs@alibaba-inc.com> Author: kaka1992 <kaka_1992@163.com> Closes #5839 from kaka1992/master and squashes the following commits: b15360d [kaka1992] Fix python unit test in sql/test. =_= I forget to commit this file last time. f928816 [kaka1992] Fix python style in sql/test. d2e7f72 [kaka1992] Fix python style in sql/test. c54d904 [kaka1992] Fix empty map bug. 7e64d1e [云峤] Update 7b9b858 [云峤] undo f080f8d [云峤] update pep8 76f0c51 [云峤] Merge remote-tracking branch 'remotes/upstream/master' 7d62368 [云峤] [SPARK-7294] ADD BETWEEN baf839b [云峤] [SPARK-7294] ADD BETWEEN d11d5b9 [云峤] [SPARK-7294] ADD BETWEEN
-rw-r--r--python/pyspark/sql/dataframe.py7
-rw-r--r--python/pyspark/sql/tests.py8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala14
4 files changed, 38 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 17448b38c3..24f370543d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1405,6 +1405,13 @@ class Column(object):
raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)
+ @ignore_unicode_prefix
+ def between(self, lowerBound, upperBound):
+ """ A boolean expression that is evaluated to true if the value of this
+ expression is between the given columns.
+ """
+ return (self >= lowerBound) & (self <= upperBound)
+
def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7ea6656d31..46c4c88e98 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -453,6 +453,14 @@ class SQLTests(ReusedPySparkTestCase):
for row in rndn:
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+ def test_between_function(self):
+ df = self.sc.parallelize([
+ Row(a=1, b=2, c=3),
+ Row(a=2, b=1, c=3),
+ Row(a=4, b=1, c=4)]).toDF()
+ self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
+ df.filter(df.a.between(df.b, df.c)).collect())
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 33f9d0b37d..c0503bf047 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -296,6 +296,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def eqNullSafe(other: Any): Column = this <=> other
/**
+ * True if the current column is between the lower bound and upper bound, inclusive.
+ *
+ * @group java_expr_ops
+ */
+ def between(lowerBound: Any, upperBound: Any): Column = {
+ (this >= lowerBound) && (this <= upperBound)
+ }
+
+ /**
* True if the current expression is null.
*
* @group expr_ops
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 6322faf4d9..3c1ad656fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -208,6 +208,20 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
+ test("between") {
+ val testData = TestSQLContext.sparkContext.parallelize(
+ (0, 1, 2) ::
+ (1, 2, 3) ::
+ (2, 1, 0) ::
+ (2, 2, 4) ::
+ (3, 1, 6) ::
+ (3, 2, 0) :: Nil).toDF("a", "b", "c")
+ val expectAnswer = testData.collect().toSeq.
+ filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))
+
+ checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer)
+ }
+
val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::