diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-07-05 17:00:24 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-07-05 17:00:24 -0700 |
commit | fdde7d0aa0ef69d0e9a88cf712601bba1d5b0706 (patch) | |
tree | 8f1c6be72c453dbd33d864c6ca54e2b2d28a480c /python/pyspark/ml | |
parent | 59f9c1bd1adfea7069e769fb68351c228c37c8fc (diff) | |
download | spark-fdde7d0aa0ef69d0e9a88cf712601bba1d5b0706.tar.gz spark-fdde7d0aa0ef69d0e9a88cf712601bba1d5b0706.tar.bz2 spark-fdde7d0aa0ef69d0e9a88cf712601bba1d5b0706.zip |
[SPARK-16348][ML][MLLIB][PYTHON] Use full classpaths for pyspark ML JVM calls
## What changes were proposed in this pull request?
Issue: Omitting the full classpath can cause problems when calling JVM methods or classes from pyspark.
This PR: Changed all uses of jvm.X in pyspark.ml and pyspark.mllib to use full classpath for X
## How was this patch tested?
Existing unit tests. Manual testing in an environment where this was an issue.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #14023 from jkbradley/SPARK-16348.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/common.py | 10 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 8 |
2 files changed, 9 insertions, 9 deletions
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index 256e91e141..7d449aaccb 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -63,7 +63,7 @@ def _to_java_object_rdd(rdd): RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True) + return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) def _py2java(sc, obj): @@ -82,7 +82,7 @@ def _py2java(sc, obj): pass else: data = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.MLSerDe.loads(data) + obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) return obj @@ -95,17 +95,17 @@ def _java2py(sc, r, encoding="bytes"): clsName = 'JavaRDD' if clsName == 'JavaRDD': - jrdd = sc._jvm.MLSerDe.javaToPython(r) + jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) return RDD(jrdd, sc) if clsName == 'Dataset': return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: - r = sc._jvm.MLSerDe.dumps(r) + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): try: - r = sc._jvm.MLSerDe.dumps(r) + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) except Py4JJavaError: pass # not pickable diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 981ed9dda0..24efce812b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec))) + jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 - jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs))) + jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): |