aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-08-15 12:41:27 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-08-15 12:41:27 -0700
commitfffb0c0d19a2444e7554dfe6b27de0c086112b17 (patch)
treedc2fc14a9820672633b61b6acdf4a3d76985caf1 /python
parent5da6c4b24f512b63cd4e6ba7dd8968066a9396f5 (diff)
downloadspark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.tar.gz
spark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.tar.bz2
spark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.zip
[SPARK-16700][PYSPARK][SQL] create DataFrame from dict/Row with schema
## What changes were proposed in this pull request? In 2.0, we verify the data type against schema for every row for safety, but with performance cost, this PR make it optional. When we verify the data type for StructType, it does not support all the types we support in infer schema (for example, dict), this PR fix that to make them consistent. For Row object which is created using named arguments, the order of fields are sorted by name, they may be not different than the order in provided schema, this PR fix that by ignore the order of fields in this case. ## How was this patch tested? Created regression tests for them. Author: Davies Liu <davies@databricks.com> Closes #14469 from davies/py_dict.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/context.py8
-rw-r--r--python/pyspark/sql/session.py29
-rw-r--r--python/pyspark/sql/tests.py16
-rw-r--r--python/pyspark/sql/types.py37
4 files changed, 62 insertions, 28 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 4085f165f4..7482be8bda 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -215,7 +215,7 @@ class SQLContext(object):
@since(1.3)
@ignore_unicode_prefix
- def createDataFrame(self, data, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
@@ -245,6 +245,7 @@ class SQLContext(object):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
:param samplingRatio: the sample ratio of rows used for inferring
+ :param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`
.. versionchanged:: 2.0
@@ -253,6 +254,9 @@ class SQLContext(object):
If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
+ .. versionchanged:: 2.1
+ Added verifySchema.
+
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
@@ -300,7 +304,7 @@ class SQLContext(object):
...
Py4JJavaError: ...
"""
- return self.sparkSession.createDataFrame(data, schema, samplingRatio)
+ return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 2dacf483fc..61fa107497 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -384,17 +384,15 @@ class SparkSession(object):
if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchemaFromList(data)
+ converter = _create_converter(struct)
+ data = map(converter, data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
- elif isinstance(schema, StructType):
- for row in data:
- _verify_type(row, schema)
-
- else:
+ elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
# convert python objects to sql data
@@ -403,7 +401,7 @@ class SparkSession(object):
@since(2.0)
@ignore_unicode_prefix
- def createDataFrame(self, data, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
@@ -432,13 +430,11 @@ class SparkSession(object):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
``int`` as a short name for ``IntegerType``.
:param samplingRatio: the sample ratio of rows used for inferring
+ :param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`
- .. versionchanged:: 2.0
- The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
- datatype string after 2.0. If it's not a
- :class:`pyspark.sql.types.StructType`, it will be wrapped into a
- :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
+ .. versionchanged:: 2.1
+ Added verifySchema.
>>> l = [('Alice', 1)]
>>> spark.createDataFrame(l).collect()
@@ -503,17 +499,18 @@ class SparkSession(object):
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]
+ verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
def prepare(obj):
- _verify_type(obj, schema)
+ verify_func(obj, schema)
return obj
elif isinstance(schema, DataType):
- datatype = schema
+ dataType = schema
+ schema = StructType().add("value", schema)
def prepare(obj):
- _verify_type(obj, datatype)
- return (obj, )
- schema = StructType().add("value", datatype)
+ verify_func(obj, dataType)
+ return obj,
else:
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 87dbb50495..520b09d9c6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_apply_schema_to_dict_and_rows(self):
+ schema = StructType().add("b", StringType()).add("a", IntegerType())
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ for verify in [False, True]:
+ df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(10, df3.count())
+ input = [Row(a=x, b=str(x)) for x in range(10)]
+ df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ self.assertEqual(10, df4.count())
+
def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 1ca4bbc379..b765472d6e 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -582,6 +582,8 @@ class StructType(DataType):
else:
if isinstance(obj, dict):
return tuple(obj.get(n) for n in self.names)
+ elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+ return tuple(obj[n] for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
elif hasattr(obj, "__dict__"):
@@ -1243,7 +1245,7 @@ _acceptable_types = {
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
- StructType: (tuple, list),
+ StructType: (tuple, list, dict),
}
@@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
if _type is StructType:
- if not isinstance(obj, (tuple, list)):
- raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
+ # check the type and fields later
+ pass
else:
- # subclass of them can not be fromInternald in JVM
+ # subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
@@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
elif isinstance(dataType, StructType):
- if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with "
- "length of fields (%d)" % (len(obj), len(dataType.fields)))
- for v, f in zip(obj, dataType.fields):
- _verify_type(v, f.dataType, f.nullable)
+ if isinstance(obj, dict):
+ for f in dataType.fields:
+ _verify_type(obj.get(f.name), f.dataType, f.nullable)
+ elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+ # the order in obj could be different than dataType.fields
+ for f in dataType.fields:
+ _verify_type(obj[f.name], f.dataType, f.nullable)
+ elif isinstance(obj, (tuple, list)):
+ if len(obj) != len(dataType.fields):
+ raise ValueError("Length of object (%d) does not match with "
+ "length of fields (%d)" % (len(obj), len(dataType.fields)))
+ for v, f in zip(obj, dataType.fields):
+ _verify_type(v, f.dataType, f.nullable)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ for f in dataType.fields:
+ _verify_type(d.get(f.name), f.dataType, f.nullable)
+ else:
+ raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
# This is used to unpickle a Row from JVM
@@ -1410,6 +1426,7 @@ class Row(tuple):
names = sorted(kwargs.keys())
row = tuple.__new__(self, [kwargs[n] for n in names])
row.__fields__ = names
+ row.__from_dict__ = True
return row
else:
@@ -1485,7 +1502,7 @@ class Row(tuple):
raise AttributeError(item)
def __setattr__(self, key, value):
- if key != '__fields__':
+ if key != '__fields__' and key != "__from_dict__":
raise Exception("Row is read-only")
self.__dict__[key] = value