aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala12
-rw-r--r--python/pyspark/mllib/tests.py17
2 files changed, 28 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 29160a10e1..f6826ddbfa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1438,9 +1438,19 @@ private[spark] object SerDe extends Serializable {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)),
args(2).asInstanceOf[Double])
}
+
+ private def ratingsIdCheckLong(obj: Object): Int = {
+ try {
+ obj.asInstanceOf[Int]
+ } catch {
+ case ex: ClassCastException =>
+ throw new PickleException(s"Ratings id ${obj.toString} exceeds " +
+ s"max integer value of ${Int.MaxValue}", ex)
+ }
+ }
}
var initialized = false
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f8e8e0e0ad..6ed03e3582 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -54,6 +54,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -1539,6 +1540,22 @@ class MLUtilsTests(MLlibTestCase):
shutil.rmtree(load_vectors_path)
+class ALSTests(MLlibTestCase):
+
+ def test_als_ratings_serialize(self):
+ r = Rating(7, 1123, 3.14)
+ jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
+ nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
+ self.assertEqual(r.user, nr.user)
+ self.assertEqual(r.product, nr.product)
+ self.assertAlmostEqual(r.rating, nr.rating, 2)
+
+ def test_als_ratings_id_long_error(self):
+ r = Rating(1205640308657491975, 50233468418, 1.0)
+ # rating user id exceeds max int value, should fail when pickled
+ self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")