aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/session.py')
-rw-r--r--python/pyspark/sql/session.py29
1 files changed, 13 insertions, 16 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 2dacf483fc..61fa107497 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -384,17 +384,15 @@ class SparkSession(object):
if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchemaFromList(data)
+ converter = _create_converter(struct)
+ data = map(converter, data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
- elif isinstance(schema, StructType):
- for row in data:
- _verify_type(row, schema)
-
- else:
+ elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
# convert python objects to sql data
@@ -403,7 +401,7 @@ class SparkSession(object):
@since(2.0)
@ignore_unicode_prefix
- def createDataFrame(self, data, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
@@ -432,13 +430,11 @@ class SparkSession(object):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
``int`` as a short name for ``IntegerType``.
:param samplingRatio: the sample ratio of rows used for inferring
+ :param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`
- .. versionchanged:: 2.0
- The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
- datatype string after 2.0. If it's not a
- :class:`pyspark.sql.types.StructType`, it will be wrapped into a
- :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
+ .. versionchanged:: 2.1
+ Added verifySchema.
>>> l = [('Alice', 1)]
>>> spark.createDataFrame(l).collect()
@@ -503,17 +499,18 @@ class SparkSession(object):
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]
+ verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
def prepare(obj):
- _verify_type(obj, schema)
+ verify_func(obj, schema)
return obj
elif isinstance(schema, DataType):
- datatype = schema
+ dataType = schema
+ schema = StructType().add("value", schema)
def prepare(obj):
- _verify_type(obj, datatype)
- return (obj, )
- schema = StructType().add("value", datatype)
+ verify_func(obj, dataType)
+ return obj,
else:
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]