aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r--python/pyspark/sql/context.py64
1 files changed, 47 insertions, 17 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc239226e6..4dda3b430c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -203,7 +203,37 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchemaFromList(self, data):
+ """
+ Infer schema from list of Row or tuple.
+
+ :param data: list of Row or tuple
+ :return: StructType
+ """
+ if not data:
+ raise ValueError("can not infer schema from empty dataset")
+ first = data[0]
+ if type(first) is dict:
+ warnings.warn("inferring schema from dict is deprecated,"
+ "please use pyspark.sql.Row instead")
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for r in data:
+ schema = _merge_type(schema, _infer_schema(r))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined after inferring")
+ return schema
+
def _inferSchema(self, rdd, samplingRatio=None):
+ """
+ Infer schema from an RDD of Row or tuple.
+
+ :param rdd: an RDD of Row or tuple
+ :param samplingRatio: sampling ratio, or no sampling (default)
+ :return: StructType
+ """
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
@@ -322,6 +352,8 @@ class SQLContext(object):
data = [r.tolist() for r in data.to_records(index=False)]
if not isinstance(data, RDD):
+ if not isinstance(data, list):
+ data = list(data)
try:
# data could be list, tuple, generator ...
rdd = self._sc.parallelize(data)
@@ -330,28 +362,26 @@ class SQLContext(object):
else:
rdd = data
- if schema is None:
- schema = self._inferSchema(rdd, samplingRatio)
+ if schema is None or isinstance(schema, (list, tuple)):
+ if isinstance(data, RDD):
+ struct = self._inferSchema(rdd, samplingRatio)
+ else:
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ schema = struct
converter = _create_converter(schema)
rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise TypeError("each row in `rdd` should be list or tuple, "
- "but got %r" % type(first))
- row_cls = Row(*schema)
- schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
-
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
- for row in rows:
- _verify_type(row, schema)
+ else:
+ raise TypeError("schema should be StructType or list or None")
# convert python objects to sql data
converter = _python_to_sql_converter(schema)