aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
authorAhir Reddy <ahirreddy@gmail.com>2014-04-15 00:07:55 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-15 00:07:55 -0700
commitc99bcb7feaa761c5826f2e1d844d0502a3b79538 (patch)
treecb136b0fbeaac6268eea2782f5ac9d615aafdb5b /python/pyspark/sql.py
parent0247b5c5467ca1b0d03ba929a78fa4d805582d84 (diff)
downloadspark-c99bcb7feaa761c5826f2e1d844d0502a3b79538.tar.gz
spark-c99bcb7feaa761c5826f2e1d844d0502a3b79538.tar.bz2
spark-c99bcb7feaa761c5826f2e1d844d0502a3b79538.zip
SPARK-1374: PySpark API for SparkSQL
An initial API that exposes SparkSQL functionality in PySpark. A PythonRDD composed of dictionaries, with string keys and primitive values (boolean, float, int, long, string) can be converted into a SchemaRDD that supports sql queries. ``` from pyspark.context import SQLContext sqlCtx = SQLContext(sc) rdd = sc.parallelize([{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) srdd = sqlCtx.applySchema(rdd) sqlCtx.registerRDDAsTable(srdd, "table1") srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") srdd2.collect() ``` The last line yields ```[{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, {"f1" : 3, "f2": "row3"}]``` Author: Ahir Reddy <ahirreddy@gmail.com> Author: Michael Armbrust <michael@databricks.com> Closes #363 from ahirreddy/pysql and squashes the following commits: 0294497 [Ahir Reddy] Updated log4j properties to supress Hive Warns 307d6e0 [Ahir Reddy] Style fix 6f7b8f6 [Ahir Reddy] Temporary fix MIMA checker. Since we now assemble Spark jar with Hive, we don't want to check the interfaces of all of our hive dependencies 3ef074a [Ahir Reddy] Updated documentation because classes moved to sql.py 29245bf [Ahir Reddy] Cache underlying SchemaRDD instead of generating and caching PythonRDD f2312c7 [Ahir Reddy] Moved everything into sql.py a19afe4 [Ahir Reddy] Doc fixes 6d658ba [Ahir Reddy] Remove the metastore directory created by the HiveContext tests in SparkSQL 521ff6d [Ahir Reddy] Trying to get spark to build with hive ab95eba [Ahir Reddy] Set SPARK_HIVE=true on jenkins ded03e7 [Ahir Reddy] Added doc test for HiveContext 22de1d4 [Ahir Reddy] Fixed maven pyrolite dependency e4da06c [Ahir Reddy] Display message if hive is not built into spark 227a0be [Michael Armbrust] Update API links. Fix Hive example. 58e2aa9 [Michael Armbrust] Build Docs for pyspark SQL Api. Minor fixes. 4285340 [Michael Armbrust] Fix building of Hive API Docs. 38a92b0 [Michael Armbrust] Add note to future non-python developers about python docs. 337b201 [Ahir Reddy] Changed com.clearspring.analytics stream version from 2.4.0 to 2.5.1 to match SBT build, and added pyrolite to maven build 40491c9 [Ahir Reddy] PR Changes + Method Visibility 1836944 [Michael Armbrust] Fix comments. e00980f [Michael Armbrust] First draft of python sql programming guide. b0192d3 [Ahir Reddy] Added Long, Double and Boolean as usable types + unit test f98a422 [Ahir Reddy] HiveContexts 79621cf [Ahir Reddy] cleaning up cruft b406ba0 [Ahir Reddy] doctest formatting 20936a5 [Ahir Reddy] Added tests and documentation e4d21b4 [Ahir Reddy] Added pyrolite dependency 79f739d [Ahir Reddy] added more tests 7515ba0 [Ahir Reddy] added more tests :) d26ec5e [Ahir Reddy] added test e9f5b8d [Ahir Reddy] adding tests 906d180 [Ahir Reddy] added todo explaining cost of creating Row object in python 251f99d [Ahir Reddy] for now only allow dictionaries as input 09b9980 [Ahir Reddy] made jrdd explicitly lazy c608947 [Ahir Reddy] SchemaRDD now has all RDD operations 725c91e [Ahir Reddy] awesome row objects 55d1c76 [Ahir Reddy] return row objects 4fe1319 [Ahir Reddy] output dictionaries correctly be079de [Ahir Reddy] returning dictionaries works cd5f79f [Ahir Reddy] Switched to using Scala SQLContext e948bd9 [Ahir Reddy] yippie 4886052 [Ahir Reddy] even better c0fb1c6 [Ahir Reddy] more working 043ca85 [Ahir Reddy] working 5496f9f [Ahir Reddy] doesn't crash b8b904b [Ahir Reddy] Added schema rdd class 67ba875 [Ahir Reddy] java to python, and python to java bcc0f23 [Ahir Reddy] Java to python ab6025d [Ahir Reddy] compiling
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py363
1 files changed, 363 insertions, 0 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
new file mode 100644
index 0000000000..67e6eee3f4
--- /dev/null
+++ b/python/pyspark/sql.py
@@ -0,0 +1,363 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.rdd import RDD
+
+from py4j.protocol import Py4JError
+
+__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]
+
+
+class SQLContext:
+ """
+ Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s,
+ register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
+ """
+
+ def __init__(self, sparkContext):
+ """
+ Create a new SQLContext.
+
+ @param sparkContext: The SparkContext to wrap.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+
+ >>> bad_rdd = sc.parallelize([1,2,3])
+ >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+
+ >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
+ ... "boolean" : True}])
+ >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
+ ... x.boolean))
+ >>> srdd.collect()[0]
+ (1, u'string', 1.0, 1, True)
+ """
+ self._sc = sparkContext
+ self._jsc = self._sc._jsc
+ self._jvm = self._sc._jvm
+ self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
+
+ @property
+ def _ssql_ctx(self):
+ """
+ Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide
+ their own JVM Contexts.
+ """
+ if not hasattr(self, '_scala_SQLContext'):
+ self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
+ return self._scala_SQLContext
+
+ def inferSchema(self, rdd):
+ """
+ Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
+ determine the fields names and types, and then use that to extract all the dictionaries.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
+ ... {"field1" : 3, "field2": "row3"}]
+ True
+ """
+ if (rdd.__class__ is SchemaRDD):
+ raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
+ elif not isinstance(rdd.first(), dict):
+ raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
+ (SchemaRDD.__name__, rdd.first()))
+
+ jrdd = self._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
+ return SchemaRDD(srdd, self)
+
+ def registerRDDAsTable(self, rdd, tableName):
+ """
+ Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
+ during the lifetime of this instance of SQLContext.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ """
+ if (rdd.__class__ is SchemaRDD):
+ jschema_rdd = rdd._jschema_rdd
+ self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
+ else:
+ raise ValueError("Can only register SchemaRDD as table")
+
+ def parquetFile(self, path):
+ """
+ Loads a Parquet file, returning the result as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.saveAsParquetFile("/tmp/tmp.parquet")
+ >>> srdd2 = sqlCtx.parquetFile("/tmp/tmp.parquet")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ jschema_rdd = self._ssql_ctx.parquetFile(path)
+ return SchemaRDD(jschema_rdd, self)
+
+ def sql(self, sqlQuery):
+ """
+ Executes a SQL query using Spark, returning the result as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
+ ... {"f1" : 3, "f2": "row3"}]
+ True
+ """
+ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+
+ def table(self, tableName):
+ """
+ Returns the specified table as a L{SchemaRDD}.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> srdd2 = sqlCtx.table("table1")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ return SchemaRDD(self._ssql_ctx.table(tableName), self)
+
+ def cacheTable(tableName):
+ """
+ Caches the specified table in-memory.
+ """
+ self._ssql_ctx.cacheTable(tableName)
+
+ def uncacheTable(tableName):
+ """
+ Removes the specified table from the in-memory cache.
+ """
+ self._ssql_ctx.uncacheTable(tableName)
+
+
+class HiveContext(SQLContext):
+ """
+ An instance of the Spark SQL execution engine that integrates with data stored in Hive.
+ Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL
+ and HiveQL commands.
+ """
+
+ @property
+ def _ssql_ctx(self):
+ try:
+ if not hasattr(self, '_scala_HiveContext'):
+ self._scala_HiveContext = self._get_hive_ctx()
+ return self._scala_HiveContext
+ except Py4JError as e:
+ raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
+ "sbt/sbt assembly" , e)
+
+ def _get_hive_ctx(self):
+ return self._jvm.HiveContext(self._jsc.sc())
+
+ def hiveql(self, hqlQuery):
+ """
+ Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
+ """
+ return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
+
+ def hql(self, hqlQuery):
+ """
+ Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
+ """
+ return self.hiveql(hqlQuery)
+
+
+class LocalHiveContext(HiveContext):
+ """
+ Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
+ created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
+
+ >>> import os
+ >>> hiveCtx = LocalHiveContext(sc)
+ >>> try:
+ ... supress = hiveCtx.hql("DROP TABLE src")
+ ... except Exception:
+ ... pass
+ >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
+ >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+ >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
+ >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
+ >>> num = results.count()
+ >>> reduce_sum = results.reduce(lambda x, y: x + y)
+ >>> num
+ 500
+ >>> reduce_sum
+ 130091
+ """
+
+ def _get_hive_ctx(self):
+ return self._jvm.LocalHiveContext(self._jsc.sc())
+
+
+class TestHiveContext(HiveContext):
+
+ def _get_hive_ctx(self):
+ return self._jvm.TestHiveContext(self._jsc.sc())
+
+
+# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples
+# are custom classes that must be generated per Schema.
+class Row(dict):
+ """
+ An extended L{dict} that takes a L{dict} in its constructor, and exposes those items as fields.
+
+ >>> r = Row({"hello" : "world", "foo" : "bar"})
+ >>> r.hello
+ 'world'
+ >>> r.foo
+ 'bar'
+ """
+
+ def __init__(self, d):
+ d.update(self.__dict__)
+ self.__dict__ = d
+ dict.__init__(self, d)
+
+
+class SchemaRDD(RDD):
+ """
+ An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD,
+ not a PythonRDD, so we can utilize the relational query api exposed by SparkSQL.
+
+ For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on
+ directly, as it's underlying implementation is a RDD composed of Java objects. Instead it is
+ converted to a PythonRDD in the JVM, on which Python operations can be done.
+ """
+
+ def __init__(self, jschema_rdd, sql_ctx):
+ self.sql_ctx = sql_ctx
+ self._sc = sql_ctx._sc
+ self._jschema_rdd = jschema_rdd
+
+ self.is_cached = False
+ self.is_checkpointed = False
+ self.ctx = self.sql_ctx._sc
+ self._jrdd_deserializer = self.ctx.serializer
+
+ @property
+ def _jrdd(self):
+ """
+ Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
+ L{pyspark.rdd.RDD} super class (map, count, etc.).
+ """
+ if not hasattr(self, '_lazy_jrdd'):
+ self._lazy_jrdd = self._toPython()._jrdd
+ return self._lazy_jrdd
+
+ @property
+ def _id(self):
+ return self._jrdd.id()
+
+ def saveAsParquetFile(self, path):
+ """
+ Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema. Files
+ that are written out using this method can be read back in as a SchemaRDD using the
+ L{SQLContext.parquetFile} method.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.saveAsParquetFile("/tmp/test.parquet")
+ >>> srdd2 = sqlCtx.parquetFile("/tmp/test.parquet")
+ >>> srdd2.collect() == srdd.collect()
+ True
+ """
+ self._jschema_rdd.saveAsParquetFile(path)
+
+ def registerAsTable(self, name):
+ """
+ Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ table is tied to the L{SQLContext} that was used to create this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.registerAsTable("test")
+ >>> srdd2 = sqlCtx.sql("select * from test")
+ >>> srdd.collect() == srdd2.collect()
+ True
+ """
+ self._jschema_rdd.registerAsTable(name)
+
+ def _toPython(self):
+ # We have to import the Row class explicitly, so that the reference Pickler has is
+ # pyspark.sql.Row instead of __main__.Row
+ from pyspark.sql import Row
+ jrdd = self._jschema_rdd.javaToPython()
+ # TODO: This is inefficient, we should construct the Python Row object
+ # in Java land in the javaToPython function. May require a custom
+ # pickle serializer in Pyrolite
+ return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))
+
+ # We override the default cache/persist/checkpoint behavior as we want to cache the underlying
+ # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
+ def cache(self):
+ self.is_cached = True
+ self._jschema_rdd.cache()
+ return self
+
+ def persist(self, storageLevel):
+ self.is_cached = True
+ javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
+ self._jschema_rdd.persist(javaStorageLevel)
+ return self
+
+ def unpersist(self):
+ self.is_cached = False
+ self._jschema_rdd.unpersist()
+ return self
+
+ def checkpoint(self):
+ self.is_checkpointed = True
+ self._jschema_rdd.checkpoint()
+
+ def isCheckpointed(self):
+ return self._jschema_rdd.isCheckpointed()
+
+ def getCheckpointFile(self):
+ checkpointFile = self._jschema_rdd.getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = sc
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
+ {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
+ (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
+