aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-08 14:00:03 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-08 14:00:03 -0800
commitd57daf1f7732a7ac54a91fe112deeda0a254f9ef (patch)
treeab88e42e961763cadd75d95b9c9989b3c460bde6 /python
parentd5ce61722f2aa4167a8344e1664b000c70d5a3f8 (diff)
downloadspark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.gz
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.bz2
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.zip
[SPARK-13593] [SQL] improve the `createDataFrame` to accept data type string and verify the data
## What changes were proposed in this pull request? This PR improves the `createDataFrame` method to make it also accept datatype string, then users can convert python RDD to DataFrame easily, for example, `df = rdd.toDF("a: int, b: string")`. It also supports flat schema so users can convert an RDD of int to DataFrame directly, we will automatically wrap int to row for users. If schema is given, now we checks if the real data matches the given schema, and throw error if it doesn't. ## How was this patch tested? new tests in `test.py` and doc test in `types.py` Author: Wenchen Fan <wenchen@databricks.com> Closes #11444 from cloud-fan/pyrdd.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/context.py68
-rw-r--r--python/pyspark/sql/tests.py40
-rw-r--r--python/pyspark/sql/types.py129
3 files changed, 214 insertions, 23 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8e324169d8..9c2f6a3c56 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -31,8 +31,8 @@ from py4j.protocol import Py4JError
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
- _infer_schema, _has_nulltype, _merge_type, _create_converter
+from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
+ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
@@ -301,11 +301,6 @@ class SQLContext(object):
Create an RDD for DataFrame from an list or pandas.DataFrame, returns
the RDD and schema.
"""
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
# make sure data could consumed multiple times
if not isinstance(data, list):
data = list(data)
@@ -333,8 +328,7 @@ class SQLContext(object):
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
- Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`,
- list or :class:`pandas.DataFrame`.
+ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
When ``schema`` is a list of column names, the type of each column
will be inferred from ``data``.
@@ -343,15 +337,29 @@ class SQLContext(object):
from ``data``, which should be an RDD of :class:`Row`,
or :class:`namedtuple`, or :class:`dict`.
+ When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or
+ exception will be thrown at runtime. If the given schema is not StructType, it will be
+ wrapped into a StructType as its only field, and the field name will be "value", each record
+ will also be wrapped into a tuple, which can be converted to row later.
+
If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
- :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`,
- :class:`list`, or :class:`pandas.DataFrame`.
- :param schema: a :class:`StructType` or list of column names. default None.
+ :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean,
+ etc.), or :class:`list`, or :class:`pandas.DataFrame`.
+ :param schema: a :class:`DataType` or a datatype string or a list of column names, default
+ is None. The data type string format equals to `DataType.simpleString`, except that
+ top level struct type can omit the `struct<>` and atomic types use `typeName()` as
+ their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int`
+ as a short name for IntegerType.
:param samplingRatio: the sample ratio of rows used for inferring
:return: :class:`DataFrame`
+ .. versionchanged:: 2.0
+ The schema parameter can be a DataType or a datatype string after 2.0. If it's not a
+ StructType, it will be wrapped into a StructType and each record will also be wrapped
+ into a tuple.
+
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
@@ -388,14 +396,46 @@ class SQLContext(object):
[Row(name=u'Alice', age=1)]
>>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
[Row(0=1, 1=2)]
+
+ >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect()
+ [Row(a=u'Alice', b=1)]
+ >>> rdd = rdd.map(lambda row: row[1])
+ >>> sqlContext.createDataFrame(rdd, "int").collect()
+ [Row(value=1)]
+ >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ Py4JJavaError:...
"""
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
+ if isinstance(schema, basestring):
+ schema = _parse_datatype_string(schema)
+
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ if isinstance(schema, StructType):
+ def prepare(obj):
+ _verify_type(obj, schema)
+ return obj
+ elif isinstance(schema, DataType):
+ datatype = schema
+
+ def prepare(obj):
+ _verify_type(obj, datatype)
+ return (obj, )
+ schema = StructType().add("value", datatype)
+ else:
+ prepare = lambda obj: obj
+
if isinstance(data, RDD):
- rdd, schema = self._createFromRDD(data, schema, samplingRatio)
+ rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
else:
- rdd, schema = self._createFromLocal(data, schema)
+ rdd, schema = self._createFromLocal(map(prepare, data), schema)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
df = DataFrame(jdf, self)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5b848e1db..9722e9e9ca 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -369,9 +369,7 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
df = self.sqlCtx.createDataFrame(rdd, schema)
- message = ".*Input row doesn't have expected number of values required by the schema.*"
- with self.assertRaisesRegexp(Exception, message):
- df.show()
+ self.assertRaises(Exception, lambda: df.show())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
@@ -1178,6 +1176,42 @@ class SQLTests(ReusedPySparkTestCase):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
+ def test_toDF_with_schema_string(self):
+ data = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(data, 5)
+
+ df = rdd.toDF("key: int, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
+ self.assertEqual(df.collect(), data)
+
+ # different but compatible field types can be used.
+ df = rdd.toDF("key: string, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
+ self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
+
+ # field names can differ.
+ df = rdd.toDF(" a: int, b: string ")
+ self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
+ self.assertEqual(df.collect(), data)
+
+ # number of fields must match.
+ self.assertRaisesRegexp(Exception, "Length of object",
+ lambda: rdd.toDF("key: int").collect())
+
+ # field types mismatch will cause exception at runtime.
+ self.assertRaisesRegexp(Exception, "FloatType can not accept",
+ lambda: rdd.toDF("key: float, value: string").collect())
+
+ # flat schema values will be wrapped into row.
+ df = rdd.map(lambda row: row.key).toDF("int")
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
+ # users can use DataType directly instead of data type string.
+ df = rdd.map(lambda row: row.key).toDF(IntegerType())
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
class HiveContextSQLTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c71adfb58f..734c1533a2 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -681,6 +681,129 @@ _all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
+_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
+
+
+_BRACKETS = {'(': ')', '[': ']', '{': '}'}
+
+
+def _parse_basic_datatype_string(s):
+ if s in _all_atomic_types.keys():
+ return _all_atomic_types[s]()
+ elif s == "int":
+ return IntegerType()
+ elif _FIXED_DECIMAL.match(s):
+ m = _FIXED_DECIMAL.match(s)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % s)
+
+
+def _ignore_brackets_split(s, separator):
+ """
+ Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
+ given "a,b" and separator ",", it will return ["a", "b"], but given "a<b,c>, d", it will return
+ ["a<b,c>", "d"].
+ """
+ parts = []
+ buf = ""
+ level = 0
+ for c in s:
+ if c in _BRACKETS.keys():
+ level += 1
+ buf += c
+ elif c in _BRACKETS.values():
+ if level == 0:
+ raise ValueError("Brackets are not correctly paired: %s" % s)
+ level -= 1
+ buf += c
+ elif c == separator and level > 0:
+ buf += c
+ elif c == separator:
+ parts.append(buf)
+ buf = ""
+ else:
+ buf += c
+
+ if len(buf) == 0:
+ raise ValueError("The %s cannot be the last char: %s" % (separator, s))
+ parts.append(buf)
+ return parts
+
+
+def _parse_struct_fields_string(s):
+ parts = _ignore_brackets_split(s, ",")
+ fields = []
+ for part in parts:
+ name_and_type = _ignore_brackets_split(part, ":")
+ if len(name_and_type) != 2:
+ raise ValueError("The strcut field string format is: 'field_name:field_type', " +
+ "but got: %s" % part)
+ field_name = name_and_type[0].strip()
+ field_type = _parse_datatype_string(name_and_type[1])
+ fields.append(StructField(field_name, field_type))
+ return StructType(fields)
+
+
+def _parse_datatype_string(s):
+ """
+ Parses the given data type string to a :class:`DataType`. The data type string format equals
+ to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and
+ atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for
+ ByteType. We can also use `int` as a short name for IntegerType.
+
+ >>> _parse_datatype_string("int ")
+ IntegerType
+ >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
+ StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+ >>> _parse_datatype_string("a: array< short>")
+ StructType(List(StructField(a,ArrayType(ShortType,true),true)))
+ >>> _parse_datatype_string(" map<string , string > ")
+ MapType(StringType,StringType,true)
+
+ >>> # Error cases
+ >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ """
+ s = s.strip()
+ if s.startswith("array<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return ArrayType(_parse_datatype_string(s[6:-1]))
+ elif s.startswith("map<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ parts = _ignore_brackets_split(s[4:-1], ",")
+ if len(parts) != 2:
+ raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
+ "but got: %s" % s)
+ kt = _parse_datatype_string(parts[0])
+ vt = _parse_datatype_string(parts[1])
+ return MapType(kt, vt)
+ elif s.startswith("struct<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return _parse_struct_fields_string(s[7:-1])
+ elif ":" in s:
+ return _parse_struct_fields_string(s)
+ else:
+ return _parse_basic_datatype_string(s)
+
+
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
@@ -730,9 +853,6 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))
-_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
-
-
def _parse_datatype_json_value(json_value):
if not isinstance(json_value, dict):
if json_value in _all_atomic_types.keys():
@@ -940,9 +1060,6 @@ def _create_converter(dataType):
return convert_struct
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
def _split_schema_abstract(s):
"""
split the schema abstract into fields