aboutsummaryrefslogblamecommitdiff
path: root/python/pyspark/sql.py
blob: 114fa138d0de21b5314fdff159680230c77d846b (plain) (tree)











































































































                                                                                                    


                                            
                                          

                                                   








































































































































































                                                                                                     


                                            
                                          

                                                   

















                                                                                                     












                                                                        





























































                                                                                                  
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.rdd import RDD

from py4j.protocol import Py4JError

__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]


class SQLContext:
    """
    Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s,
    register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
    """

    def __init__(self, sparkContext):
        """
        Create a new SQLContext.

        @param sparkContext: The SparkContext to wrap.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
            ...
        ValueError:...

        >>> bad_rdd = sc.parallelize([1,2,3])
        >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
            ...
        ValueError:...

        >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
        ... "boolean" : True}])
        >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
        ... x.boolean))
        >>> srdd.collect()[0]
        (1, u'string', 1.0, 1, True)
        """
        self._sc = sparkContext
        self._jsc = self._sc._jsc
        self._jvm = self._sc._jvm
        self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap

    @property
    def _ssql_ctx(self):
        """
        Accessor for the JVM SparkSQL context.  Subclasses can overrite this property to provide
        their own JVM Contexts.
        """
        if not hasattr(self, '_scala_SQLContext'):
            self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
        return self._scala_SQLContext

    def inferSchema(self, rdd):
        """
        Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to
        determine the fields names and types, and then use that to extract all the dictionaries.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
        ...                    {"field1" : 3, "field2": "row3"}]
        True
        """
        if (rdd.__class__ is SchemaRDD):
            raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
        elif not isinstance(rdd.first(), dict):
            raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" %
                             (SchemaRDD.__name__, rdd.first()))

        jrdd = self._pythonToJavaMap(rdd._jrdd)
        srdd = self._ssql_ctx.inferSchema(jrdd.rdd())
        return SchemaRDD(srdd, self)

    def registerRDDAsTable(self, rdd, tableName):
        """
        Registers the given RDD as a temporary table in the catalog.  Temporary tables exist only
        during the lifetime of this instance of SQLContext.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
        """
        if (rdd.__class__ is SchemaRDD):
            jschema_rdd = rdd._jschema_rdd
            self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
        else:
            raise ValueError("Can only register SchemaRDD as table")

    def parquetFile(self, path):
        """
        Loads a Parquet file, returning the result as a L{SchemaRDD}.

        >>> import tempfile, shutil
        >>> parquetFile = tempfile.mkdtemp()
        >>> shutil.rmtree(parquetFile)
        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> srdd.saveAsParquetFile(parquetFile)
        >>> srdd2 = sqlCtx.parquetFile(parquetFile)
        >>> srdd.collect() == srdd2.collect()
        True
        """
        jschema_rdd = self._ssql_ctx.parquetFile(path)
        return SchemaRDD(jschema_rdd, self)

    def sql(self, sqlQuery):
        """
        Executes a SQL query using Spark, returning the result as a L{SchemaRDD}.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
        >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
        >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"},
        ...                     {"f1" : 3, "f2": "row3"}]
        True
        """
        return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)

    def table(self, tableName):
        """
        Returns the specified table as a L{SchemaRDD}.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
        >>> srdd2 = sqlCtx.table("table1")
        >>> srdd.collect() == srdd2.collect()
        True
        """
        return SchemaRDD(self._ssql_ctx.table(tableName), self)

    def cacheTable(tableName):
        """
        Caches the specified table in-memory.
        """
        self._ssql_ctx.cacheTable(tableName)

    def uncacheTable(tableName):
        """
        Removes the specified table from the in-memory cache.
        """
        self._ssql_ctx.uncacheTable(tableName)


class HiveContext(SQLContext):
    """
    An instance of the Spark SQL execution engine that integrates with data stored in Hive.
    Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL
    and HiveQL commands.
    """

    @property
    def _ssql_ctx(self):
        try:
            if not hasattr(self, '_scala_HiveContext'):
                self._scala_HiveContext = self._get_hive_ctx()
            return self._scala_HiveContext
        except Py4JError as e:
            raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
                            "sbt/sbt assembly" , e)

    def _get_hive_ctx(self):
        return self._jvm.HiveContext(self._jsc.sc())

    def hiveql(self, hqlQuery):
        """
        Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
        """
        return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)

    def hql(self, hqlQuery):
        """
        Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}.
        """
        return self.hiveql(hqlQuery)


class LocalHiveContext(HiveContext):
    """
    Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
    created with data stored in ./metadata.  Warehouse data is stored in in ./warehouse.

    >>> import os
    >>> hiveCtx = LocalHiveContext(sc)
    >>> try:
    ...     supress = hiveCtx.hql("DROP TABLE src")
    ... except Exception:
    ...     pass
    >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt')
    >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
    >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1)
    >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1]))
    >>> num = results.count()
    >>> reduce_sum = results.reduce(lambda x, y: x + y)
    >>> num
    500
    >>> reduce_sum
    130091
    """

    def _get_hive_ctx(self):
        return self._jvm.LocalHiveContext(self._jsc.sc())


