aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-08 14:00:03 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-08 14:00:03 -0800
commitd57daf1f7732a7ac54a91fe112deeda0a254f9ef (patch)
treeab88e42e961763cadd75d95b9c9989b3c460bde6 /python/pyspark/sql/types.py
parentd5ce61722f2aa4167a8344e1664b000c70d5a3f8 (diff)
downloadspark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.gz
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.bz2
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.zip
[SPARK-13593] [SQL] improve the `createDataFrame` to accept data type string and verify the data
## What changes were proposed in this pull request? This PR improves the `createDataFrame` method to make it also accept datatype string, then users can convert python RDD to DataFrame easily, for example, `df = rdd.toDF("a: int, b: string")`. It also supports flat schema so users can convert an RDD of int to DataFrame directly, we will automatically wrap int to row for users. If schema is given, now we checks if the real data matches the given schema, and throw error if it doesn't. ## How was this patch tested? new tests in `test.py` and doc test in `types.py` Author: Wenchen Fan <wenchen@databricks.com> Closes #11444 from cloud-fan/pyrdd.
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py129
1 files changed, 123 insertions, 6 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c71adfb58f..734c1533a2 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -681,6 +681,129 @@ _all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
+_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
+
+
+_BRACKETS = {'(': ')', '[': ']', '{': '}'}
+
+
+def _parse_basic_datatype_string(s):
+ if s in _all_atomic_types.keys():
+ return _all_atomic_types[s]()
+ elif s == "int":
+ return IntegerType()
+ elif _FIXED_DECIMAL.match(s):
+ m = _FIXED_DECIMAL.match(s)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % s)
+
+
+def _ignore_brackets_split(s, separator):
+ """
+ Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
+ given "a,b" and separator ",", it will return ["a", "b"], but given "a<b,c>, d", it will return
+ ["a<b,c>", "d"].
+ """
+ parts = []
+ buf = ""
+ level = 0
+ for c in s:
+ if c in _BRACKETS.keys():
+ level += 1
+ buf += c
+ elif c in _BRACKETS.values():
+ if level == 0:
+ raise ValueError("Brackets are not correctly paired: %s" % s)
+ level -= 1
+ buf += c
+ elif c == separator and level > 0:
+ buf += c
+ elif c == separator:
+ parts.append(buf)
+ buf = ""
+ else:
+ buf += c
+
+ if len(buf) == 0:
+ raise ValueError("The %s cannot be the last char: %s" % (separator, s))
+ parts.append(buf)
+ return parts
+
+
+def _parse_struct_fields_string(s):
+ parts = _ignore_brackets_split(s, ",")
+ fields = []
+ for part in parts:
+ name_and_type = _ignore_brackets_split(part, ":")
+ if len(name_and_type) != 2:
+ raise ValueError("The strcut field string format is: 'field_name:field_type', " +
+ "but got: %s" % part)
+ field_name = name_and_type[0].strip()
+ field_type = _parse_datatype_string(name_and_type[1])
+ fields.append(StructField(field_name, field_type))
+ return StructType(fields)
+
+
+def _parse_datatype_string(s):
+ """
+ Parses the given data type string to a :class:`DataType`. The data type string format equals
+ to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and
+ atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for
+ ByteType. We can also use `int` as a short name for IntegerType.
+
+ >>> _parse_datatype_string("int ")
+ IntegerType
+ >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
+ StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+ >>> _parse_datatype_string("a: array< short>")
+ StructType(List(StructField(a,ArrayType(ShortType,true),true)))
+ >>> _parse_datatype_string(" map<string , string > ")
+ MapType(StringType,StringType,true)
+
+ >>> # Error cases
+ >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ """
+ s = s.strip()
+ if s.startswith("array<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return ArrayType(_parse_datatype_string(s[6:-1]))
+ elif s.startswith("map<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ parts = _ignore_brackets_split(s[4:-1], ",")
+ if len(parts) != 2:
+ raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
+ "but got: %s" % s)
+ kt = _parse_datatype_string(parts[0])
+ vt = _parse_datatype_string(parts[1])
+ return MapType(kt, vt)
+ elif s.startswith("struct<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return _parse_struct_fields_string(s[7:-1])
+ elif ":" in s:
+ return _parse_struct_fields_string(s)
+ else:
+ return _parse_basic_datatype_string(s)
+
+
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
@@ -730,9 +853,6 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))
-_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
-
-
def _parse_datatype_json_value(json_value):
if not isinstance(json_value, dict):
if json_value in _all_atomic_types.keys():
@@ -940,9 +1060,6 @@ def _create_converter(dataType):
return convert_struct
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
def _split_schema_abstract(s):
"""
split the schema abstract into fields