aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py85
1 files changed, 76 insertions, 9 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bb84ebe72c..2e7c2750a8 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
import time
import zipfile
import random
+from platform import python_implementation
if sys.version_info[:2] <= (2, 6):
import unittest2 as unittest
@@ -41,7 +42,8 @@ else:
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
+from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
+ CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
@@ -168,15 +170,46 @@ class SerializationTestCase(unittest.TestCase):
p2 = loads(dumps(p1, 2))
self.assertEquals(p1, p2)
-
-# Regression test for SPARK-3415
-class CloudPickleTest(unittest.TestCase):
+ def test_itemgetter(self):
+ from operator import itemgetter
+ ser = CloudPickleSerializer()
+ d = range(10)
+ getter = itemgetter(1)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ getter = itemgetter(0, 3)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ def test_attrgetter(self):
+ from operator import attrgetter
+ ser = CloudPickleSerializer()
+
+ class C(object):
+ def __getattr__(self, item):
+ return item
+ d = C()
+ getter = attrgetter("a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("a", "b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ d.e = C()
+ getter = attrgetter("e.a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("e.a", "e.b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ # Regression test for SPARK-3415
def test_pickling_file_handles(self):
- from pyspark.cloudpickle import dumps
- from StringIO import StringIO
- from pickle import load
+ ser = CloudPickleSerializer()
out1 = sys.stderr
- out2 = load(StringIO(dumps(out1)))
+ out2 = ser.loads(ser.dumps(out1))
self.assertEquals(out1, out2)
@@ -861,9 +894,43 @@ class TestOutputFormat(PySparkTestCase):
conf=input_conf).collect())
self.assertEqual(old_dataset, dict_data)
- @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed")
def test_newhadoop(self):
basepath = self.tempdir.name
+ data = [(1, ""),
+ (1, "a"),
+ (2, "bcdf")]
+ self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text")
+ result = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ self.assertEqual(result, data)
+
+ conf = {
+ "mapreduce.outputformat.class":
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "mapred.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapred.output.value.class": "org.apache.hadoop.io.Text",
+ "mapred.output.dir": basepath + "/newdataset/"
+ }
+ self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
+ input_conf = {"mapred.input.dir": basepath + "/newdataset/"}
+ new_dataset = sorted(self.sc.newAPIHadoopRDD(
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text",
+ conf=input_conf).collect())
+ self.assertEqual(new_dataset, data)
+
+ @unittest.skipIf(sys.version_info[:2] <= (2, 6) or python_implementation() == "PyPy",
+ "Skipped on 2.6 and PyPy until SPARK-2951 is fixed")
+ def test_newhadoop_with_array(self):
+ basepath = self.tempdir.name
# use custom ArrayWritable types and converters to handle arrays
array_data = [(1, array('d')),
(1, array('d', [1.0, 2.0, 3.0])),