From c1cc8c4da239965e8ad478089b27e9c694088978 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sat, 7 Sep 2013 14:41:31 -0700 Subject: Export StorageLevel and refactor --- python/pyspark/context.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) (limited to 'python/pyspark/context.py') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4c48cd3f37..efd7828df6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -27,6 +27,7 @@ from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -119,29 +120,6 @@ class SparkContext(object): self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() - self._initStorageLevel() - - def _initStorageLevel(self): - """ - Initializes the StorageLevel object, which mimics the behavior of the scala object - by the same name. e.g., StorageLevel.DISK_ONLY returns the equivalent Java StorageLevel. - """ - newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel - levels = { - 'NONE': newStorageLevel(False, False, False, 1), - 'DISK_ONLY': newStorageLevel(True, False, False, 1), - 'DISK_ONLY_2': newStorageLevel(True, False, False, 2), - 'MEMORY_ONLY': newStorageLevel(False, True, True, 1), - 'MEMORY_ONLY_2': newStorageLevel(False, True, True, 2), - 'MEMORY_ONLY_SER': newStorageLevel(False, True, False, 1), - 'MEMORY_ONLY_SER_2': newStorageLevel(False, True, False, 2), - 'MEMORY_AND_DISK': newStorageLevel(True, True, True, 1), - 'MEMORY_AND_DISK_2': newStorageLevel(True, True, True, 2), - 'MEMORY_AND_DISK_SER': newStorageLevel(True, True, False, 1), - 'MEMORY_AND_DISK_SER_2': newStorageLevel(True, True, False, 2), - } - self.StorageLevel = type('StorageLevel', (), levels) - @property def defaultParallelism(self): """ @@ -303,6 +281,17 @@ class SparkContext(object): """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + def _getJavaStorageLevel(self, storageLevel): + """ + Returns a Java StorageLevel based on a pyspark.StorageLevel. + """ + if not isinstance(storageLevel, StorageLevel): + raise Exception("storageLevel must be of type pyspark.StorageLevel") + + newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel + return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, + storageLevel.deserialized, storageLevel.replication) + def _test(): import atexit import doctest -- cgit v1.2.3