aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08 /python/pyspark/sql/tests.py
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py38
1 files changed, 17 insertions, 21 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 43e5c3a1b0..aa80bca346 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -96,7 +96,7 @@ class SQLTests(ReusedPySparkTestCase):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.createDataFrame(rdd)
+ cls.df = rdd.toDF()
@classmethod
def tearDownClass(cls):
@@ -138,7 +138,7 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sqlCtx.jsonRDD(rdd)
df.count()
df.collect()
- df.schema()
+ df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
@@ -155,11 +155,11 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.createDataFrame(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
@@ -195,7 +195,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1, result.head()[0])
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
- self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
@@ -204,8 +204,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -213,8 +212,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -223,9 +221,8 @@ class SQLTests(ReusedPySparkTestCase):
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
- schema = df.schema()
+ df = self.sc.parallelize([row]).toDF()
+ schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
@@ -238,15 +235,14 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.createDataFrame(rdd, schema)
+ df = rdd.toDF(schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.createDataFrame(rdd)
+ df0 = self.sc.parallelize([row]).toDF()
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
@@ -280,10 +276,11 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- from pyspark.sql import Dsl
- self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
- self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ from pyspark.sql import functions
+ self.assertEqual((0, u'99'),
+ tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
def test_save_and_load(self):
df = self.df
@@ -339,8 +336,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
@classmethod
def tearDownClass(cls):