aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2014-09-06 14:49:25 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-06 14:49:25 -0700
commitda35330e830a85008c0bf9f0725418e4dfe7ac66 (patch)
tree4b9e0a37919e9e0daed35a2ade99a5bc3159437b
parentbaff7e936101635d9bd4245e45335878bafb75e0 (diff)
downloadspark-da35330e830a85008c0bf9f0725418e4dfe7ac66.tar.gz
spark-da35330e830a85008c0bf9f0725418e4dfe7ac66.tar.bz2
spark-da35330e830a85008c0bf9f0725418e4dfe7ac66.zip
Spark-3406 add a default storage level to python RDD persist API
Author: Holden Karau <holden@pigscanfly.ca> Closes #2280 from holdenk/SPARK-3406-Python-RDD-persist-api-does-not-have-default-storage-level and squashes the following commits: 33eaade [Holden Karau] As Josh pointed out, sql also override persist. Make persist behave the same as in the underlying RDD as well e658227 [Holden Karau] Fix the test I added e95a6c5 [Holden Karau] The Python persist function did not have a default storageLevel unlike the Scala API. Noticed this issue because we got a bug report back from the book where we had documented it as if it was the same as the Scala API
-rw-r--r--python/pyspark/rdd.py7
-rw-r--r--python/pyspark/sql.py3
2 files changed, 8 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 04f13523b4..aa90297855 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -212,11 +212,16 @@ class RDD(object):
self.persist(StorageLevel.MEMORY_ONLY_SER)
return self
- def persist(self, storageLevel):
+ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
"""
Set this RDD's storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+
+ >>> rdd = sc.parallelize(["b", "a", "c"])
+ >>> rdd.persist().is_cached
+ True
"""
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e7f573cf6d..97a51b9f8a 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -29,6 +29,7 @@ from operator import itemgetter
from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
+from pyspark.storagelevel import StorageLevel
from itertools import chain, ifilter, imap
@@ -1665,7 +1666,7 @@ class SchemaRDD(RDD):
self._jschema_rdd.cache()
return self
- def persist(self, storageLevel):
+ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
self._jschema_rdd.persist(javaStorageLevel)