aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-06-13 19:59:53 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-13 19:59:53 -0700
commitbaa3e633e18c47b12e79fe3ddc01fc8ec010f096 (patch)
tree83c91014b9d46fc9efc4bf1f5dcef5cee1fe184a /python/pyspark/ml
parent5827b65e28da168286c771c53a38620d79f5e74f (diff)
downloadspark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.gz
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.bz2
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.zip
[SPARK-15364][ML][PYSPARK] Implement PySpark picklers for ml.Vector and ml.Matrix under spark.ml.python
## What changes were proposed in this pull request? Now we have PySpark picklers for new and old vector/matrix, individually. However, they are all implemented under `PythonMLlibAPI`. To separate spark.mllib from spark.ml, we should implement the picklers of new vector/matrix under `spark.ml.python` instead. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #13219 from viirya/pyspark-pickler-ml.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r--python/pyspark/ml/base.py2
-rw-r--r--python/pyspark/ml/classification.py2
-rw-r--r--python/pyspark/ml/clustering.py2
-rw-r--r--python/pyspark/ml/common.py137
-rw-r--r--python/pyspark/ml/evaluation.py2
-rwxr-xr-xpython/pyspark/ml/feature.py2
-rw-r--r--python/pyspark/ml/pipeline.py2
-rw-r--r--python/pyspark/ml/recommendation.py2
-rw-r--r--python/pyspark/ml/regression.py2
-rwxr-xr-xpython/pyspark/ml/tests.py10
-rw-r--r--python/pyspark/ml/tuning.py2
-rw-r--r--python/pyspark/ml/util.py2
-rw-r--r--python/pyspark/ml/wrapper.py2
13 files changed, 153 insertions, 16 deletions
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a7a58e17a4..339e5d6af5 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod
from pyspark import since
from pyspark.ml.param import Params
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
@inherit_doc
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 77badebeb4..121b9262dd 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -26,7 +26,7 @@ from pyspark.ml.regression import (
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 92df19e804..75d9a0e8ca 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
'KMeans', 'KMeansModel',
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
new file mode 100644
index 0000000000..256e91e141
--- /dev/null
+++ b/python/pyspark/ml/common.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= '3':
+ long = int
+ unicode = str
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame, SQLContext
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = str(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+]
+
+
+# this will call the ML version of pythonToJava()
+def _to_java_object_rdd(rdd):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, DataFrame):
+ obj = obj._jdf
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, list):
+ obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
+ pass
+ else:
+ data = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.MLSerDe.loads(data)
+ return obj
+
+
+def _java2py(sc, r, encoding="bytes"):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.MLSerDe.javaToPython(r)
+ return RDD(jrdd, sc)
+
+ if clsName == 'Dataset':
+ return DataFrame(r, SQLContext.getOrCreate(sc))
+
+ if clsName in _picklable_classes:
+ r = sc._jvm.MLSerDe.dumps(r)
+ elif isinstance(r, (JavaArray, JavaList)):
+ try:
+ r = sc._jvm.MLSerDe.dumps(r)
+ except Py4JJavaError:
+ pass # not pickable
+
+ if isinstance(r, (bytearray, bytes)):
+ r = PickleSerializer().loads(bytes(r), encoding=encoding)
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def inherit_doc(cls):
+ """
+ A decorator that makes a class inherit documentation from its parents.
+ """
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index cd071f1b7c..1fe8772da7 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
'MulticlassClassificationEvaluator']
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ca77ac395d..a28764a752 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import _convert_to_vector
from pyspark.ml.param.shared import *
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['Binarizer',
'Bucketizer',
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 0777527134..a48f4bb2ad 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
@inherit_doc
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 1778bfe938..0a7096794d 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
__all__ = ['ALS', 'ALSModel']
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 7c79ab73c7..db31993f0f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
from pyspark.sql import DataFrame
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4358175a57..981ed9dda0 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -61,7 +61,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
GeneralizedLinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import _java2py
+from pyspark.ml.common import _java2py
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
@@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
+ jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
- jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
+ jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)
def test_serialize(self):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index fe87b6cdb9..f857c5e8c8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -25,7 +25,7 @@ from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
-from pyspark.mllib.common import inherit_doc, _py2java
+from pyspark.ml.common import inherit_doc, _py2java
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9d28823196..4a31a29809 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -23,7 +23,7 @@ if sys.version > '3':
unicode = str
from pyspark import SparkContext, since
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
def _jvm():
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index fef0040faf..25c44b7533 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -22,7 +22,7 @@ from pyspark.sql import DataFrame
from pyspark.ml import Estimator, Transformer, Model
from pyspark.ml.param import Params
from pyspark.ml.util import _jvm
-from pyspark.mllib.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
class JavaWrapper(object):