aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/context.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-29 13:20:55 -0700
committerReynold Xin <rxin@databricks.com>2015-06-29 13:20:55 -0700
commitafae9766f28d2e58297405c39862d20a04267b62 (patch)
treea05a7678489832bec046703367ac3446cb31c4f5 /python/pyspark/sql/context.py
parentbe7ef067620408859144e0244b0f1b8eb56faa86 (diff)
downloadspark-afae9766f28d2e58297405c39862d20a04267b62.tar.gz
spark-afae9766f28d2e58297405c39862d20a04267b62.tar.bz2
spark-afae9766f28d2e58297405c39862d20a04267b62.zip
[SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame
Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu <davies@databricks.com> Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r--python/pyspark/sql/context.py64
1 files changed, 47 insertions, 17 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc239226e6..4dda3b430c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -203,7 +203,37 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchemaFromList(self, data):
+ """
+ Infer schema from list of Row or tuple.
+
+ :param data: list of Row or tuple
+ :return: StructType
+ """
+ if not data:
+ raise ValueError("can not infer schema from empty dataset")
+ first = data[0]
+ if type(first) is dict:
+ warnings.warn("inferring schema from dict is deprecated,"
+ "please use pyspark.sql.Row instead")
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for r in data:
+ schema = _merge_type(schema, _infer_schema(r))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined after inferring")
+ return schema
+
def _inferSchema(self, rdd, samplingRatio=None):
+ """
+ Infer schema from an RDD of Row or tuple.
+
+ :param rdd: an RDD of Row or tuple
+ :param samplingRatio: sampling ratio, or no sampling (default)
+ :return: StructType
+ """
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
@@ -322,6 +352,8 @@ class SQLContext(object):
data = [r.tolist() for r in data.to_records(index=False)]
if not isinstance(data, RDD):
+ if not isinstance(data, list):
+ data = list(data)
try:
# data could be list, tuple, generator ...
rdd = self._sc.parallelize(data)
@@ -330,28 +362,26 @@ class SQLContext(object):
else:
rdd = data
- if schema is None:
- schema = self._inferSchema(rdd, samplingRatio)
+ if schema is None or isinstance(schema, (list, tuple)):
+ if isinstance(data, RDD):
+ struct = self._inferSchema(rdd, samplingRatio)
+ else:
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ schema = struct
converter = _create_converter(schema)
rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise TypeError("each row in `rdd` should be list or tuple, "
- "but got %r" % type(first))
- row_cls = Row(*schema)
- schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
-
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
- for row in rows:
- _verify_type(row, schema)
+ else:
+ raise TypeError("schema should be StructType or list or None")
# convert python objects to sql data
converter = _python_to_sql_converter(schema)