class TestHiveContext(HiveContext):

    def _get_hive_ctx(self):
        return self._jvm.TestHiveContext(self._jsc.sc())


# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples
# are custom classes that must be generated per Schema.
class Row(dict):
    """
    An extended L{dict} that takes a L{dict} in its constructor, and exposes those items as fields.

    >>> r = Row({"hello" : "world", "foo" : "bar"})
    >>> r.hello
    'world'
    >>> r.foo
    'bar'
    """

    def __init__(self, d):
        d.update(self.__dict__)
        self.__dict__ = d
        dict.__init__(self, d)


class SchemaRDD(RDD):
    """
    An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD,
    not a PythonRDD, so we can utilize the relational query api exposed by SparkSQL.

    For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on
    directly, as it's underlying implementation is a RDD composed of Java objects. Instead it is
    converted to a PythonRDD in the JVM, on which Python operations can be done.
    """

    def __init__(self, jschema_rdd, sql_ctx):
        self.sql_ctx = sql_ctx
        self._sc = sql_ctx._sc
        self._jschema_rdd = jschema_rdd

        self.is_cached = False
        self.is_checkpointed = False
        self.ctx = self.sql_ctx._sc
        self._jrdd_deserializer = self.ctx.serializer

    @property
    def _jrdd(self):
        """
        Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
        L{pyspark.rdd.RDD} super class (map, count, etc.).
        """
        if not hasattr(self, '_lazy_jrdd'):
            self._lazy_jrdd = self._toPython()._jrdd
        return self._lazy_jrdd

    @property
    def _id(self):
        return self._jrdd.id()

    def saveAsParquetFile(self, path):
        """
        Saves the contents of this L{SchemaRDD} as a parquet file, preserving the schema.  Files
        that are written out using this method can be read back in as a SchemaRDD using the
        L{SQLContext.parquetFile} method.

        >>> import tempfile, shutil
        >>> parquetFile = tempfile.mkdtemp()
        >>> shutil.rmtree(parquetFile)
        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> srdd.saveAsParquetFile(parquetFile)
        >>> srdd2 = sqlCtx.parquetFile(parquetFile)
        >>> srdd2.collect() == srdd.collect()
        True
        """
        self._jschema_rdd.saveAsParquetFile(path)

    def registerAsTable(self, name):
        """
        Registers this RDD as a temporary table using the given name.  The lifetime of this temporary
        table is tied to the L{SQLContext} that was used to create this SchemaRDD.

        >>> srdd = sqlCtx.inferSchema(rdd)
        >>> srdd.registerAsTable("test")
        >>> srdd2 = sqlCtx.sql("select * from test")
        >>> srdd.collect() == srdd2.collect()
        True
        """
        self._jschema_rdd.registerAsTable(name)

    def insertInto(self, tableName, overwrite = False):
        """
        Inserts the contents of this SchemaRDD into the specified table,
        optionally overwriting any existing data.
        """
        self._jschema_rdd.insertInto(tableName, overwrite)

    def saveAsTable(self, tableName):
        """
        Creates a new table with the contents of this SchemaRDD.
        """
        self._jschema_rdd.saveAsTable(tableName)

    def _toPython(self):
        # We have to import the Row class explicitly, so that the reference Pickler has is
        # pyspark.sql.Row instead of __main__.Row
        from pyspark.sql import Row
        jrdd = self._jschema_rdd.javaToPython()
        # TODO: This is inefficient, we should construct the Python Row object
        # in Java land in the javaToPython function. May require a custom
        # pickle serializer in Pyrolite
        return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))

    # We override the default cache/persist/checkpoint behavior as we want to cache the underlying
    # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
    def cache(self):
        self.is_cached = True
        self._jschema_rdd.cache()
        return self

    def persist(self, storageLevel):
        self.is_cached = True
        javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
        self._jschema_rdd.persist(javaStorageLevel)
        return self

    def unpersist(self):
        self.is_cached = False
        self._jschema_rdd.unpersist()
        return self

    def checkpoint(self):
        self.is_checkpointed = True
        self._jschema_rdd.checkpoint()

    def isCheckpointed(self):
        return self._jschema_rdd.isCheckpointed()

    def getCheckpointFile(self):
        checkpointFile = self._jschema_rdd.getCheckpointFile()
        if checkpointFile.isDefined():
            return checkpointFile.get()
        else:
            return None

def _test():
    import doctest
    from pyspark.context import SparkContext
    globs = globals().copy()
    # The small batch size here ensures that we see multiple batches,
    # even in these small test examples:
    sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
    globs['sc'] = sc
    globs['sqlCtx'] = SQLContext(sc)
    globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
        {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
    (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
    globs['sc'].stop()
    if failure_count:
        exit(-1)


if __name__ == "__main__":
    _test()