aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
authorBryan Cutler <bjcutler@us.ibm.com>2015-12-20 09:08:23 +0000
committerSean Owen <sowen@cloudera.com>2015-12-20 09:08:23 +0000
commitce1798b3af8de326bf955b51ed955a924b019b4e (patch)
treed7f5a9b81218cf50109ad05b7adf0dc68979ad6f /python/pyspark/mllib
parent284e29a870bbb62f59988a5d88cd12f1b0b6f9d3 (diff)
downloadspark-ce1798b3af8de326bf955b51ed955a924b019b4e.tar.gz
spark-ce1798b3af8de326bf955b51ed955a924b019b4e.tar.bz2
spark-ce1798b3af8de326bf955b51ed955a924b019b4e.zip
[SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs
Added catch for casting Long to Int exception when PySpark ALS Ratings are serialized. It is easy to accidentally use Long IDs for user/product and before, it would fail with a somewhat cryptic "ClassCastException: java.lang.Long cannot be cast to java.lang.Integer." Now if this is done, a more descriptive error is shown, e.g. "PickleException: Ratings id 1205640308657491975 exceeds max integer value of 2147483647." Author: Bryan Cutler <bjcutler@us.ibm.com> Closes #9361 from BryanCutler/als-pyspark-long-id-error-SPARK-10158.
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/tests.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f8e8e0e0ad..6ed03e3582 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -54,6 +54,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -1539,6 +1540,22 @@ class MLUtilsTests(MLlibTestCase):
shutil.rmtree(load_vectors_path)
+class ALSTests(MLlibTestCase):
+
+ def test_als_ratings_serialize(self):
+ r = Rating(7, 1123, 3.14)
+ jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
+ nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
+ self.assertEqual(r.user, nr.user)
+ self.assertEqual(r.product, nr.product)
+ self.assertAlmostEqual(r.rating, nr.rating, 2)
+
+ def test_als_ratings_id_long_error(self):
+ r = Rating(1205640308657491975, 50233468418, 1.0)
+ # rating user id exceeds max int value, should fail when pickled
+ self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")