aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author0x0FFF <programmerag@gmail.com>2015-09-11 15:19:04 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-11 15:19:04 -0700
commitc34fc19765bdf55365cdce78d9ba11b220b73bb6 (patch)
tree0562a581bc5d7d9f79ea174754065dfdcbefafdc
parentd74c6a143cbd060c25bf14a8d306841b3ec55d03 (diff)
downloadspark-c34fc19765bdf55365cdce78d9ba11b220b73bb6.tar.gz
spark-c34fc19765bdf55365cdce78d9ba11b220b73bb6.tar.bz2
spark-c34fc19765bdf55365cdce78d9ba11b220b73bb6.zip
[SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator
This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF <programmerag@gmail.com> Closes #8658 from 0x0FFF/SPARK-9014.
-rw-r--r--python/pyspark/sql/column.py13
-rw-r--r--python/pyspark/sql/tests.py2
2 files changed, 14 insertions, 1 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 573f65f5bf..9ca8e1f264 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -91,6 +91,17 @@ def _func_op(name, doc=''):
return _
+def _bin_func_op(name, reverse=False, doc="binary function"):
+ def _(self, other):
+ sc = SparkContext._active_spark_context
+ fn = getattr(sc._jvm.functions, name)
+ jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
+ njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
+ return Column(njc)
+ _.__doc__ = doc
+ return _
+
+
def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
"""
@@ -151,6 +162,8 @@ class Column(object):
__rdiv__ = _reverse_op("divide")
__rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
+ __pow__ = _bin_func_op("pow")
+ __rpow__ = _bin_func_op("pow", reverse=True)
# logistic operators
__eq__ = _bin_op("equalTo")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index eb449e8679..f2172b7a27 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -568,7 +568,7 @@ class SQLTests(ReusedPySparkTestCase):
cs = self.df.value
c = ci == cs
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
- rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
self.assertTrue(all(isinstance(c, Column) for c in rcc))
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
self.assertTrue(all(isinstance(c, Column) for c in cb))