aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py196
1 files changed, 128 insertions, 68 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 98e41f8575..675df084bf 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -109,6 +109,15 @@ class PrimitiveType(DataType):
return self is other
+class NullType(PrimitiveType):
+
+ """Spark SQL NullType
+
+ The data type representing None, used for the types which has not
+ been inferred.
+ """
+
+
class StringType(PrimitiveType):
"""Spark SQL StringType
@@ -331,7 +340,7 @@ class StructField(DataType):
"""
- def __init__(self, name, dataType, nullable, metadata=None):
+ def __init__(self, name, dataType, nullable=True, metadata=None):
"""Creates a StructField
:param name: the name of this field.
:param dataType: the data type of this field.
@@ -484,6 +493,7 @@ def _parse_datatype_json_value(json_value):
# Mapping Python types to Spark SQL DataType
_type_mappings = {
+ type(None): NullType,
bool: BooleanType,
int: IntegerType,
long: LongType,
@@ -500,22 +510,22 @@ _type_mappings = {
def _infer_type(obj):
"""Infer the DataType from obj"""
- if obj is None:
- raise ValueError("Can not infer type for None")
-
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
if isinstance(obj, dict):
- if not obj:
- raise ValueError("Can not infer type for empty dict")
- key, value = obj.iteritems().next()
- return MapType(_infer_type(key), _infer_type(value), True)
+ for key, value in obj.iteritems():
+ if key is not None and value is not None:
+ return MapType(_infer_type(key), _infer_type(value), True)
+ else:
+ return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
- if not obj:
- raise ValueError("Can not infer type for empty list/array")
- return ArrayType(_infer_type(obj[0]), True)
+ for v in obj:
+ if v is not None:
+ return ArrayType(_infer_type(obj[0]), True)
+ else:
+ return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
@@ -548,60 +558,93 @@ def _infer_schema(row):
return StructType(fields)
-def _create_converter(obj, dataType):
+def _has_nulltype(dt):
+ """ Return whether there is NullType in `dt` or not """
+ if isinstance(dt, StructType):
+ return any(_has_nulltype(f.dataType) for f in dt.fields)
+ elif isinstance(dt, ArrayType):
+ return _has_nulltype((dt.elementType))
+ elif isinstance(dt, MapType):
+ return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+ else:
+ return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+ if isinstance(a, NullType):
+ return b
+ elif isinstance(b, NullType):
+ return a
+ elif type(a) is not type(b):
+ # TODO: type cast (such as int -> long)
+ raise TypeError("Can not merge type %s and %s" % (a, b))
+
+ # same type
+ if isinstance(a, StructType):
+ nfs = dict((f.name, f.dataType) for f in b.fields)
+ fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+ for f in a.fields]
+ names = set([f.name for f in fields])
+ for n in nfs:
+ if n not in names:
+ fields.append(StructField(n, nfs[n]))
+ return StructType(fields)
+
+ elif isinstance(a, ArrayType):
+ return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+ elif isinstance(a, MapType):
+ return MapType(_merge_type(a.keyType, b.keyType),
+ _merge_type(a.valueType, b.valueType),
+ True)
+ else:
+ return a
+
+
+def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if isinstance(dataType, ArrayType):
- conv = _create_converter(obj[0], dataType.elementType)
+ conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
elif isinstance(dataType, MapType):
- value = obj.values()[0]
- conv = _create_converter(value, dataType.valueType)
+ conv = _create_converter(dataType.valueType)
return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ elif isinstance(dataType, NullType):
+ return lambda x: None
+
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
+ converters = [_create_converter(f.dataType) for f in dataType.fields]
+
+ def convert_struct(obj):
+ if obj is None:
+ return
+
+ if isinstance(obj, tuple):
+ if hasattr(obj, "fields"):
+ d = dict(zip(obj.fields, obj))
+ if hasattr(obj, "__FIELDS__"):
+ d = dict(zip(obj.__FIELDS__, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
+ d = dict(obj)
+ else:
+ raise ValueError("unexpected tuple: %s" % obj)
- if isinstance(obj, dict):
- conv = lambda o: tuple(o.get(n) for n in names)
-
- elif isinstance(obj, tuple):
- if hasattr(obj, "_fields"): # namedtuple
- conv = tuple
- elif hasattr(obj, "__FIELDS__"):
- conv = tuple
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- conv = lambda o: tuple(v for k, v in o)
+ elif isinstance(obj, dict):
+ d = obj
+ elif hasattr(obj, "__dict__"): # object
+ d = obj.__dict__
else:
- raise ValueError("unexpected tuple")
+ raise ValueError("Unexpected obj: %s" % obj)
- elif hasattr(obj, "__dict__"): # object
- conv = lambda o: [o.__dict__.get(n, None) for n in names]
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
- if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
- return conv
-
- row = conv(obj)
- convs = [_create_converter(v, f.dataType)
- for v, f in zip(row, dataType.fields)]
-
- def nested_conv(row):
- return tuple(f(v) for f, v in zip(convs, conv(row)))
-
- return nested_conv
-
-
-def _drop_schema(rows, schema):
- """ all the names of fields, becoming tuples"""
- iterator = iter(rows)
- row = iterator.next()
- converter = _create_converter(row, schema)
- yield converter(row)
- for i in iterator:
- yield converter(i)
+ return convert_struct
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
@@ -713,7 +756,7 @@ def _infer_schema_type(obj, dataType):
return _infer_type(obj)
if not obj:
- raise ValueError("Can not infer type from empty value")
+ return NullType()
if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
@@ -1049,18 +1092,20 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
- def inferSchema(self, rdd):
+ def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
- We peek at the first row of the RDD to determine the fields' names
- and types. Nested collections are supported, which include array,
- dict, list, Row, tuple, namedtuple, or object.
+ When samplingRatio is specified, the schema is inferred by looking
+ at the types of each row in the sampled dataset. Otherwise, the
+ first 100 rows of the RDD are inspected. Nested collections are
+ supported, which can include array, dict, list, Row, tuple,
+ namedtuple, or object.
- All the rows in `rdd` should have the same type with the first one,
- or it will cause runtime exceptions.
+ Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+ Using top level dicts is deprecated, as dict is used to represent Maps.
- Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
- using dict is deprecated.
+ If a single column has multiple distinct inferred types, it may cause
+ runtime exceptions.
>>> rdd = sc.parallelize(
... [Row(field1=1, field2="row1"),
@@ -1097,8 +1142,23 @@ class SQLContext(object):
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")
- schema = _infer_schema(first)
- rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ warnings.warn("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio > 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+ converter = _create_converter(schema)
+ rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
def applySchema(self, rdd, schema):
@@ -1219,7 +1279,7 @@ class SQLContext(object):
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)
- def jsonFile(self, path, schema=None):
+ def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
L{SchemaRDD}.
@@ -1227,8 +1287,8 @@ class SQLContext(object):
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -1274,20 +1334,20 @@ class SQLContext(object):
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path)
+ srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
- def jsonRDD(self, rdd, schema=None):
+ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
>>> srdd1 = sqlCtx.jsonRDD(json)
>>> sqlCtx.registerRDDAsTable(srdd1, "table1")
@@ -1344,7 +1404,7 @@ class SQLContext(object):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)