aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/context.py68
-rw-r--r--python/pyspark/sql/tests.py40
-rw-r--r--python/pyspark/sql/types.py129
3 files changed, 214 insertions, 23 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8e324169d8..9c2f6a3c56 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -31,8 +31,8 @@ from py4j.protocol import Py4JError
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
- _infer_schema, _has_nulltype, _merge_type, _create_converter
+from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
+ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
@@ -301,11 +301,6 @@ class SQLContext(object):
Create an RDD for DataFrame from an list or pandas.DataFrame, returns
the RDD and schema.
"""
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
# make sure data could consumed multiple times
if not isinstance(data, list):
data = list(data)
@@ -333,8 +328,7 @@ class SQLContext(object):
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
- Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`,
- list or :class:`pandas.DataFrame`.
+ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
When ``schema`` is a list of column names, the type of each column
will be inferred from ``data``.
@@ -343,15 +337,29 @@ class SQLContext(object):
from ``data``, which should be an RDD of :class:`Row`,
or :class:`namedtuple`, or :class:`dict`.
+ When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or
+ exception will be thrown at runtime. If the given schema is not StructType, it will be
+ wrapped into a StructType as its only field, and the field name will be "value", each record
+ will also be wrapped into a tuple, which can be converted to row later.
+
If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
- :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`,
- :class:`list`, or :class:`pandas.DataFrame`.
- :param schema: a :class:`StructType` or list of column names. default None.
+ :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean,
+ etc.), or :class:`list`, or :class:`pandas.DataFrame`.
+ :param schema: a :class:`DataType` or a datatype string or a list of column names, default
+ is None. The data type string format equals to `DataType.simpleString`, except that
+ top level struct type can omit the `struct<>` and atomic types use `typeName()` as
+ their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int`
+ as a short name for IntegerType.
:param samplingRatio: the sample ratio of rows used for inferring
:return: :class:`DataFrame`
+ .. versionchanged:: 2.0
+ The schema parameter can be a DataType or a datatype string after 2.0. If it's not a
+ StructType, it will be wrapped into a StructType and each record will also be wrapped
+ into a tuple.
+
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
@@ -388,14 +396,46 @@ class SQLContext(object):
[Row(name=u'Alice', age=1)]
>>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
[Row(0=1, 1=2)]
+
+ >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect()
+ [Row(a=u'Alice', b=1)]
+ >>> rdd = rdd.map(lambda row: row[1])
+ >>> sqlContext.createDataFrame(rdd, "int").collect()
+ [Row(value=1)]
+ >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ Py4JJavaError:...
"""
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
+ if isinstance(schema, basestring):
+ schema = _parse_datatype_string(schema)
+
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ if isinstance(schema, StructType):
+ def prepare(obj):
+ _verify_type(obj, schema)
+ return obj
+ elif isinstance(schema, DataType):
+ datatype = schema
+
+ def prepare(obj):
+ _verify_type(obj, datatype)
+ return (obj, )
+ schema = StructType().add("value", datatype)
+ else:
+ prepare = lambda obj: obj
+
if isinstance(data, RDD):
- rdd, schema = self._createFromRDD(data, schema, samplingRatio)
+ rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
else:
- rdd, schema = self._createFromLocal(data, schema)
+ rdd, schema = self._createFromLocal(map(prepare, data), schema)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
df = DataFrame(jdf, self)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5b848e1db..9722e9e9ca 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -369,9 +369,7 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
df = self.sqlCtx.createDataFrame(rdd, schema)
- message = ".*Input row doesn't have expected number of values required by the schema.*"
- with self.assertRaisesRegexp(Exception, message):
- df.show()
+ self.assertRaises(Exception, lambda: df.show())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
@@ -1178,6 +1176,42 @@ class SQLTests(ReusedPySparkTestCase):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
+ def test_toDF_with_schema_string(self):
+ data = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(data, 5)
+
+ df = rdd.toDF("key: int, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
+ self.assertEqual(df.collect(), data)
+
+ # different but compatible field types can be used.
+ df = rdd.toDF("key: string, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
+ self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
+
+ # field names can differ.
+ df = rdd.toDF(" a: int, b: string ")
+ self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
+ self.assertEqual(df.collect(), data)
+
+ # number of fields must match.
+ self.assertRaisesRegexp(Exception, "Length of object",
+ lambda: rdd.toDF("key: int").collect())
+
+ # field types mismatch will cause exception at runtime.
+ self.assertRaisesRegexp(Exception, "FloatType can not accept",
+ lambda: rdd.toDF("key: float, value: string").collect())
+
+ # flat schema values will be wrapped into row.
+ df = rdd.map(lambda row: row.key).toDF("int")
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
+ # users can use DataType directly instead of data type string.
+ df = rdd.map(lambda row: row.key).toDF(IntegerType())
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
class HiveContextSQLTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c71adfb58f..734c1533a2 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -681,6 +681,129 @@ _all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
+_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
+
+
+_BRACKETS = {'(': ')', '[': ']', '{': '}'}
+
+
+def _parse_basic_datatype_string(s):
+ if s in _all_atomic_types.keys():
+ return _all_atomic_types[s]()
+ elif s == "int":
+ return IntegerType()
+ elif _FIXED_DECIMAL.match(s):
+ m = _FIXED_DECIMAL.match(s)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % s)
+
+
+def _ignore_brackets_split(s, separator):
+ """
+ Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
+ given "a,b" and separator ",", it will return ["a", "b"], but given "a<b,c>, d", it will return
+ ["a<b,c>", "d"].
+ """
+ parts = []
+ buf = ""
+ level = 0
+ for c in s:
+ if c in _BRACKETS.keys():
+ level += 1
+ buf += c
+ elif c in _BRACKETS.values():
+ if level == 0:
+ raise ValueError("Brackets are not correctly paired: %s" % s)
+ level -= 1
+ buf += c
+ elif c == separator and level > 0:
+ buf += c
+ elif c == separator:
+ parts.append(buf)
+ buf = ""
+ else:
+ buf += c
+
+ if len(buf) == 0:
+ raise ValueError("The %s cannot be the last char: %s" % (separator, s))
+ parts.append(buf)
+ return parts
+
+
+def _parse_struct_fields_string(s):
+ parts = _ignore_brackets_split(s, ",")
+ fields = []
+ for part in parts:
+ name_and_type = _ignore_brackets_split(part, ":")
+ if len(name_and_type) != 2:
+ raise ValueError("The strcut field string format is: 'field_name:field_type', " +
+ "but got: %s" % part)
+ field_name = name_and_type[0].strip()
+ field_type = _parse_datatype_string(name_and_type[1])
+ fields.append(StructField(field_name, field_type))
+ return StructType(fields)
+
+
+def _parse_datatype_string(s):
+ """
+ Parses the given data type string to a :class:`DataType`. The data type string format equals
+ to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and
+ atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for
+ ByteType. We can also use `int` as a short name for IntegerType.
+
+ >>> _parse_datatype_string("int ")
+ IntegerType
+ >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
+ StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+ >>> _parse_datatype_string("a: array< short>")
+ StructType(List(StructField(a,ArrayType(ShortType,true),true)))
+ >>> _parse_datatype_string(" map<string , string > ")
+ MapType(StringType,StringType,true)
+
+ >>> # Error cases
+ >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ """
+ s = s.strip()
+ if s.startswith("array<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return ArrayType(_parse_datatype_string(s[6:-1]))
+ elif s.startswith("map<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ parts = _ignore_brackets_split(s[4:-1], ",")
+ if len(parts) != 2:
+ raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
+ "but got: %s" % s)
+ kt = _parse_datatype_string(parts[0])
+ vt = _parse_datatype_string(parts[1])
+ return MapType(kt, vt)
+ elif s.startswith("struct<"):
+ if s[-1] != ">":
+ raise ValueError("'>' should be the last char, but got: %s" % s)
+ return _parse_struct_fields_string(s[7:-1])
+ elif ":" in s:
+ return _parse_struct_fields_string(s)
+ else:
+ return _parse_basic_datatype_string(s)
+
+
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
@@ -730,9 +853,6 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))
-_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
-
-
def _parse_datatype_json_value(json_value):
if not isinstance(json_value, dict):
if json_value in _all_atomic_types.keys():
@@ -940,9 +1060,6 @@ def _create_converter(dataType):
return convert_struct
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
def _split_schema_abstract(s):
"""
split the schema abstract into fields