aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py66
1 files changed, 45 insertions, 21 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index adc56e7ec0..950e275adb 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -45,6 +45,7 @@ __all__ = [
class DataType(object):
+
"""Spark SQL DataType"""
def __repr__(self):
@@ -62,6 +63,7 @@ class DataType(object):
class PrimitiveTypeSingleton(type):
+
"""Metaclass for PrimitiveType"""
_instances = {}
@@ -73,6 +75,7 @@ class PrimitiveTypeSingleton(type):
class PrimitiveType(DataType):
+
"""Spark SQL PrimitiveType"""
__metaclass__ = PrimitiveTypeSingleton
@@ -83,6 +86,7 @@ class PrimitiveType(DataType):
class StringType(PrimitiveType):
+
"""Spark SQL StringType
The data type representing string values.
@@ -90,6 +94,7 @@ class StringType(PrimitiveType):
class BinaryType(PrimitiveType):
+
"""Spark SQL BinaryType
The data type representing bytearray values.
@@ -97,6 +102,7 @@ class BinaryType(PrimitiveType):
class BooleanType(PrimitiveType):
+
"""Spark SQL BooleanType
The data type representing bool values.
@@ -104,6 +110,7 @@ class BooleanType(PrimitiveType):
class TimestampType(PrimitiveType):
+
"""Spark SQL TimestampType
The data type representing datetime.datetime values.
@@ -111,6 +118,7 @@ class TimestampType(PrimitiveType):
class DecimalType(PrimitiveType):
+
"""Spark SQL DecimalType
The data type representing decimal.Decimal values.
@@ -118,6 +126,7 @@ class DecimalType(PrimitiveType):
class DoubleType(PrimitiveType):
+
"""Spark SQL DoubleType
The data type representing float values.
@@ -125,6 +134,7 @@ class DoubleType(PrimitiveType):
class FloatType(PrimitiveType):
+
"""Spark SQL FloatType
The data type representing single precision floating-point values.
@@ -132,6 +142,7 @@ class FloatType(PrimitiveType):
class ByteType(PrimitiveType):
+
"""Spark SQL ByteType
The data type representing int values with 1 singed byte.
@@ -139,6 +150,7 @@ class ByteType(PrimitiveType):
class IntegerType(PrimitiveType):
+
"""Spark SQL IntegerType
The data type representing int values.
@@ -146,6 +158,7 @@ class IntegerType(PrimitiveType):
class LongType(PrimitiveType):
+
"""Spark SQL LongType
The data type representing long values. If the any value is
@@ -155,6 +168,7 @@ class LongType(PrimitiveType):
class ShortType(PrimitiveType):
+
"""Spark SQL ShortType
The data type representing int values with 2 signed bytes.
@@ -162,6 +176,7 @@ class ShortType(PrimitiveType):
class ArrayType(DataType):
+
"""Spark SQL ArrayType
The data type representing list values. An ArrayType object
@@ -187,10 +202,11 @@ class ArrayType(DataType):
def __str__(self):
return "ArrayType(%s,%s)" % (self.elementType,
- str(self.containsNull).lower())
+ str(self.containsNull).lower())
class MapType(DataType):
+
"""Spark SQL MapType
The data type representing dict values. A MapType object comprises
@@ -226,10 +242,11 @@ class MapType(DataType):
def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
- str(self.valueContainsNull).lower())
+ str(self.valueContainsNull).lower())
class StructField(DataType):
+
"""Spark SQL StructField
Represents a field in a StructType.
@@ -263,10 +280,11 @@ class StructField(DataType):
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
- str(self.nullable).lower())
+ str(self.nullable).lower())
class StructType(DataType):
+
"""Spark SQL StructType
The data type representing rows.
@@ -291,7 +309,7 @@ class StructType(DataType):
def __repr__(self):
return ("StructType(List(%s))" %
- ",".join(str(field) for field in self.fields))
+ ",".join(str(field) for field in self.fields))
def _parse_datatype_list(datatype_list_string):
@@ -319,7 +337,7 @@ def _parse_datatype_list(datatype_list_string):
_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
- if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
+ if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
def _parse_datatype_string(datatype_string):
@@ -459,16 +477,16 @@ def _infer_schema(row):
items = sorted(row.items())
elif isinstance(row, tuple):
- if hasattr(row, "_fields"): # namedtuple
+ if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
- elif hasattr(row, "__FIELDS__"): # Row
+ elif hasattr(row, "__FIELDS__"): # Row
items = zip(row.__FIELDS__, tuple(row))
elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
items = row
else:
raise ValueError("Can't infer schema from tuple")
- elif hasattr(row, "__dict__"): # object
+ elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
else:
@@ -499,7 +517,7 @@ def _create_converter(obj, dataType):
conv = lambda o: tuple(o.get(n) for n in names)
elif isinstance(obj, tuple):
- if hasattr(obj, "_fields"): # namedtuple
+ if hasattr(obj, "_fields"): # namedtuple
conv = tuple
elif hasattr(obj, "__FIELDS__"):
conv = tuple
@@ -508,7 +526,7 @@ def _create_converter(obj, dataType):
else:
raise ValueError("unexpected tuple")
- elif hasattr(obj, "__dict__"): # object
+ elif hasattr(obj, "__dict__"): # object
conv = lambda o: [o.__dict__.get(n, None) for n in names]
nested = any(_has_struct(f.dataType) for f in dataType.fields)
@@ -660,7 +678,7 @@ def _infer_schema_type(obj, dataType):
assert len(fs) == len(obj), \
"Obj(%s) have different length with fields(%s)" % (obj, fs)
fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
- for o, f in zip(obj, fs)]
+ for o, f in zip(obj, fs)]
return StructType(fields)
else:
@@ -683,6 +701,7 @@ _acceptable_types = {
StructType: (tuple, list),
}
+
def _verify_type(obj, dataType):
"""
Verify the type of obj against dataType, raise an exception if
@@ -728,7 +747,7 @@ def _verify_type(obj, dataType):
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with"
- "length of fields (%d)" % (len(obj), len(dataType.fields)))
+ "length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
@@ -861,6 +880,7 @@ def _create_cls(dataType):
raise Exception("unexpected data type: %s" % dataType)
class Row(tuple):
+
""" Row in SchemaRDD """
__DATATYPE__ = dataType
__FIELDS__ = tuple(f.name for f in dataType.fields)
@@ -872,7 +892,7 @@ def _create_cls(dataType):
def __repr__(self):
# call collect __repr__ for nested objects
return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
- for n in self.__FIELDS__))
+ for n in self.__FIELDS__))
def __reduce__(self):
return (_restore_object, (self.__DATATYPE__, tuple(self)))
@@ -881,6 +901,7 @@ def _create_cls(dataType):
class SQLContext:
+
"""Main entry point for SparkSQL functionality.
A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as
@@ -960,7 +981,7 @@ class SQLContext:
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
- self._sc._gateway._gateway_client)
+ self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
bytearray(CloudPickleSerializer().dumps(command)),
env,
@@ -1012,7 +1033,7 @@ class SQLContext:
first = rdd.first()
if not first:
raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
+ "can not infer schema")
if type(first) is dict:
warnings.warn("Using RDD of dict to inferSchema is deprecated")
@@ -1287,6 +1308,7 @@ class SQLContext:
class HiveContext(SQLContext):
+
"""A variant of Spark SQL that integrates with data stored in Hive.
Configuration for Hive is read from hive-site.xml on the classpath.
@@ -1327,6 +1349,7 @@ class HiveContext(SQLContext):
class LocalHiveContext(HiveContext):
+
"""Starts up an instance of hive where metadata is stored locally.
An in-process metadata data is created with data stored in ./metadata.
@@ -1357,7 +1380,7 @@ class LocalHiveContext(HiveContext):
def __init__(self, sparkContext, sqlContext=None):
HiveContext.__init__(self, sparkContext, sqlContext)
warnings.warn("LocalHiveContext is deprecated. "
- "Use HiveContext instead.", DeprecationWarning)
+ "Use HiveContext instead.", DeprecationWarning)
def _get_hive_ctx(self):
return self._jvm.LocalHiveContext(self._jsc.sc())
@@ -1376,6 +1399,7 @@ def _create_row(fields, values):
class Row(tuple):
+
"""
A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
@@ -1417,7 +1441,6 @@ class Row(tuple):
else:
raise ValueError("No args or kwargs")
-
# let obect acs like class
def __call__(self, *args):
"""create new Row object"""
@@ -1443,12 +1466,13 @@ class Row(tuple):
def __repr__(self):
if hasattr(self, "__FIELDS__"):
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
- for k, v in zip(self.__FIELDS__, self))
+ for k, v in zip(self.__FIELDS__, self))
else:
return "<Row(%s)>" % ", ".join(self)
class SchemaRDD(RDD):
+
"""An RDD of L{Row} objects that has an associated schema.
The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
@@ -1659,7 +1683,7 @@ class SchemaRDD(RDD):
rdd = self._jschema_rdd.subtract(other._jschema_rdd)
else:
rdd = self._jschema_rdd.subtract(other._jschema_rdd,
- numPartitions)
+ numPartitions)
return SchemaRDD(rdd, self.sql_ctx)
else:
raise ValueError("Can only subtract another SchemaRDD")
@@ -1686,9 +1710,9 @@ def _test():
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
- '"field6":[{"field7": "row2"}]}',
+ '"field6":[{"field7": "row2"}]}',
'{"field1" : null, "field2": "row3", '
- '"field3":{"field4":33, "field5": []}}'
+ '"field3":{"field4":33, "field5": []}}'
]
globs['jsonStrings'] = jsonStrings
globs['json'] = sc.parallelize(jsonStrings)