aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 00:09:24 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 00:09:24 -0700
commitfe917f5ec9be8c8424416f7b5423ddb4318e03a0 (patch)
treef92902ee2d78db76c98a40313a971d497134b69f /python
parent8dee2746b5857bb4c32b96f38840dd1b63574ab2 (diff)
downloadspark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.tar.gz
spark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.tar.bz2
spark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.zip
[SPARK-7188] added python support for math DataFrame functions
Adds support for the math functions for DataFrames in PySpark. rxin I love Davies. Author: Burak Yavuz <brkyvz@gmail.com> Closes #5750 from brkyvz/python-math-udfs and squashes the following commits: 7c4f563 [Burak Yavuz] removed is_math 3c4adde [Burak Yavuz] cleanup imports d5dca3f [Burak Yavuz] moved math functions to mathfunctions 25e6534 [Burak Yavuz] addressed comments v2.0 d3f7e0f [Burak Yavuz] addressed comments and added tests 7b7d7c4 [Burak Yavuz] remove tests for removed methods 33c2c15 [Burak Yavuz] fixed python style 3ee0c05 [Burak Yavuz] added python functions
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--python/pyspark/sql/mathfunctions.py101
-rw-r--r--python/pyspark/sql/tests.py29
3 files changed, 131 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7b86655d9c..555c2fa5e7 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -54,7 +54,7 @@ _functions = {
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.',
'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
+ 'abs': 'Computes the absolute value.',
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
diff --git a/python/pyspark/sql/mathfunctions.py b/python/pyspark/sql/mathfunctions.py
new file mode 100644
index 0000000000..7dbcab8694
--- /dev/null
+++ b/python/pyspark/sql/mathfunctions.py
@@ -0,0 +1,101 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collection of builtin math functions
+"""
+
+from pyspark import SparkContext
+from pyspark.sql.dataframe import Column
+
+__all__ = []
+
+
+def _create_unary_mathfunction(name, doc=""):
+ """ Create a unary mathfunction by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col)
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return _
+
+
+def _create_binary_mathfunction(name, doc=""):
+ """ Create a binary mathfunction by name"""
+ def _(col1, col2):
+ sc = SparkContext._active_spark_context
+ # users might write ints for simplicity. This would throw an error on the JVM side.
+ if type(col1) is int:
+ col1 = col1 * 1.0
+ if type(col2) is int:
+ col2 = col2 * 1.0
+ jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1,
+ col2._jc if isinstance(col2, Column) else col2)
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return _
+
+
+# math functions are found under another object therefore, they need to be handled separately
+_mathfunctions = {
+ 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
+ '0.0 through pi.',
+ 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
+ '-pi/2 through pi/2.',
+ 'atan': 'Computes the tangent inverse of the given value.',
+ 'cbrt': 'Computes the cube-root of the given value.',
+ 'ceil': 'Computes the ceiling of the given value.',
+ 'cos': 'Computes the cosine of the given value.',
+ 'cosh': 'Computes the hyperbolic cosine of the given value.',
+ 'exp': 'Computes the exponential of the given value.',
+ 'expm1': 'Computes the exponential of the given value minus one.',
+ 'floor': 'Computes the floor of the given value.',
+ 'log': 'Computes the natural logarithm of the given value.',
+ 'log10': 'Computes the logarithm of the given value in Base 10.',
+ 'log1p': 'Computes the natural logarithm of the given value plus one.',
+ 'rint': 'Returns the double value that is closest in value to the argument and' +
+ ' is equal to a mathematical integer.',
+ 'signum': 'Computes the signum of the given value.',
+ 'sin': 'Computes the sine of the given value.',
+ 'sinh': 'Computes the hyperbolic sine of the given value.',
+ 'tan': 'Computes the tangent of the given value.',
+ 'tanh': 'Computes the hyperbolic tangent of the given value.',
+ 'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' +
+ 'measured in degrees.',
+ 'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
+ 'measured in radians.'
+}
+
+# math functions that take two arguments as input
+_binary_mathfunctions = {
+ 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
+ 'polar coordinates (r, theta).',
+ 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
+ 'pow': 'Returns the value of the first argument raised to the power of the second argument.'
+}
+
+for _name, _doc in _mathfunctions.items():
+ globals()[_name] = _create_unary_mathfunction(_name, _doc)
+for _name, _doc in _binary_mathfunctions.items():
+ globals()[_name] = _create_binary_mathfunction(_name, _doc)
+del _name, _doc
+__all__ += _mathfunctions.keys()
+__all__ += _binary_mathfunctions.keys()
+__all__.sort()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fe43c374f1..2ffd18ebd7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -387,6 +387,35 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+ def test_math_functions(self):
+ df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
+ from pyspark.sql import mathfunctions as functions
+ import math
+
+ def get_values(l):
+ return [j[0] for j in l]
+
+ def assert_close(a, b):
+ c = get_values(b)
+ diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
+ return sum(diff) == len(a)
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos(df.a)).collect())
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos("a")).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df.a)).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df['a'])).collect())
+ assert_close([math.pow(i, 2 * i) for i in range(10)],
+ df.select(functions.pow(df.a, df.b)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2.0)).collect())
+ assert_close([math.hypot(i, 2 * i) for i in range(10)],
+ df.select(functions.hypot(df.a, df.b)).collect())
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()