diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/context.py | 8 | ||||
-rw-r--r-- | python/pyspark/sql/session.py | 29 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 16 | ||||
-rw-r--r-- | python/pyspark/sql/types.py | 37 |
4 files changed, 62 insertions, 28 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4085f165f4..7482be8bda 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -215,7 +215,7 @@ class SQLContext(object): @since(1.3) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -245,6 +245,7 @@ class SQLContext(object): ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` .. versionchanged:: 2.0 @@ -253,6 +254,9 @@ class SQLContext(object): If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple. + .. versionchanged:: 2.1 + Added verifySchema. + >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] @@ -300,7 +304,7 @@ class SQLContext(object): ... Py4JJavaError: ... """ - return self.sparkSession.createDataFrame(data, schema, samplingRatio) + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 2dacf483fc..61fa107497 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -384,17 +384,15 @@ class SparkSession(object): if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchemaFromList(data) + converter = _create_converter(struct) + data = map(converter, data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data @@ -403,7 +401,7 @@ class SparkSession(object): @since(2.0) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -432,13 +430,11 @@ class SparkSession(object): ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` - .. versionchanged:: 2.0 - The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a - datatype string after 2.0. If it's not a - :class:`pyspark.sql.types.StructType`, it will be wrapped into a - :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple. + .. versionchanged:: 2.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() @@ -503,17 +499,18 @@ class SparkSession(object): schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] + verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): def prepare(obj): - _verify_type(obj, schema) + verify_func(obj, schema) return obj elif isinstance(schema, DataType): - datatype = schema + dataType = schema + schema = StructType().add("value", schema) def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) + verify_func(obj, dataType) + return obj, else: if isinstance(schema, list): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 87dbb50495..520b09d9c6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase): df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_apply_schema_to_dict_and_rows(self): + schema = StructType().add("b", StringType()).add("a", IntegerType()) + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + for verify in [False, True]: + df = self.spark.createDataFrame(input, schema, verifySchema=verify) + df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(10, df3.count()) + input = [Row(a=x, b=str(x)) for x in range(10)] + df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) + self.assertEqual(10, df4.count()) + def test_create_dataframe_schema_mismatch(self): input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1ca4bbc379..b765472d6e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -582,6 +582,8 @@ class StructType(DataType): else: if isinstance(obj, dict): return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + return tuple(obj[n] for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) elif hasattr(obj, "__dict__"): @@ -1243,7 +1245,7 @@ _acceptable_types = { TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), - StructType: (tuple, list), + StructType: (tuple, list, dict), } @@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True): assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) if _type is StructType: - if not isinstance(obj, (tuple, list)): - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + # check the type and fields later + pass else: - # subclass of them can not be fromInternald in JVM + # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) @@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True): _verify_type(v, dataType.valueType, dataType.valueContainsNull) elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + if isinstance(obj, dict): + for f in dataType.fields: + _verify_type(obj.get(f.name), f.dataType, f.nullable) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f in dataType.fields: + _verify_type(obj[f.name], f.dataType, f.nullable) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType, f.nullable) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f in dataType.fields: + _verify_type(d.get(f.name), f.dataType, f.nullable) + else: + raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) # This is used to unpickle a Row from JVM @@ -1410,6 +1426,7 @@ class Row(tuple): names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names + row.__from_dict__ = True return row else: @@ -1485,7 +1502,7 @@ class Row(tuple): raise AttributeError(item) def __setattr__(self, key, value): - if key != '__fields__': + if key != '__fields__' and key != "__from_dict__": raise Exception("Row is read-only") self.__dict__[key] = value |