From c1cc8c4da239965e8ad478089b27e9c694088978 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sat, 7 Sep 2013 14:41:31 -0700 Subject: Export StorageLevel and refactor --- python/pyspark/__init__.py | 5 ++++- python/pyspark/context.py | 35 ++++++++++++---------------------- python/pyspark/rdd.py | 3 ++- python/pyspark/shell.py | 2 +- python/pyspark/storagelevel.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 26 deletions(-) create mode 100644 python/pyspark/storagelevel.py (limited to 'python') diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index fd5972d381..1f35f6f939 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -30,6 +30,8 @@ Public classes: An "add-only" shared variable that tasks can only add values to. - L{SparkFiles} Access files shipped with jobs. + - L{StorageLevel} + Finer-grained cache persistence levels. """ import sys import os @@ -39,6 +41,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles +from pyspark.storagelevel import StorageLevel -__all__ = ["SparkContext", "RDD", "SparkFiles"] +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4c48cd3f37..efd7828df6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -27,6 +27,7 @@ from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -119,29 +120,6 @@ class SparkContext(object): self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() - self._initStorageLevel() - - def _initStorageLevel(self): - """ - Initializes the StorageLevel object, which mimics the behavior of the scala object - by the same name. e.g., StorageLevel.DISK_ONLY returns the equivalent Java StorageLevel. - """ - newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel - levels = { - 'NONE': newStorageLevel(False, False, False, 1), - 'DISK_ONLY': newStorageLevel(True, False, False, 1), - 'DISK_ONLY_2': newStorageLevel(True, False, False, 2), - 'MEMORY_ONLY': newStorageLevel(False, True, True, 1), - 'MEMORY_ONLY_2': newStorageLevel(False, True, True, 2), - 'MEMORY_ONLY_SER': newStorageLevel(False, True, False, 1), - 'MEMORY_ONLY_SER_2': newStorageLevel(False, True, False, 2), - 'MEMORY_AND_DISK': newStorageLevel(True, True, True, 1), - 'MEMORY_AND_DISK_2': newStorageLevel(True, True, True, 2), - 'MEMORY_AND_DISK_SER': newStorageLevel(True, True, False, 1), - 'MEMORY_AND_DISK_SER_2': newStorageLevel(True, True, False, 2), - } - self.StorageLevel = type('StorageLevel', (), levels) - @property def defaultParallelism(self): """ @@ -303,6 +281,17 @@ class SparkContext(object): """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + def _getJavaStorageLevel(self, storageLevel): + """ + Returns a Java StorageLevel based on a pyspark.StorageLevel. + """ + if not isinstance(storageLevel, StorageLevel): + raise Exception("storageLevel must be of type pyspark.StorageLevel") + + newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel + return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, + storageLevel.deserialized, storageLevel.replication) + def _test(): import atexit import doctest diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 332258f5d1..58e1849cad 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -77,7 +77,8 @@ class RDD(object): have a storage level set yet. """ self.is_cached = True - self._jrdd.persist(storageLevel) + javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + self._jrdd.persist(javaStorageLevel) return self def unpersist(self): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index e374ca4ee4..dc205b306f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,12 +24,12 @@ import os import platform import pyspark from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files) -StorageLevel = sc.StorageLevel # alias StorageLevel to global scope print """Welcome to ____ __ diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py new file mode 100644 index 0000000000..b31f4762e6 --- /dev/null +++ b/python/pyspark/storagelevel.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["StorageLevel"] + +class StorageLevel: + """ + Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, + whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory + in a serialized format, and whether to replicate the RDD partitions on multiple nodes. + Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. + """ + + def __init__(self, useDisk, useMemory, deserialized, replication = 1): + self.useDisk = useDisk + self.useMemory = useMemory + self.deserialized = deserialized + self.replication = replication + +StorageLevel.DISK_ONLY = StorageLevel(True, False, False) +StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, True, 2) +StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False) +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, True) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, True, 2) +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False) +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, 2) -- cgit v1.2.3