aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py363
1 files changed, 363 insertions, 0 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
new file mode 100644
index 0000000000..67e6eee3f4
--- /dev/null
+++ b/python/pyspark/sql.py
@@ -0,0 +1,363 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.rdd import RDD
+
+from py4j.protocol import Py4JError
+
+__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+
+
+class SQLContext:
+ """
+ Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s,
+ register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
+ """
+
+ def __init__(self, sparkContext):
+ """
+ Create a new SQLContext.
+
+ @param sparkContext: The SparkContext to wrap.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+
+ >>> bad_rdd = sc.parallelize([1,2,3])
+ >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+
+ >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
+ ... "boolean" : True}])
+ >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
+ ... x.boolean))
+ >>> srdd.collect()[0]
+ (1, u'string', 1.0, 1, True)
+ """
+ self._sc = sparkContext
+ self._jsc = self._sc._jsc
+ self._jvm = self._sc._jvm
+ self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
+
+ @property
+ def _ssql_ctx(self):
+ """
+ Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide
+ their own JVM Contexts.
+ """
+ if not hasattr(self, '_scala_SQLContext'):
+ self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
+ return self._scala_SQLContext
+
+ def inferSchema(self, rdd):
+ """
+ Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
+ determine the fields names and types, and then use that to extract all the dictionaries.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
+ ... {"field1" : 3, "field2": "row3"}]
+ True
+ """
+ if (rdd.__class__ is SchemaRDD):
+ raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
+ elif not isinstance(rdd.first(), dict):
+ raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
+ (SchemaRDD.__name__, rdd.first()))
+
+ jrdd = self._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
+ return SchemaRDD(srdd, self)
+
+ def registerRDDAsTable(self, rdd, tableName):
+ """
+ Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
+ during the lifetime of this instance of SQLContext.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ """
+ if (rdd.__class__ is SchemaRDD):
+ jschema_rdd = rdd._jschema_rdd
+ self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
+ else:
+ raise ValueError("Can only register SchemaRDD as table")
+
+ def parquetFile(self, path):
+ """
+ Loads a Parquet file, returning the result as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
+ >>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ jschema_rdd = self._ssql_ctx.parquetFile(path)
+ return SchemaRDD(jschema_rdd, self)
+
+ def sql(self, sqlQuery):
+ """
+ Executes a SQL query using Spark, returning the result as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
+ ... {"f1" : 3, "f2": "row3"}]
+ True
+ """
+ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+
+ def table(self, tableName):
+ """
+ Returns the specified table as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.table("table1")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ return SchemaRDD(self._ssql_ctx.table(tableName), self)
+
+ def cacheTable(tableName):
+ """
+ Caches the specified table in-memory.
+ """
+ self._ssql_ctx.cacheTable(tableName)
+
+ def uncacheTable(tableName):
+ """
+ Removes the specified table from the in-memory cache.
+ """
+ self._ssql_ctx.uncacheTable(tableName)
+
+
+class HiveContext(SQLContext):
+ """
+ An instance of the Spark SQL execution engine that integrates with data stored in Hive.
+ Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL
+ and HiveQL commands.
+ """
+
+ @property
+ def _ssql_ctx(self):
+ try:
+ if not hasattr(self, '_scala_HiveContext'):
+ self._scala_HiveContext = self._get_hive_ctx()
+ return self._scala_HiveContext
+ except Py4JError as e:
+ raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
+ "sbt/sbt assembly" , e)
+
+ def _get_hive_ctx(self):
+ return self._jvm.HiveContext(self._jsc.sc())
+
+ def hiveql(self, hqlQuery):
+ """
+ Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
+ """
+ return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
+
+ def hql(self, hqlQuery):
+ """
+ Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
+ """
+ return self.hiveql(hqlQuery)
+
+
+class LocalHiveContext(HiveContext):
+ """
+ Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
+ created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
+
+ >>> import os
+ >>> hiveCtx = LocalHiveContext(sc)
+ >>> try:
+ ... supress = hiveCtx.hql("DROP TABLE src")
+ ... except Exception:
+ ... pass
+ >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
+ >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+ >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
+ >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
+ >>> num = results.count()
+ >>> reduce_sum = results.reduce(lambda x, y: x + y)
+ >>> num
+ 500
+ >>> reduce_sum
+ 130091
+ """
+
+ def _get_hive_ctx(self):
+ return self._jvm.LocalHiveContext(self._jsc.sc())
+
+
+class TestHiveContext(HiveContext):
+
+ def _get_hive_ctx(self):
+ return self._jvm.TestHiveContext(self._jsc.sc())
+
+
+# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples
+# are custom classes that must be generated per Schema.
+class Row(dict):
+ """
+ An extended L{dict} that takes a L{dict} in its constructor, and exposes those items as fields.
+
+ >>> r = Row({"hello" : "world", "foo" : "bar"})
+ >>> r.hello
+ 'world'
+ >>> r.foo
+ 'bar'
+ """
+
+ def __init__(self, d):
+ d.update(self.__dict__)
+ self.__dict__ = d
+ dict.__init__(self, d)
+
+
+class SchemaRDD(RDD):
+ """
+ An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD,
+ not a PythonRDD, so we can utilize the relational query api exposed by SparkSQL.
+
+ For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on
+ directly, as it's underlying implementation is a RDD composed of Java objects. Instead it is
+ converted to a PythonRDD in the JVM, on which Python operations can be done.
+ """
+
+ def __init__(self, jschema_rdd, sql_ctx):
+ self.sql_ctx = sql_ctx
+ self._sc = sql_ctx._sc
+ self._jschema_rdd = jschema_rdd
+
+ self.is_cached = False
+ self.is_checkpointed = False
+ self.ctx = self.sql_ctx._sc
+ self._jrdd_deserializer = self.ctx.serializer
+
+ @property
+ def _jrdd(self):
+ """
+ Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
+ L{pyspark.rdd.RDD} super class (map, count, etc.).
+ """
+ if not hasattr(self, '_lazy_jrdd'):
+ self._lazy_jrdd = self._toPython()._jrdd
+ return self._lazy_jrdd
+
+ @property
+ def _id(self):
+ return self._jrdd.id()
+
+ def saveAsParquetFile(self, path):
+ """
+ Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema. Files
+ that are written out using this method can be read back in as a SchemaRDD using the
+ L{SQLContext.parquetFile} method.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.saveAsParquetFile("/tmp/test.parquet")
+ >>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
+ >>> srdd2.collect() == srdd.collect()
+ True
+ """
+ self._jschema_rdd.saveAsParquetFile(path)
+
+ def registerAsTable(self, name):
+ """
+ Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ table is tied to the L{SQLContext} that was used to create this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.registerAsTable("test")
+ >>> srdd2 = sqlCtx.sql("select * from test")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ self._jschema_rdd.registerAsTable(name)
+
+ def _toPython(self):
+ # We have to import the Row class explicitly, so that the reference Pickler has is
+ # pyspark.sql.Row instead of __main__.Row
+ from pyspark.sql import Row
+ jrdd = self._jschema_rdd.javaToPython()
+ # TODO: This is inefficient, we should construct the Python Row object
+ # in Java land in the javaToPython function. May require a custom
+ # pickle serializer in Pyrolite
+ return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))
+
+ # We override the default cache/persist/checkpoint behavior as we want to cache the underlying
+ # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
+ def cache(self):
+ self.is_cached = True
+ self._jschema_rdd.cache()
+ return self
+
+ def persist(self, storageLevel):
+ self.is_cached = True
+ javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
+ self._jschema_rdd.persist(javaStorageLevel)
+ return self
+
+ def unpersist(self):
+ self.is_cached = False
+ self._jschema_rdd.unpersist()
+ return self
+
+ def checkpoint(self):
+ self.is_checkpointed = True
+ self._jschema_rdd.checkpoint()
+
+ def isCheckpointed(self):
+ return self._jschema_rdd.isCheckpointed()
+
+ def getCheckpointFile(self):
+ checkpointFile = self._jschema_rdd.getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = sc
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
+ {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
+ (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
+