From c99bcb7feaa761c5826f2e1d844d0502a3b79538 Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Tue, 15 Apr 2014 00:07:55 -0700 Subject: SPARK-1374: PySpark API for SparkSQL An initial API that exposes SparkSQL functionality in PySpark. A PythonRDD composed of dictionaries, with string keys and primitive values (boolean, float, int, long, string) can be converted into a SchemaRDD that supports sql queries. ``` from pyspark.context import SQLContext sqlCtx = SQLContext(sc) rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) srdd = sqlCtx.applySchema(rdd) sqlCtx.registerRDDAsTable(srdd, "table1") srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") srdd2.collect() ``` The last line yields ```[{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, {"f1" : 3, "f2": "row3"}]``` Author: Ahir Reddy Author: Michael Armbrust Closes #363 from ahirreddy/pysql and squashes the following commits: 0294497 [Ahir Reddy] Updated log4j properties to supress Hive Warns 307d6e0 [Ahir Reddy] Style fix 6f7b8f6 [Ahir Reddy] Temporary fix MIMA checker. Since we now assemble Spark jar with Hive, we don't want to check the interfaces of all of our hive dependencies 3ef074a [Ahir Reddy] Updated documentation because classes moved to sql.py 29245bf [Ahir Reddy] Cache underlying SchemaRDD instead of generating and caching PythonRDD f2312c7 [Ahir Reddy] Moved everything into sql.py a19afe4 [Ahir Reddy] Doc fixes 6d658ba [Ahir Reddy] Remove the metastore directory created by the HiveContext tests in SparkSQL 521ff6d [Ahir Reddy] Trying to get spark to build with hive ab95eba [Ahir Reddy] Set SPARK_HIVE=true on jenkins ded03e7 [Ahir Reddy] Added doc test for HiveContext 22de1d4 [Ahir Reddy] Fixed maven pyrolite dependency e4da06c [Ahir Reddy] Display message if hive is not built into spark 227a0be [Michael Armbrust] Update API links. Fix Hive example. 58e2aa9 [Michael Armbrust] Build Docs for pyspark SQL Api. Minor fixes. 4285340 [Michael Armbrust] Fix building of Hive API Docs. 38a92b0 [Michael Armbrust] Add note to future non-python developers about python docs. 337b201 [Ahir Reddy] Changed com.clearspring.analytics stream version from 2.4.0 to 2.5.1 to match SBT build, and added pyrolite to maven build 40491c9 [Ahir Reddy] PR Changes + Method Visibility 1836944 [Michael Armbrust] Fix comments. e00980f [Michael Armbrust] First draft of python sql programming guide. b0192d3 [Ahir Reddy] Added Long, Double and Boolean as usable types + unit test f98a422 [Ahir Reddy] HiveContexts 79621cf [Ahir Reddy] cleaning up cruft b406ba0 [Ahir Reddy] doctest formatting 20936a5 [Ahir Reddy] Added tests and documentation e4d21b4 [Ahir Reddy] Added pyrolite dependency 79f739d [Ahir Reddy] added more tests 7515ba0 [Ahir Reddy] added more tests :) d26ec5e [Ahir Reddy] added test e9f5b8d [Ahir Reddy] adding tests 906d180 [Ahir Reddy] added todo explaining cost of creating Row object in python 251f99d [Ahir Reddy] for now only allow dictionaries as input 09b9980 [Ahir Reddy] made jrdd explicitly lazy c608947 [Ahir Reddy] SchemaRDD now has all RDD operations 725c91e [Ahir Reddy] awesome row objects 55d1c76 [Ahir Reddy] return row objects 4fe1319 [Ahir Reddy] output dictionaries correctly be079de [Ahir Reddy] returning dictionaries works cd5f79f [Ahir Reddy] Switched to using Scala SQLContext e948bd9 [Ahir Reddy] yippie 4886052 [Ahir Reddy] even better c0fb1c6 [Ahir Reddy] more working 043ca85 [Ahir Reddy] working 5496f9f [Ahir Reddy] doesn't crash b8b904b [Ahir Reddy] Added schema rdd class 67ba875 [Ahir Reddy] java to python, and python to java bcc0f23 [Ahir Reddy] Java to python ab6025d [Ahir Reddy] compiling --- python/pyspark/sql.py | 363 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 363 insertions(+) create mode 100644 python/pyspark/sql.py (limited to 'python/pyspark/sql.py') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py new file mode 100644 index 0000000000..67e6eee3f4 --- /dev/null +++ b/python/pyspark/sql.py @@ -0,0 +1,363 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import RDD + +from py4j.protocol import Py4JError + +__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + + +class SQLContext: + """ + Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s, + register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files. + """ + + def __init__(self, sparkContext): + """ + Create a new SQLContext. + + @param sparkContext: The SparkContext to wrap. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + + >>> bad_rdd = sc.parallelize([1,2,3]) + >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + + >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, + ... "boolean" : True}]) + >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, + ... x.boolean)) + >>> srdd.collect()[0] + (1, u'string', 1.0, 1, True) + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap + + @property + def _ssql_ctx(self): + """ + Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide + their own JVM Contexts. + """ + if not hasattr(self, '_scala_SQLContext'): + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def inferSchema(self, rdd): + """ + Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to + determine the fields names and types, and then use that to extract all the dictionaries. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, + ... {"field1" : 3, "field2": "row3"}] + True + """ + if (rdd.__class__ is SchemaRDD): + raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) + elif not isinstance(rdd.first(), dict): + raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % + (SchemaRDD.__name__, rdd.first())) + + jrdd = self._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) + return SchemaRDD(srdd, self) + + def registerRDDAsTable(self, rdd, tableName): + """ + Registers the given RDD as a temporary table in the catalog. Temporary tables exist only + during the lifetime of this instance of SQLContext. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + """ + if (rdd.__class__ is SchemaRDD): + jschema_rdd = rdd._jschema_rdd + self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) + else: + raise ValueError("Can only register SchemaRDD as table") + + def parquetFile(self, path): + """ + Loads a Parquet file, returning the result as a L{SchemaRDD}. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.saveAsParquetFile("/tmp/tmp.parquet") + >>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet") + >>> srdd.collect() == srdd2.collect() + True + """ + jschema_rdd = self._ssql_ctx.parquetFile(path) + return SchemaRDD(jschema_rdd, self) + + def sql(self, sqlQuery): + """ + Executes a SQL query using Spark, returning the result as a L{SchemaRDD}. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, + ... {"f1" : 3, "f2": "row3"}] + True + """ + return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """ + Returns the specified table as a L{SchemaRDD}. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.table("table1") + >>> srdd.collect() == srdd2.collect() + True + """ + return SchemaRDD(self._ssql_ctx.table(tableName), self) + + def cacheTable(tableName): + """ + Caches the specified table in-memory. + """ + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(tableName): + """ + Removes the specified table from the in-memory cache. + """ + self._ssql_ctx.uncacheTable(tableName) + + +class HiveContext(SQLContext): + """ + An instance of the Spark SQL execution engine that integrates with data stored in Hive. + Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL + and HiveQL commands. + """ + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ + "sbt/sbt assembly" , e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + def hiveql(self, hqlQuery): + """ + Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + """ + return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) + + def hql(self, hqlQuery): + """ + Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + """ + return self.hiveql(hqlQuery) + + +class LocalHiveContext(HiveContext): + """ + Starts up an instance of hive where metadata is stored locally. An in-process metadata data is + created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + + >>> import os + >>> hiveCtx = LocalHiveContext(sc) + >>> try: + ... supress = hiveCtx.hql("DROP TABLE src") + ... except Exception: + ... pass + >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') + >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) + >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + >>> num = results.count() + >>> reduce_sum = results.reduce(lambda x, y: x + y) + >>> num + 500 + >>> reduce_sum + 130091 + """ + + def _get_hive_ctx(self): + return self._jvm.LocalHiveContext(self._jsc.sc()) + + +class TestHiveContext(HiveContext): + + def _get_hive_ctx(self): + return self._jvm.TestHiveContext(self._jsc.sc()) + + +# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples +# are custom classes that must be generated per Schema. +class Row(dict): + """ + An extended L{dict} that takes a L{dict} in its constructor, and exposes those items as fields. + + >>> r = Row({"hello" : "world", "foo" : "bar"}) + >>> r.hello + 'world' + >>> r.foo + 'bar' + """ + + def __init__(self, d): + d.update(self.__dict__) + self.__dict__ = d + dict.__init__(self, d) + + +class SchemaRDD(RDD): + """ + An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, + not a PythonRDD, so we can utilize the relational query api exposed by SparkSQL. + + For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on + directly, as it's underlying implementation is a RDD composed of Java objects. Instead it is + converted to a PythonRDD in the JVM, on which Python operations can be done. + """ + + def __init__(self, jschema_rdd, sql_ctx): + self.sql_ctx = sql_ctx + self._sc = sql_ctx._sc + self._jschema_rdd = jschema_rdd + + self.is_cached = False + self.is_checkpointed = False + self.ctx = self.sql_ctx._sc + self._jrdd_deserializer = self.ctx.serializer + + @property + def _jrdd(self): + """ + Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the + L{pyspark.rdd.RDD} super class (map, count, etc.). + """ + if not hasattr(self, '_lazy_jrdd'): + self._lazy_jrdd = self._toPython()._jrdd + return self._lazy_jrdd + + @property + def _id(self): + return self._jrdd.id() + + def saveAsParquetFile(self, path): + """ + Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema. Files + that are written out using this method can be read back in as a SchemaRDD using the + L{SQLContext.parquetFile} method. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.saveAsParquetFile("/tmp/test.parquet") + >>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet") + >>> srdd2.collect() == srdd.collect() + True + """ + self._jschema_rdd.saveAsParquetFile(path) + + def registerAsTable(self, name): + """ + Registers this RDD as a temporary table using the given name. The lifetime of this temporary + table is tied to the L{SQLContext} that was used to create this SchemaRDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.registerAsTable("test") + >>> srdd2 = sqlCtx.sql("select * from test") + >>> srdd.collect() == srdd2.collect() + True + """ + self._jschema_rdd.registerAsTable(name) + + def _toPython(self): + # We have to import the Row class explicitly, so that the reference Pickler has is + # pyspark.sql.Row instead of __main__.Row + from pyspark.sql import Row + jrdd = self._jschema_rdd.javaToPython() + # TODO: This is inefficient, we should construct the Python Row object + # in Java land in the javaToPython function. May require a custom + # pickle serializer in Pyrolite + return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d)) + + # We override the default cache/persist/checkpoint behavior as we want to cache the underlying + # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class + def cache(self): + self.is_cached = True + self._jschema_rdd.cache() + return self + + def persist(self, storageLevel): + self.is_cached = True + javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + self._jschema_rdd.persist(javaStorageLevel) + return self + + def unpersist(self): + self.is_cached = False + self._jschema_rdd.unpersist() + return self + + def checkpoint(self): + self.is_checkpointed = True + self._jschema_rdd.checkpoint() + + def isCheckpointed(self): + return self._jschema_rdd.isCheckpointed() + + def getCheckpointFile(self): + checkpointFile = self._jschema_rdd.getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = sc + globs['sqlCtx'] = SQLContext(sc) + globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, + {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) + (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() + -- cgit v1.2.3