diff options
author | Bryan Cutler <bjcutler@us.ibm.com> | 2015-12-20 09:08:23 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-12-20 09:08:23 +0000 |
commit | ce1798b3af8de326bf955b51ed955a924b019b4e (patch) | |
tree | d7f5a9b81218cf50109ad05b7adf0dc68979ad6f /python | |
parent | 284e29a870bbb62f59988a5d88cd12f1b0b6f9d3 (diff) | |
download | spark-ce1798b3af8de326bf955b51ed955a924b019b4e.tar.gz spark-ce1798b3af8de326bf955b51ed955a924b019b4e.tar.bz2 spark-ce1798b3af8de326bf955b51ed955a924b019b4e.zip |
[SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs
Added catch for casting Long to Int exception when PySpark ALS Ratings are serialized. It is easy to accidentally use Long IDs for user/product and before, it would fail with a somewhat cryptic "ClassCastException: java.lang.Long cannot be cast to java.lang.Integer." Now if this is done, a more descriptive error is shown, e.g. "PickleException: Ratings id 1205640308657491975 exceeds max integer value of 2147483647."
Author: Bryan Cutler <bjcutler@us.ibm.com>
Closes #9361 from BryanCutler/als-pyspark-long-id-error-SPARK-10158.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/tests.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f8e8e0e0ad..6ed03e3582 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -1539,6 +1540,22 @@ class MLUtilsTests(MLlibTestCase): shutil.rmtree(load_vectors_path) +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") |