diff options
Diffstat (limited to 'python/pyspark/sql/session.py')
-rw-r--r-- | python/pyspark/sql/session.py | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 2dacf483fc..61fa107497 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -384,17 +384,15 @@ class SparkSession(object): if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchemaFromList(data) + converter = _create_converter(struct) + data = map(converter, data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data @@ -403,7 +401,7 @@ class SparkSession(object): @since(2.0) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -432,13 +430,11 @@ class SparkSession(object): ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` - .. versionchanged:: 2.0 - The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a - datatype string after 2.0. If it's not a - :class:`pyspark.sql.types.StructType`, it will be wrapped into a - :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple. + .. versionchanged:: 2.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() @@ -503,17 +499,18 @@ class SparkSession(object): schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] + verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): def prepare(obj): - _verify_type(obj, schema) + verify_func(obj, schema) return obj elif isinstance(schema, DataType): - datatype = schema + dataType = schema + schema = StructType().add("value", schema) def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) + verify_func(obj, dataType) + return obj, else: if isinstance(schema, list): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] |