aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/context.py10
-rw-r--r--python/pyspark/sql/tests.py11
2 files changed, 14 insertions, 7 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index b05aa2f5c4..ba6915a123 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -18,6 +18,7 @@
import sys
import warnings
import json
+from functools import reduce
if sys.version >= '3':
basestring = unicode = str
@@ -236,14 +237,9 @@ class SQLContext(object):
if type(first) is dict:
warnings.warn("inferring schema from dict is deprecated,"
"please use pyspark.sql.Row instead")
- schema = _infer_schema(first)
+ schema = reduce(_merge_type, map(_infer_schema, data))
if _has_nulltype(schema):
- for r in data:
- schema = _merge_type(schema, _infer_schema(r))
- if not _has_nulltype(schema):
- break
- else:
- raise ValueError("Some of types cannot be determined after inferring")
+ raise ValueError("Some of types cannot be determined after inferring")
return schema
def _inferSchema(self, rdd, samplingRatio=None):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 9f5f7cfdf7..10b99175ad 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -353,6 +353,17 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_infer_schema_to_local(self):
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ df = self.sqlCtx.createDataFrame(input)
+ df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
+ self.assertEqual(10, df3.count())
+
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)