diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2016-04-05 12:19:20 +0900 |
---|---|---|
committer | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2016-04-05 12:19:20 +0900 |
commit | 7db56244fa3dba92246bad6694f31bbf68ea47ec (patch) | |
tree | 703af71f0b5e6d4c66f2a32c7e8f4b529405fd23 | |
parent | 8f50574ab4021b9984b0017cd47ba012a894c19a (diff) | |
download | spark-7db56244fa3dba92246bad6694f31bbf68ea47ec.tar.gz spark-7db56244fa3dba92246bad6694f31bbf68ea47ec.tar.bz2 spark-7db56244fa3dba92246bad6694f31bbf68ea47ec.zip |
[SPARK-14368][PYSPARK] Support python.spark.worker.memory with upper-case unit.
## What changes were proposed in this pull request?
This fix tries to address the issue in PySpark where `spark.python.worker.memory`
could only be configured with a lower case unit (`k`, `m`, `g`, `t`). This fix
allows the upper case unit (`K`, `M`, `G`, `T`) to be used as well. This is to
conform to the JVM memory string as is specified in the documentation .
## How was this patch tested?
This fix adds additional test to cover the changes.
Author: Yong Tang <yong.tang.github@outlook.com>
Closes #12163 from yongtang/SPARK-14368.
-rw-r--r-- | python/pyspark/rdd.py | 2 | ||||
-rw-r--r-- | python/pyspark/tests.py | 12 |
2 files changed, 13 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cd1f64e8aa..8978f028c5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -115,7 +115,7 @@ def _parse_memory(s): 2048 """ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} - if s[-1] not in units: + if s[-1].lower() not in units: raise ValueError("invalid format: " + s) return int(float(s[:-1]) * units[s[-1].lower()]) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a5a83c7e38..40fccb8c00 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1966,6 +1966,18 @@ class ContextTests(unittest.TestCase): self.assertGreater(sc.startTime, 0) +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): |