aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py567
1 files changed, 553 insertions, 14 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index a6b3277db3..13f0ed4e35 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,7 +20,451 @@ from pyspark.serializers import BatchedSerializer, PickleSerializer
from py4j.protocol import Py4JError
-__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+__all__ = [
+ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
+ "ShortType", "ArrayType", "MapType", "StructField", "StructType",
+ "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+
+
+class PrimitiveTypeSingleton(type):
+ _instances = {}
+
+ def __call__(cls):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
+ return cls._instances[cls]
+
+
+class StringType(object):
+ """Spark SQL StringType
+
+ The data type representing string values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "StringType"
+
+
+class BinaryType(object):
+ """Spark SQL BinaryType
+
+ The data type representing bytearray values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BinaryType"
+
+
+class BooleanType(object):
+ """Spark SQL BooleanType
+
+ The data type representing bool values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BooleanType"
+
+
+class TimestampType(object):
+ """Spark SQL TimestampType
+
+ The data type representing datetime.datetime values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "TimestampType"
+
+
+class DecimalType(object):
+ """Spark SQL DecimalType
+
+ The data type representing decimal.Decimal values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DecimalType"
+
+
+class DoubleType(object):
+ """Spark SQL DoubleType
+
+ The data type representing float values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DoubleType"
+
+
+class FloatType(object):
+ """Spark SQL FloatType
+
+ The data type representing single precision floating-point values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "FloatType"
+
+
+class ByteType(object):
+ """Spark SQL ByteType
+
+ The data type representing int values with 1 singed byte.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ByteType"
+
+
+class IntegerType(object):
+ """Spark SQL IntegerType
+
+ The data type representing int values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "IntegerType"
+
+
+class LongType(object):
+ """Spark SQL LongType
+
+ The data type representing long values. If the any value is beyond the range of
+ [-9223372036854775808, 9223372036854775807], please use DecimalType.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "LongType"
+
+
+class ShortType(object):
+ """Spark SQL ShortType
+
+ The data type representing int values with 2 signed bytes.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ShortType"
+
+
+class ArrayType(object):
+ """Spark SQL ArrayType
+
+ The data type representing list values.
+ An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool).
+ The field of elementType is used to specify the type of array elements.
+ The field of containsNull is used to specify if the array has None values.
+
+ """
+ def __init__(self, elementType, containsNull=False):
+ """Creates an ArrayType
+
+ :param elementType: the data type of elements.
+ :param containsNull: indicates whether the list contains None values.
+
+ >>> ArrayType(StringType) == ArrayType(StringType, False)
+ True
+ >>> ArrayType(StringType, True) == ArrayType(StringType)
+ False
+ """
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self):
+ return "ArrayType(" + self.elementType.__repr__() + "," + \
+ str(self.containsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.elementType == other.elementType and
+ self.containsNull == other.containsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class MapType(object):
+ """Spark SQL MapType
+
+ The data type representing dict values.
+ A MapType object comprises three fields,
+ keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool).
+ The field of keyType is used to specify the type of keys in the map.
+ The field of valueType is used to specify the type of values in the map.
+ The field of valueContainsNull is used to specify if values of this map has None values.
+ For values of a MapType column, keys are not allowed to have None values.
+
+ """
+ def __init__(self, keyType, valueType, valueContainsNull=True):
+ """Creates a MapType
+ :param keyType: the data type of keys.
+ :param valueType: the data type of values.
+ :param valueContainsNull: indicates whether values contains null values.
+
+ >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
+ True
+ >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType)
+ False
+ """
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self):
+ return "MapType(" + self.keyType.__repr__() + "," + \
+ self.valueType.__repr__() + "," + \
+ str(self.valueContainsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.keyType == other.keyType and
+ self.valueType == other.valueType and
+ self.valueContainsNull == other.valueContainsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructField(object):
+ """Spark SQL StructField
+
+ Represents a field in a StructType.
+ A StructField object comprises three fields, name (a string), dataType (a DataType),
+ and nullable (a bool). The field of name is the name of a StructField. The field of
+ dataType specifies the data type of a StructField.
+ The field of nullable specifies if values of a StructField can contain None values.
+
+ """
+ def __init__(self, name, dataType, nullable):
+ """Creates a StructField
+ :param name: the name of this field.
+ :param dataType: the data type of this field.
+ :param nullable: indicates whether values of this field can be null.
+
+ >>> StructField("f1", StringType, True) == StructField("f1", StringType, True)
+ True
+ >>> StructField("f1", StringType, True) == StructField("f2", StringType, True)
+ False
+ """
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+
+ def __repr__(self):
+ return "StructField(" + self.name + "," + \
+ self.dataType.__repr__() + "," + \
+ str(self.nullable).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.name == other.name and
+ self.dataType == other.dataType and
+ self.nullable == other.nullable)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructType(object):
+ """Spark SQL StructType
+
+ The data type representing namedtuple values.
+ A StructType object comprises a list of L{StructField}s.
+
+ """
+ def __init__(self, fields):
+ """Creates a StructType
+
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True),
+ ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 == struct2
+ False
+ """
+ self.fields = fields
+
+ def __repr__(self):
+ return "StructType(List(" + \
+ ",".join([field.__repr__() for field in self.fields]) + "))"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.fields == other.fields)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def _parse_datatype_list(datatype_list_string):
+ """Parses a list of comma separated data types."""
+ index = 0
+ datatype_list = []
+ start = 0
+ depth = 0
+ while index < len(datatype_list_string):
+ if depth == 0 and datatype_list_string[index] == ",":
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ start = index + 1
+ elif datatype_list_string[index] == "(":
+ depth += 1
+ elif datatype_list_string[index] == ")":
+ depth -= 1
+
+ index += 1
+
+ # Handle the last data type
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ return datatype_list
+
+
+def _parse_datatype_string(datatype_string):
+ """Parses the given data type string.
+
+ >>> def check_datatype(datatype):
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
+ ... python_datatype = _parse_datatype_string(scala_datatype.toString())
+ ... return datatype == python_datatype
+ >>> check_datatype(StringType())
+ True
+ >>> check_datatype(BinaryType())
+ True
+ >>> check_datatype(BooleanType())
+ True
+ >>> check_datatype(TimestampType())
+ True
+ >>> check_datatype(DecimalType())
+ True
+ >>> check_datatype(DoubleType())
+ True
+ >>> check_datatype(FloatType())
+ True
+ >>> check_datatype(ByteType())
+ True
+ >>> check_datatype(IntegerType())
+ True
+ >>> check_datatype(LongType())
+ True
+ >>> check_datatype(ShortType())
+ True
+ >>> # Simple ArrayType.
+ >>> simple_arraytype = ArrayType(StringType(), True)
+ >>> check_datatype(simple_arraytype)
+ True
+ >>> # Simple MapType.
+ >>> simple_maptype = MapType(StringType(), LongType())
+ >>> check_datatype(simple_maptype)
+ True
+ >>> # Simple StructType.
+ >>> simple_structtype = StructType([
+ ... StructField("a", DecimalType(), False),
+ ... StructField("b", BooleanType(), True),
+ ... StructField("c", LongType(), True),
+ ... StructField("d", BinaryType(), False)])
+ >>> check_datatype(simple_structtype)
+ True
+ >>> # Complex StructType.
+ >>> complex_structtype = StructType([
+ ... StructField("simpleArray", simple_arraytype, True),
+ ... StructField("simpleMap", simple_maptype, True),
+ ... StructField("simpleStruct", simple_structtype, True),
+ ... StructField("boolean", BooleanType(), False)])
+ >>> check_datatype(complex_structtype)
+ True
+ >>> # Complex ArrayType.
+ >>> complex_arraytype = ArrayType(complex_structtype, True)
+ >>> check_datatype(complex_arraytype)
+ True
+ >>> # Complex MapType.
+ >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False)
+ >>> check_datatype(complex_maptype)
+ True
+ """
+ left_bracket_index = datatype_string.find("(")
+ if left_bracket_index == -1:
+ # It is a primitive type.
+ left_bracket_index = len(datatype_string)
+ type_or_field = datatype_string[:left_bracket_index]
+ rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip()
+ if type_or_field == "StringType":
+ return StringType()
+ elif type_or_field == "BinaryType":
+ return BinaryType()
+ elif type_or_field == "BooleanType":
+ return BooleanType()
+ elif type_or_field == "TimestampType":
+ return TimestampType()
+ elif type_or_field == "DecimalType":
+ return DecimalType()
+ elif type_or_field == "DoubleType":
+ return DoubleType()
+ elif type_or_field == "FloatType":
+ return FloatType()
+ elif type_or_field == "ByteType":
+ return ByteType()
+ elif type_or_field == "IntegerType":
+ return IntegerType()
+ elif type_or_field == "LongType":
+ return LongType()
+ elif type_or_field == "ShortType":
+ return ShortType()
+ elif type_or_field == "ArrayType":
+ last_comma_index = rest_part.rfind(",")
+ containsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ containsNull = False
+ elementType = _parse_datatype_string(rest_part[:last_comma_index].strip())
+ return ArrayType(elementType, containsNull)
+ elif type_or_field == "MapType":
+ last_comma_index = rest_part.rfind(",")
+ valueContainsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ valueContainsNull = False
+ keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip())
+ return MapType(keyType, valueType, valueContainsNull)
+ elif type_or_field == "StructField":
+ first_comma_index = rest_part.find(",")
+ name = rest_part[:first_comma_index].strip()
+ last_comma_index = rest_part.rfind(",")
+ nullable = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ nullable = False
+ dataType = _parse_datatype_string(
+ rest_part[first_comma_index+1:last_comma_index].strip())
+ return StructField(name, dataType, nullable)
+ elif type_or_field == "StructType":
+ # rest_part should be in the format like
+ # List(StructField(field1,IntegerType,false)).
+ field_list_string = rest_part[rest_part.find("(")+1:-1]
+ fields = _parse_datatype_list(field_list_string)
+ return StructType(fields)
class SQLContext:
@@ -109,6 +553,40 @@ class SQLContext:
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
return SchemaRDD(srdd, self)
+ def applySchema(self, rdd, schema):
+ """Applies the given schema to the given RDD of L{dict}s.
+
+ >>> schema = StructType([StructField("field1", IntegerType(), False),
+ ... StructField("field2", StringType(), False)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.sql("SELECT * from table1")
+ >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
+ ... {"field1" : 3, "field2": "row3"}]
+ True
+ >>> from datetime import datetime
+ >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0,
+ ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2},
+ ... "list": [1, 2, 3]}])
+ >>> schema = StructType([
+ ... StructField("byte", ByteType(), False),
+ ... StructField("short", ShortType(), False),
+ ... StructField("float", FloatType(), False),
+ ... StructField("time", TimestampType(), False),
+ ... StructField("map", MapType(StringType(), IntegerType(), False), False),
+ ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False),
+ ... StructField("list", ArrayType(ByteType(), False), False),
+ ... StructField("null", DoubleType(), True)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema).map(
+ ... lambda x: (
+ ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null))
+ >>> srdd.collect()[0]
+ (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ """
+ jrdd = self._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__())
+ return SchemaRDD(srdd, self)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -139,10 +617,11 @@ class SQLContext:
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)
- def jsonFile(self, path):
- """Loads a text file storing one JSON object per line,
- returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonFile(self, path, schema=None):
+ """Loads a text file storing one JSON object per line as a L{SchemaRDD}.
+
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -151,8 +630,8 @@ class SQLContext:
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -160,16 +639,45 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
- jschema_rdd = self._ssql_ctx.jsonFile(path)
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonFile(path)
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(jschema_rdd, self)
- def jsonRDD(self, rdd):
- """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonRDD(self, rdd, schema=None):
+ """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
- >>> srdd = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
+
+ >>> srdd1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -177,6 +685,29 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
def func(split, iterator):
for x in iterator:
@@ -186,7 +717,11 @@ class SQLContext:
keyed = PipelinedRDD(rdd, func)
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
- jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(jschema_rdd, self)
def sql(self, sqlQuery):
@@ -389,6 +924,10 @@ class SchemaRDD(RDD):
"""Creates a new table with the contents of this SchemaRDD."""
self._jschema_rdd.saveAsTable(tableName)
+ def schema(self):
+ """Returns the schema of this SchemaRDD (represented by a L{StructType})."""
+ return _parse_datatype_string(self._jschema_rdd.schema().toString())
+
def schemaString(self):
"""Returns the output schema in the tree format."""
return self._jschema_rdd.schemaString()