diff options
-rw-r--r-- | python/pyspark/sql/column.py | 13 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 2 |
2 files changed, 14 insertions, 1 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 573f65f5bf..9ca8e1f264 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -91,6 +91,17 @@ def _func_op(name, doc=''): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ class Column(object): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb449e8679..f2172b7a27 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,7 +568,7 @@ class SQLTests(ReusedPySparkTestCase): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) |