aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-29 13:20:55 -0700
committerReynold Xin <rxin@databricks.com>2015-06-29 13:20:55 -0700
commitafae9766f28d2e58297405c39862d20a04267b62 (patch)
treea05a7678489832bec046703367ac3446cb31c4f5 /python/pyspark
parentbe7ef067620408859144e0244b0f1b8eb56faa86 (diff)
downloadspark-afae9766f28d2e58297405c39862d20a04267b62.tar.gz
spark-afae9766f28d2e58297405c39862d20a04267b62.tar.bz2
spark-afae9766f28d2e58297405c39862d20a04267b62.zip
[SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame
Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu <davies@databricks.com> Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/context.py64
-rw-r--r--python/pyspark/sql/types.py48
2 files changed, 75 insertions, 37 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc239226e6..4dda3b430c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -203,7 +203,37 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchemaFromList(self, data):
+ """
+ Infer schema from list of Row or tuple.
+
+ :param data: list of Row or tuple
+ :return: StructType
+ """
+ if not data:
+ raise ValueError("can not infer schema from empty dataset")
+ first = data[0]
+ if type(first) is dict:
+ warnings.warn("inferring schema from dict is deprecated,"
+ "please use pyspark.sql.Row instead")
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for r in data:
+ schema = _merge_type(schema, _infer_schema(r))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined after inferring")
+ return schema
+
def _inferSchema(self, rdd, samplingRatio=None):
+ """
+ Infer schema from an RDD of Row or tuple.
+
+ :param rdd: an RDD of Row or tuple
+ :param samplingRatio: sampling ratio, or no sampling (default)
+ :return: StructType
+ """
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
@@ -322,6 +352,8 @@ class SQLContext(object):
data = [r.tolist() for r in data.to_records(index=False)]
if not isinstance(data, RDD):
+ if not isinstance(data, list):
+ data = list(data)
try:
# data could be list, tuple, generator ...
rdd = self._sc.parallelize(data)
@@ -330,28 +362,26 @@ class SQLContext(object):
else:
rdd = data
- if schema is None:
- schema = self._inferSchema(rdd, samplingRatio)
+ if schema is None or isinstance(schema, (list, tuple)):
+ if isinstance(data, RDD):
+ struct = self._inferSchema(rdd, samplingRatio)
+ else:
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ schema = struct
converter = _create_converter(schema)
rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise TypeError("each row in `rdd` should be list or tuple, "
- "but got %r" % type(first))
- row_cls = Row(*schema)
- schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
-
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
- for row in rows:
- _verify_type(row, schema)
+ else:
+ raise TypeError("schema should be StructType or list or None")
# convert python objects to sql data
converter = _python_to_sql_converter(schema)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 23d9adb0da..932686e5e4 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType):
>>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
- False
+ True
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
@@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType):
True
"""
if isinstance(dataType, StructType):
- return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ # convert namedtuple or Row into tuple
+ return True
elif isinstance(dataType, ArrayType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
@@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType):
if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
- converters = [_python_to_sql_converter(t) for t in types]
-
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(c(obj.get(n)) for n, c in zip(names, converters))
- elif isinstance(obj, tuple):
- if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
- return tuple(c(v) for c, v in zip(converters, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
- d = dict(obj)
- return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ if any(_need_python_to_sql_conversion(t) for t in types):
+ converters = [_python_to_sql_converter(t) for t in types]
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif obj is not None:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ else:
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(obj.get(n) for n in names)
else:
- return tuple(c(v) for c, v in zip(converters, obj))
- elif obj is not None:
- raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return tuple(obj)
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
@@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType):
_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s" % dataType
- # subclass of them can not be deserialized in JVM
- if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept object in type %s"
- % (dataType, type(obj)))
+ if _type is StructType:
+ if not isinstance(obj, (tuple, list)):
+ raise TypeError("StructType can not accept object in type %s" % type(obj))
+ else:
+ # subclass of them can not be deserialized in JVM
+ if type(obj) not in _acceptable_types[_type]:
+ raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
if isinstance(dataType, ArrayType):
for i in obj: