aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 00:09:24 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 00:09:24 -0700
commitfe917f5ec9be8c8424416f7b5423ddb4318e03a0 (patch)
treef92902ee2d78db76c98a40313a971d497134b69f /python/pyspark/sql/tests.py
parent8dee2746b5857bb4c32b96f38840dd1b63574ab2 (diff)
downloadspark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.tar.gz
spark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.tar.bz2
spark-fe917f5ec9be8c8424416f7b5423ddb4318e03a0.zip
[SPARK-7188] added python support for math DataFrame functions
Adds support for the math functions for DataFrames in PySpark. rxin I love Davies. Author: Burak Yavuz <brkyvz@gmail.com> Closes #5750 from brkyvz/python-math-udfs and squashes the following commits: 7c4f563 [Burak Yavuz] removed is_math 3c4adde [Burak Yavuz] cleanup imports d5dca3f [Burak Yavuz] moved math functions to mathfunctions 25e6534 [Burak Yavuz] addressed comments v2.0 d3f7e0f [Burak Yavuz] addressed comments and added tests 7b7d7c4 [Burak Yavuz] remove tests for removed methods 33c2c15 [Burak Yavuz] fixed python style 3ee0c05 [Burak Yavuz] added python functions
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fe43c374f1..2ffd18ebd7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -387,6 +387,35 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+ def test_math_functions(self):
+ df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
+ from pyspark.sql import mathfunctions as functions
+ import math
+
+ def get_values(l):
+ return [j[0] for j in l]
+
+ def assert_close(a, b):
+ c = get_values(b)
+ diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
+ return sum(diff) == len(a)
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos(df.a)).collect())
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos("a")).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df.a)).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df['a'])).collect())
+ assert_close([math.pow(i, 2 * i) for i in range(10)],
+ df.select(functions.pow(df.a, df.b)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2.0)).collect())
+ assert_close([math.hypot(i, 2 * i) for i in range(10)],
+ df.select(functions.hypot(df.a, df.b)).collect())
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()