From a63d4c7dc2970900b116f7287e3d6b302d9d5698 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 5 Sep 2013 23:36:27 -0700 Subject: SPARK-660: Add StorageLevel support in Python It uses reflection... I am not proud of that fact, but it at least ensures compatibility (sans refactoring of the StorageLevel stuff). --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 11 +++++++++++ python/pyspark/context.py | 14 ++++++++++++++ python/pyspark/rdd.py | 18 ++++++++++++++++++ python/pyspark/shell.py | 3 ++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ccd3833964..6ca56b3af6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PipedRDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -270,6 +271,16 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + /** + * Returns the StorageLevel with the given string name. + * Throws an exception if the name is not a valid StorageLevel. + */ + def getStorageLevel(name: String) : StorageLevel = { + // In Scala, "val MEMORY_ONLY" produces a public getter by the same name. + val storageLevelGetter = StorageLevel.getClass().getDeclaredMethod(name) + return storageLevelGetter.invoke(StorageLevel).asInstanceOf[StorageLevel] + } + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ writeIteratorToPickleFile(items.asScala, filename) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8fbf296509..49f9b4610d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -279,6 +279,20 @@ class SparkContext(object): """ self._jsc.sc().setCheckpointDir(dirName, useExisting) +class StorageLevelReader: + """ + Mimics the Scala StorageLevel by directing all attribute requests + (e.g., StorageLevel.DISK_ONLY) to the JVM for reflection. + """ + + def __init__(self, sc): + self.sc = sc + + def __getattr__(self, name): + try: + return self.sc._jvm.PythonRDD.getStorageLevel(name) + except: + print "Failed to find StorageLevel:", name def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 914118ccdd..332258f5d1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,24 @@ class RDD(object): self._jrdd.cache() return self + def persist(self, storageLevel): + """ + Set this RDD's storage level to persist its values across operations after the first time + it is computed. This can only be used to assign a new storage level if the RDD does not + have a storage level set yet. + """ + self.is_cached = True + self._jrdd.persist(storageLevel) + return self + + def unpersist(self): + """ + Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + """ + self.is_cached = False + self._jrdd.unpersist() + return self + def checkpoint(self): """ Mark this RDD for checkpointing. It will be saved to a file inside the diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 54823f8037..9acc176d55 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -23,12 +23,13 @@ This file is designed to be launched as a PYTHONSTARTUP script. import os import platform import pyspark -from pyspark.context import SparkContext +from pyspark.context import SparkContext, StorageLevelReader # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files) +StorageLevel = StorageLevelReader(sc) print """Welcome to ____ __ -- cgit v1.2.3