aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYin Huai <huai@cse.ohio-state.edu>2014-07-30 00:15:31 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-30 00:15:31 -0700
commit7003c163dbb46bb7313aab130a33486a356435a8 (patch)
treed125fa76f05683c209bd7f60a64da3f73d3c82ca /python
parent4ce92ccaf761e48a10fc4fe4927dbfca858ca22b (diff)
downloadspark-7003c163dbb46bb7313aab130a33486a356435a8.tar.gz
spark-7003c163dbb46bb7313aab130a33486a356435a8.tar.bz2
spark-7003c163dbb46bb7313aab130a33486a356435a8.zip
[SPARK-2179][SQL] Public API for DataTypes and Schema
The current PR contains the following changes: * Expose `DataType`s in the sql package (internal details are private to sql). * Users can create Rows. * Introduce `applySchema` to create a `SchemaRDD` by applying a `schema: StructType` to an `RDD[Row]`. * Add a function `simpleString` to every `DataType`. Also, the schema represented by a `StructType` can be visualized by `printSchema`. * `ScalaReflection.typeOfObject` provides a way to infer the Catalyst data type based on an object. Also, we can compose `typeOfObject` with some custom logics to form a new function to infer the data type (for different use cases). * `JsonRDD` has been refactored to use changes introduced by this PR. * Add a field `containsNull` to `ArrayType`. So, we can explicitly mark if an `ArrayType` can contain null values. The default value of `containsNull` is `false`. New APIs are introduced in the sql package object and SQLContext. You can find the scaladoc at [sql package object](http://yhuai.github.io/site/api/scala/index.html#org.apache.spark.sql.package) and [SQLContext](http://yhuai.github.io/site/api/scala/index.html#org.apache.spark.sql.SQLContext). An example of using `applySchema` is shown below. ```scala import org.apache.spark.sql._ val sqlContext = new org.apache.spark.sql.SQLContext(sc) val schema = StructType( StructField("name", StringType, false) :: StructField("age", IntegerType, true) :: Nil) val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Row(p(0), p(1).trim.toInt)) val peopleSchemaRDD = sqlContext. applySchema(people, schema) peopleSchemaRDD.printSchema // root // |-- name: string (nullable = false) // |-- age: integer (nullable = true) peopleSchemaRDD.registerAsTable("people") sqlContext.sql("select name from people").collect.foreach(println) ``` I will add new contents to the SQL programming guide later. JIRA: https://issues.apache.org/jira/browse/SPARK-2179 Author: Yin Huai <huai@cse.ohio-state.edu> Closes #1346 from yhuai/dataTypeAndSchema and squashes the following commits: 1d45977 [Yin Huai] Clean up. a6e08b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema c712fbf [Yin Huai] Converts types of values based on defined schema. 4ceeb66 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema e5f8df5 [Yin Huai] Scaladoc. 122d1e7 [Yin Huai] Address comments. 03bfd95 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 2476ed0 [Yin Huai] Minor updates. ab71f21 [Yin Huai] Format. fc2bed1 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema bd40a33 [Yin Huai] Address comments. 991f860 [Yin Huai] Move "asJavaDataType" and "asScalaDataType" to DataTypeConversions.scala. 1cb35fe [Yin Huai] Add "valueContainsNull" to MapType. 3edb3ae [Yin Huai] Python doc. 692c0b9 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 1d93395 [Yin Huai] Python APIs. 246da96 [Yin Huai] Add java data type APIs to javadoc index. 1db9531 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema d48fc7b [Yin Huai] Minor updates. 33c4fec [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema b9f3071 [Yin Huai] Java API for applySchema. 1c9f33c [Yin Huai] Java APIs for DataTypes and Row. 624765c [Yin Huai] Tests for applySchema. aa92e84 [Yin Huai] Update data type tests. 8da1a17 [Yin Huai] Add Row.fromSeq. 9c99bc0 [Yin Huai] Several minor updates. 1d9c13a [Yin Huai] Update applySchema API. 85e9b51 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema e495e4e [Yin Huai] More comments. 42d47a3 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema c3f4a02 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 2e58dbd [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema b8b7db4 [Yin Huai] 1. Move sql package object and package-info to sql-core. 2. Minor updates on APIs. 3. Update scala doc. 68525a2 [Yin Huai] Update JSON unit test. 3209108 [Yin Huai] Add unit tests. dcaf22f [Yin Huai] Add a field containsNull to ArrayType to indicate if an array can contain null values or not. If an ArrayType is constructed by "ArrayType(elementType)" (the existing constructor), the value of containsNull is false. 9168b83 [Yin Huai] Update comments. fc649d7 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema eca7d04 [Yin Huai] Add two apply methods which will be used to extract StructField(s) from a StructType. 949d6bb [Yin Huai] When creating a SchemaRDD for a JSON dataset, users can apply an existing schema. 7a6a7e5 [Yin Huai] Fix bug introduced by the change made on SQLContext.inferSchema. 43a45e1 [Yin Huai] Remove sql.util.package introduced in a previous commit. 0266761 [Yin Huai] Format 03eec4c [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 90460ac [Yin Huai] Infer the Catalyst data type from an object and cast a data value to the expected type. 3fa0df5 [Yin Huai] Provide easier ways to construct a StructType. 16be3e5 [Yin Huai] This commit contains three changes: * Expose `DataType`s in the sql package (internal details are private to sql). * Introduce `createSchemaRDD` to create a `SchemaRDD` from an `RDD` with a provided schema (represented by a `StructType`) and a provided function to construct `Row`, * Add a function `simpleString` to every `DataType`. Also, the schema represented by a `StructType` can be visualized by `printSchema`.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py567
1 files changed, 553 insertions, 14 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index a6b3277db3..13f0ed4e35 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,7 +20,451 @@ from pyspark.serializers import BatchedSerializer, PickleSerializer
from py4j.protocol import Py4JError
-__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+__all__ = [
+ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
+ "ShortType", "ArrayType", "MapType", "StructField", "StructType",
+ "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+
+
+class PrimitiveTypeSingleton(type):
+ _instances = {}
+
+ def __call__(cls):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
+ return cls._instances[cls]
+
+
+class StringType(object):
+ """Spark SQL StringType
+
+ The data type representing string values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "StringType"
+
+
+class BinaryType(object):
+ """Spark SQL BinaryType
+
+ The data type representing bytearray values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BinaryType"
+
+
+class BooleanType(object):
+ """Spark SQL BooleanType
+
+ The data type representing bool values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "BooleanType"
+
+
+class TimestampType(object):
+ """Spark SQL TimestampType
+
+ The data type representing datetime.datetime values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "TimestampType"
+
+
+class DecimalType(object):
+ """Spark SQL DecimalType
+
+ The data type representing decimal.Decimal values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DecimalType"
+
+
+class DoubleType(object):
+ """Spark SQL DoubleType
+
+ The data type representing float values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "DoubleType"
+
+
+class FloatType(object):
+ """Spark SQL FloatType
+
+ The data type representing single precision floating-point values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "FloatType"
+
+
+class ByteType(object):
+ """Spark SQL ByteType
+
+ The data type representing int values with 1 singed byte.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ByteType"
+
+
+class IntegerType(object):
+ """Spark SQL IntegerType
+
+ The data type representing int values.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "IntegerType"
+
+
+class LongType(object):
+ """Spark SQL LongType
+
+ The data type representing long values. If the any value is beyond the range of
+ [-9223372036854775808, 9223372036854775807], please use DecimalType.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "LongType"
+
+
+class ShortType(object):
+ """Spark SQL ShortType
+
+ The data type representing int values with 2 signed bytes.
+
+ """
+ __metaclass__ = PrimitiveTypeSingleton
+
+ def __repr__(self):
+ return "ShortType"
+
+
+class ArrayType(object):
+ """Spark SQL ArrayType
+
+ The data type representing list values.
+ An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool).
+ The field of elementType is used to specify the type of array elements.
+ The field of containsNull is used to specify if the array has None values.
+
+ """
+ def __init__(self, elementType, containsNull=False):
+ """Creates an ArrayType
+
+ :param elementType: the data type of elements.
+ :param containsNull: indicates whether the list contains None values.
+
+ >>> ArrayType(StringType) == ArrayType(StringType, False)
+ True
+ >>> ArrayType(StringType, True) == ArrayType(StringType)
+ False
+ """
+ self.elementType = elementType
+ self.containsNull = containsNull
+
+ def __repr__(self):
+ return "ArrayType(" + self.elementType.__repr__() + "," + \
+ str(self.containsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.elementType == other.elementType and
+ self.containsNull == other.containsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class MapType(object):
+ """Spark SQL MapType
+
+ The data type representing dict values.
+ A MapType object comprises three fields,
+ keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool).
+ The field of keyType is used to specify the type of keys in the map.
+ The field of valueType is used to specify the type of values in the map.
+ The field of valueContainsNull is used to specify if values of this map has None values.
+ For values of a MapType column, keys are not allowed to have None values.
+
+ """
+ def __init__(self, keyType, valueType, valueContainsNull=True):
+ """Creates a MapType
+ :param keyType: the data type of keys.
+ :param valueType: the data type of values.
+ :param valueContainsNull: indicates whether values contains null values.
+
+ >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
+ True
+ >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType)
+ False
+ """
+ self.keyType = keyType
+ self.valueType = valueType
+ self.valueContainsNull = valueContainsNull
+
+ def __repr__(self):
+ return "MapType(" + self.keyType.__repr__() + "," + \
+ self.valueType.__repr__() + "," + \
+ str(self.valueContainsNull).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.keyType == other.keyType and
+ self.valueType == other.valueType and
+ self.valueContainsNull == other.valueContainsNull)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructField(object):
+ """Spark SQL StructField
+
+ Represents a field in a StructType.
+ A StructField object comprises three fields, name (a string), dataType (a DataType),
+ and nullable (a bool). The field of name is the name of a StructField. The field of
+ dataType specifies the data type of a StructField.
+ The field of nullable specifies if values of a StructField can contain None values.
+
+ """
+ def __init__(self, name, dataType, nullable):
+ """Creates a StructField
+ :param name: the name of this field.
+ :param dataType: the data type of this field.
+ :param nullable: indicates whether values of this field can be null.
+
+ >>> StructField("f1", StringType, True) == StructField("f1", StringType, True)
+ True
+ >>> StructField("f1", StringType, True) == StructField("f2", StringType, True)
+ False
+ """
+ self.name = name
+ self.dataType = dataType
+ self.nullable = nullable
+
+ def __repr__(self):
+ return "StructField(" + self.name + "," + \
+ self.dataType.__repr__() + "," + \
+ str(self.nullable).lower() + ")"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.name == other.name and
+ self.dataType == other.dataType and
+ self.nullable == other.nullable)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class StructType(object):
+ """Spark SQL StructType
+
+ The data type representing namedtuple values.
+ A StructType object comprises a list of L{StructField}s.
+
+ """
+ def __init__(self, fields):
+ """Creates a StructType
+
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType([StructField("f1", StringType, True)])
+ >>> struct2 = StructType([StructField("f1", StringType, True),
+ ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 == struct2
+ False
+ """
+ self.fields = fields
+
+ def __repr__(self):
+ return "StructType(List(" + \
+ ",".join([field.__repr__() for field in self.fields]) + "))"
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.fields == other.fields)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def _parse_datatype_list(datatype_list_string):
+ """Parses a list of comma separated data types."""
+ index = 0
+ datatype_list = []
+ start = 0
+ depth = 0
+ while index < len(datatype_list_string):
+ if depth == 0 and datatype_list_string[index] == ",":
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ start = index + 1
+ elif datatype_list_string[index] == "(":
+ depth += 1
+ elif datatype_list_string[index] == ")":
+ depth -= 1
+
+ index += 1
+
+ # Handle the last data type
+ datatype_string = datatype_list_string[start:index].strip()
+ datatype_list.append(_parse_datatype_string(datatype_string))
+ return datatype_list
+
+
+def _parse_datatype_string(datatype_string):
+ """Parses the given data type string.
+
+ >>> def check_datatype(datatype):
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
+ ... python_datatype = _parse_datatype_string(scala_datatype.toString())
+ ... return datatype == python_datatype
+ >>> check_datatype(StringType())
+ True
+ >>> check_datatype(BinaryType())
+ True
+ >>> check_datatype(BooleanType())
+ True
+ >>> check_datatype(TimestampType())
+ True
+ >>> check_datatype(DecimalType())
+ True
+ >>> check_datatype(DoubleType())
+ True
+ >>> check_datatype(FloatType())
+ True
+ >>> check_datatype(ByteType())
+ True
+ >>> check_datatype(IntegerType())
+ True
+ >>> check_datatype(LongType())
+ True
+ >>> check_datatype(ShortType())
+ True
+ >>> # Simple ArrayType.
+ >>> simple_arraytype = ArrayType(StringType(), True)
+ >>> check_datatype(simple_arraytype)
+ True
+ >>> # Simple MapType.
+ >>> simple_maptype = MapType(StringType(), LongType())
+ >>> check_datatype(simple_maptype)
+ True
+ >>> # Simple StructType.
+ >>> simple_structtype = StructType([
+ ... StructField("a", DecimalType(), False),
+ ... StructField("b", BooleanType(), True),
+ ... StructField("c", LongType(), True),
+ ... StructField("d", BinaryType(), False)])
+ >>> check_datatype(simple_structtype)
+ True
+ >>> # Complex StructType.
+ >>> complex_structtype = StructType([
+ ... StructField("simpleArray", simple_arraytype, True),
+ ... StructField("simpleMap", simple_maptype, True),
+ ... StructField("simpleStruct", simple_structtype, True),
+ ... StructField("boolean", BooleanType(), False)])
+ >>> check_datatype(complex_structtype)
+ True
+ >>> # Complex ArrayType.
+ >>> complex_arraytype = ArrayType(complex_structtype, True)
+ >>> check_datatype(complex_arraytype)
+ True
+ >>> # Complex MapType.
+ >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False)
+ >>> check_datatype(complex_maptype)
+ True
+ """
+ left_bracket_index = datatype_string.find("(")
+ if left_bracket_index == -1:
+ # It is a primitive type.
+ left_bracket_index = len(datatype_string)
+ type_or_field = datatype_string[:left_bracket_index]
+ rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip()
+ if type_or_field == "StringType":
+ return StringType()
+ elif type_or_field == "BinaryType":
+ return BinaryType()
+ elif type_or_field == "BooleanType":
+ return BooleanType()
+ elif type_or_field == "TimestampType":
+ return TimestampType()
+ elif type_or_field == "DecimalType":
+ return DecimalType()
+ elif type_or_field == "DoubleType":
+ return DoubleType()
+ elif type_or_field == "FloatType":
+ return FloatType()
+ elif type_or_field == "ByteType":
+ return ByteType()
+ elif type_or_field == "IntegerType":
+ return IntegerType()
+ elif type_or_field == "LongType":
+ return LongType()
+ elif type_or_field == "ShortType":
+ return ShortType()
+ elif type_or_field == "ArrayType":
+ last_comma_index = rest_part.rfind(",")
+ containsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ containsNull = False
+ elementType = _parse_datatype_string(rest_part[:last_comma_index].strip())
+ return ArrayType(elementType, containsNull)
+ elif type_or_field == "MapType":
+ last_comma_index = rest_part.rfind(",")
+ valueContainsNull = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ valueContainsNull = False
+ keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip())
+ return MapType(keyType, valueType, valueContainsNull)
+ elif type_or_field == "StructField":
+ first_comma_index = rest_part.find(",")
+ name = rest_part[:first_comma_index].strip()
+ last_comma_index = rest_part.rfind(",")
+ nullable = True
+ if rest_part[last_comma_index+1:].strip().lower() == "false":
+ nullable = False
+ dataType = _parse_datatype_string(
+ rest_part[first_comma_index+1:last_comma_index].strip())
+ return StructField(name, dataType, nullable)
+ elif type_or_field == "StructType":
+ # rest_part should be in the format like
+ # List(StructField(field1,IntegerType,false)).
+ field_list_string = rest_part[rest_part.find("(")+1:-1]
+ fields = _parse_datatype_list(field_list_string)
+ return StructType(fields)
class SQLContext:
@@ -109,6 +553,40 @@ class SQLContext:
srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
return SchemaRDD(srdd, self)
+ def applySchema(self, rdd, schema):
+ """Applies the given schema to the given RDD of L{dict}s.
+
+ >>> schema = StructType([StructField("field1", IntegerType(), False),
+ ... StructField("field2", StringType(), False)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.sql("SELECT * from table1")
+ >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
+ ... {"field1" : 3, "field2": "row3"}]
+ True
+ >>> from datetime import datetime
+ >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0,
+ ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2},
+ ... "list": [1, 2, 3]}])
+ >>> schema = StructType([
+ ... StructField("byte", ByteType(), False),
+ ... StructField("short", ShortType(), False),
+ ... StructField("float", FloatType(), False),
+ ... StructField("time", TimestampType(), False),
+ ... StructField("map", MapType(StringType(), IntegerType(), False), False),
+ ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False),
+ ... StructField("list", ArrayType(ByteType(), False), False),
+ ... StructField("null", DoubleType(), True)])
+ >>> srdd = sqlCtx.applySchema(rdd, schema).map(
+ ... lambda x: (
+ ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null))
+ >>> srdd.collect()[0]
+ (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ """
+ jrdd = self._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__())
+ return SchemaRDD(srdd, self)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -139,10 +617,11 @@ class SQLContext:
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)
- def jsonFile(self, path):
- """Loads a text file storing one JSON object per line,
- returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonFile(self, path, schema=None):
+ """Loads a text file storing one JSON object per line as a L{SchemaRDD}.
+
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -151,8 +630,8 @@ class SQLContext:
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -160,16 +639,45 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
- jschema_rdd = self._ssql_ctx.jsonFile(path)
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonFile(path)
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(jschema_rdd, self)
- def jsonRDD(self, rdd):
- """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}.
- It goes through the entire dataset once to determine the schema.
+ def jsonRDD(self, rdd, schema=None):
+ """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
- >>> srdd = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ If the schema is provided, applies the given schema to this JSON dataset.
+ Otherwise, it goes through the entire dataset once to determine the schema.
+
+ >>> srdd1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
>>> srdd2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1")
>>> srdd2.collect() == [
@@ -177,6 +685,29 @@ class SQLContext:
... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
True
+ >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
+ >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
+ >>> srdd4 = sqlCtx.sql(
+ ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2")
+ >>> srdd4.collect() == [
+ ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None},
+ ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]},
+ ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}]
+ True
+ >>> schema = StructType([
+ ... StructField("field2", StringType(), True),
+ ... StructField("field3",
+ ... StructType([
+ ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)])
+ >>> srdd5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
+ >>> srdd6 = sqlCtx.sql(
+ ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3")
+ >>> srdd6.collect() == [
+ ... {"f1": "row1", "f2": None, "f3": None},
+ ... {"f1": None, "f2": [10, 11], "f3": 10},
+ ... {"f1": "row3", "f2": [], "f3": None}]
+ True
"""
def func(split, iterator):
for x in iterator:
@@ -186,7 +717,11 @@ class SQLContext:
keyed = PipelinedRDD(rdd, func)
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
- jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ if schema is None:
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ else:
+ scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__())
+ jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(jschema_rdd, self)
def sql(self, sqlQuery):
@@ -389,6 +924,10 @@ class SchemaRDD(RDD):
"""Creates a new table with the contents of this SchemaRDD."""
self._jschema_rdd.saveAsTable(tableName)
+ def schema(self):
+ """Returns the schema of this SchemaRDD (represented by a L{StructType})."""
+ return _parse_datatype_string(self._jschema_rdd.schema().toString())
+
def schemaString(self):
"""Returns the output schema in the tree format."""
return self._jschema_rdd.schemaString()