diff options
author | Davies Liu <davies@databricks.com> | 2015-06-29 13:20:55 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-06-29 13:20:55 -0700 |
commit | afae9766f28d2e58297405c39862d20a04267b62 (patch) | |
tree | a05a7678489832bec046703367ac3446cb31c4f5 /python/pyspark/sql/context.py | |
parent | be7ef067620408859144e0244b0f1b8eb56faa86 (diff) | |
download | spark-afae9766f28d2e58297405c39862d20a04267b62.tar.gz spark-afae9766f28d2e58297405c39862d20a04267b62.tar.bz2 spark-afae9766f28d2e58297405c39862d20a04267b62.zip |
[SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame
Avoid the unnecessary jobs when infer schema from list.
cc yhuai mengxr
Author: Davies Liu <davies@databricks.com>
Closes #6606 from davies/improve_create and squashes the following commits:
a5928bf [Davies Liu] Update MimaExcludes.scala
62da911 [Davies Liu] fix mima
bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create
eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create
8d9292d [Davies Liu] Update context.py
eb24531 [Davies Liu] Update context.py
c969997 [Davies Liu] bug fix
d5a8ab0 [Davies Liu] fix tests
8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create
6ea5925 [Davies Liu] address comments
6ceaeff [Davies Liu] avoid spark jobs in createDataFrame
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r-- | python/pyspark/sql/context.py | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index dc239226e6..4dda3b430c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -203,7 +203,37 @@ class SQLContext(object): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -322,6 +352,8 @@ class SQLContext(object): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -330,28 +362,26 @@ class SQLContext(object): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) |