aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/context.py8
-rw-r--r--python/pyspark/sql/session.py29
-rw-r--r--python/pyspark/sql/tests.py16
-rw-r--r--python/pyspark/sql/types.py37
4 files changed, 62 insertions, 28 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 4085f165f4..7482be8bda 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -215,7 +215,7 @@ class SQLContext(object):
@since(1.3)
@ignore_unicode_prefix
- def createDataFrame(self, data, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
@@ -245,6 +245,7 @@ class SQLContext(object):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
:param samplingRatio: the sample ratio of rows used for inferring
+ :param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`
.. versionchanged:: 2.0
@@ -253,6 +254,9 @@ class SQLContext(object):
If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
+ .. versionchanged:: 2.1
+ Added verifySchema.
+
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
@@ -300,7 +304,7 @@ class SQLContext(object):
...
Py4JJavaError: ...
"""
- return self.sparkSession.createDataFrame(data, schema, samplingRatio)
+ return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 2dacf483fc..61fa107497 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -384,17 +384,15 @@ class SparkSession(object):
if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchemaFromList(data)
+ converter = _create_converter(struct)
+ data = map(converter, data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
- elif isinstance(schema, StructType):
- for row in data:
- _verify_type(row, schema)
-
- else:
+ elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
# convert python objects to sql data
@@ -403,7 +401,7 @@ class SparkSession(object):
@since(2.0)
@ignore_unicode_prefix
- def createDataFrame(self, data, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
@@ -432,13 +430,11 @@ class SparkSession(object):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
``int`` as a short name for ``IntegerType``.
:param samplingRatio: the sample ratio of rows used for inferring
+ :param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`
- .. versionchanged:: 2.0
- The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
- datatype string after 2.0. If it's not a
- :class:`pyspark.sql.types.StructType`, it will be wrapped into a
- :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
+ .. versionchanged:: 2.1
+ Added verifySchema.
>>> l = [('Alice', 1)]
>>> spark.createDataFrame(l).collect()
@@ -503,17 +499,18 @@ class SparkSession(object):
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]
+ verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
def prepare(obj):
- _verify_type(obj, schema)
+ verify_func(obj, schema)
return obj
elif isinstance(schema, DataType):
- datatype = schema
+ dataType = schema
+ schema = StructType().add("value", schema)
def prepare(obj):
- _verify_type(obj, datatype)
- return (obj, )
- schema = StructType().add("value", datatype)
+ verify_func(obj, dataType)
+ return obj,
else:
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 87dbb50495..520b09d9c6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_apply_schema_to_dict_and_rows(self):
+ schema = StructType().add("b", StringType()).add("a", IntegerType())
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ for verify in [False, True]:
+ df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(10, df3.count())
+ input = [Row(a=x, b=str(x)) for x in range(10)]
+ df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ self.assertEqual(10, df4.count())
+
def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 1ca4bbc379..b765472d6e 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -582,6 +582,8 @@ class StructType(DataType):
else:
if isinstance(obj, dict):
return tuple(obj.get(n) for n in self.names)
+ elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+ return tuple(obj[n] for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
elif hasattr(obj, "__dict__"):
@@ -1243,7 +1245,7 @@ _acceptable_types = {
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
- StructType: (tuple, list),
+ StructType: (tuple, list, dict),
}
@@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
if _type is StructType:
- if not isinstance(obj, (tuple, list)):
- raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
+ # check the type and fields later
+ pass
else:
- # subclass of them can not be fromInternald in JVM
+ # subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
@@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
elif isinstance(dataType, StructType):
- if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with "
- "length of fields (%d)" % (len(obj), len(dataType.fields)))
- for v, f in zip(obj, dataType.fields):
- _verify_type(v, f.dataType, f.nullable)
+ if isinstance(obj, dict):
+ for f in dataType.fields:
+ _verify_type(obj.get(f.name), f.dataType, f.nullable)
+ elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+ # the order in obj could be different than dataType.fields
+ for f in dataType.fields:
+ _verify_type(obj[f.name], f.dataType, f.nullable)
+ elif isinstance(obj, (tuple, list)):
+ if len(obj) != len(dataType.fields):
+ raise ValueError("Length of object (%d) does not match with "
+ "length of fields (%d)" % (len(obj), len(dataType.fields)))
+ for v, f in zip(obj, dataType.fields):
+ _verify_type(v, f.dataType, f.nullable)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ for f in dataType.fields:
+ _verify_type(d.get(f.name), f.dataType, f.nullable)
+ else:
+ raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
# This is used to unpickle a Row from JVM
@@ -1410,6 +1426,7 @@ class Row(tuple):
names = sorted(kwargs.keys())
row = tuple.__new__(self, [kwargs[n] for n in names])
row.__fields__ = names
+ row.__from_dict__ = True
return row
else:
@@ -1485,7 +1502,7 @@ class Row(tuple):
raise AttributeError(item)
def __setattr__(self, key, value):
- if key != '__fields__':
+ if key != '__fields__' and key != "__from_dict__":
raise Exception("Row is read-only")
self.__dict__[key] = value