From c34fc19765bdf55365cdce78d9ba11b220b73bb6 Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Fri, 11 Sep 2015 15:19:04 -0700 Subject: [SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF Closes #8658 from 0x0FFF/SPARK-9014. --- python/pyspark/sql/column.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 573f65f5bf..9ca8e1f264 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -91,6 +91,17 @@ def _func_op(name, doc=''): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ class Column(object): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb449e8679..f2172b7a27 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,7 +568,7 @@ class SQLTests(ReusedPySparkTestCase): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) -- cgit v1.2.3