diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/context.py | 68 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 40 | ||||
-rw-r--r-- | python/pyspark/sql/types.py | 129 |
3 files changed, 214 insertions, 23 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8e324169d8..9c2f6a3c56 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -31,8 +31,8 @@ from py4j.protocol import Py4JError from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter +from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler @@ -301,11 +301,6 @@ class SQLContext(object): Create an RDD for DataFrame from an list or pandas.DataFrame, returns the RDD and schema. """ - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - # make sure data could consumed multiple times if not isinstance(data, list): data = list(data) @@ -333,8 +328,7 @@ class SQLContext(object): @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, - list or :class:`pandas.DataFrame`. + Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. @@ -343,15 +337,29 @@ class SQLContext(object): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. + When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or + exception will be thrown at runtime. If the given schema is not StructType, it will be + wrapped into a StructType as its only field, and the field name will be "value", each record + will also be wrapped into a tuple, which can be converted to row later. + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, - :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`StructType` or list of column names. default None. + :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, + etc.), or :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`DataType` or a datatype string or a list of column names, default + is None. The data type string format equals to `DataType.simpleString`, except that + top level struct type can omit the `struct<>` and atomic types use `typeName()` as + their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` + as a short name for IntegerType. :param samplingRatio: the sample ratio of rows used for inferring :return: :class:`DataFrame` + .. versionchanged:: 2.0 + The schema parameter can be a DataType or a datatype string after 2.0. If it's not a + StructType, it will be wrapped into a StructType and each record will also be wrapped + into a tuple. + >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] @@ -388,14 +396,46 @@ class SQLContext(object): [Row(name=u'Alice', age=1)] >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] + + >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect() + [Row(a=u'Alice', b=1)] + >>> rdd = rdd.map(lambda row: row[1]) + >>> sqlContext.createDataFrame(rdd, "int").collect() + [Row(value=1)] + >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError:... """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") + if isinstance(schema, basestring): + schema = _parse_datatype_string(schema) + + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + if isinstance(schema, StructType): + def prepare(obj): + _verify_type(obj, schema) + return obj + elif isinstance(schema, DataType): + datatype = schema + + def prepare(obj): + _verify_type(obj, datatype) + return (obj, ) + schema = StructType().add("value", datatype) + else: + prepare = lambda obj: obj + if isinstance(data, RDD): - rdd, schema = self._createFromRDD(data, schema, samplingRatio) + rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: - rdd, schema = self._createFromLocal(data, schema) + rdd, schema = self._createFromLocal(map(prepare, data), schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) df = DataFrame(jdf, self) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b5b848e1db..9722e9e9ca 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -369,9 +369,7 @@ class SQLTests(ReusedPySparkTestCase): rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) df = self.sqlCtx.createDataFrame(rdd, schema) - message = ".*Input row doesn't have expected number of values required by the schema.*" - with self.assertRaisesRegexp(Exception, message): - df.show() + self.assertRaises(Exception, lambda: df.show()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] @@ -1178,6 +1176,42 @@ class SQLTests(ReusedPySparkTestCase): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_toDF_with_schema_string(self): + data = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(data, 5) + + df = rdd.toDF("key: int, value: string") + self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>") + self.assertEqual(df.collect(), data) + + # different but compatible field types can be used. + df = rdd.toDF("key: string, value: string") + self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>") + self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) + + # field names can differ. + df = rdd.toDF(" a: int, b: string ") + self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>") + self.assertEqual(df.collect(), data) + + # number of fields must match. + self.assertRaisesRegexp(Exception, "Length of object", + lambda: rdd.toDF("key: int").collect()) + + # field types mismatch will cause exception at runtime. + self.assertRaisesRegexp(Exception, "FloatType can not accept", + lambda: rdd.toDF("key: float, value: string").collect()) + + # flat schema values will be wrapped into row. + df = rdd.map(lambda row: row.key).toDF("int") + self.assertEqual(df.schema.simpleString(), "struct<value:int>") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + + # users can use DataType directly instead of data type string. + df = rdd.map(lambda row: row.key).toDF(IntegerType()) + self.assertEqual(df.schema.simpleString(), "struct<value:int>") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c71adfb58f..734c1533a2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -681,6 +681,129 @@ _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) +_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _parse_basic_datatype_string(s): + if s in _all_atomic_types.keys(): + return _all_atomic_types[s]() + elif s == "int": + return IntegerType() + elif _FIXED_DECIMAL.match(s): + m = _FIXED_DECIMAL.match(s) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % s) + + +def _ignore_brackets_split(s, separator): + """ + Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. + given "a,b" and separator ",", it will return ["a", "b"], but given "a<b,c>, d", it will return + ["a<b,c>", "d"]. + """ + parts = [] + buf = "" + level = 0 + for c in s: + if c in _BRACKETS.keys(): + level += 1 + buf += c + elif c in _BRACKETS.values(): + if level == 0: + raise ValueError("Brackets are not correctly paired: %s" % s) + level -= 1 + buf += c + elif c == separator and level > 0: + buf += c + elif c == separator: + parts.append(buf) + buf = "" + else: + buf += c + + if len(buf) == 0: + raise ValueError("The %s cannot be the last char: %s" % (separator, s)) + parts.append(buf) + return parts + + +def _parse_struct_fields_string(s): + parts = _ignore_brackets_split(s, ",") + fields = [] + for part in parts: + name_and_type = _ignore_brackets_split(part, ":") + if len(name_and_type) != 2: + raise ValueError("The strcut field string format is: 'field_name:field_type', " + + "but got: %s" % part) + field_name = name_and_type[0].strip() + field_type = _parse_datatype_string(name_and_type[1]) + fields.append(StructField(field_name, field_type)) + return StructType(fields) + + +def _parse_datatype_string(s): + """ + Parses the given data type string to a :class:`DataType`. The data type string format equals + to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and + atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for + ByteType. We can also use `int` as a short name for IntegerType. + + >>> _parse_datatype_string("int ") + IntegerType + >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") + StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true))) + >>> _parse_datatype_string("a: array< short>") + StructType(List(StructField(a,ArrayType(ShortType,true),true))) + >>> _parse_datatype_string(" map<string , string > ") + MapType(StringType,StringType,true) + + >>> # Error cases + >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + s = s.strip() + if s.startswith("array<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + return ArrayType(_parse_datatype_string(s[6:-1])) + elif s.startswith("map<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + parts = _ignore_brackets_split(s[4:-1], ",") + if len(parts) != 2: + raise ValueError("The map type string format is: 'map<key_type,value_type>', " + + "but got: %s" % s) + kt = _parse_datatype_string(parts[0]) + vt = _parse_datatype_string(parts[1]) + return MapType(kt, vt) + elif s.startswith("struct<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + return _parse_struct_fields_string(s[7:-1]) + elif ":" in s: + return _parse_struct_fields_string(s) + else: + return _parse_basic_datatype_string(s) + + def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. >>> import pickle @@ -730,9 +853,6 @@ def _parse_datatype_json_string(json_string): return _parse_datatype_json_value(json.loads(json_string)) -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): if json_value in _all_atomic_types.keys(): @@ -940,9 +1060,6 @@ def _create_converter(dataType): return convert_struct -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - def _split_schema_abstract(s): """ split the schema abstract into fields |