aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/common.py
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-06-13 19:59:53 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-13 19:59:53 -0700
commitbaa3e633e18c47b12e79fe3ddc01fc8ec010f096 (patch)
tree83c91014b9d46fc9efc4bf1f5dcef5cee1fe184a /python/pyspark/ml/common.py
parent5827b65e28da168286c771c53a38620d79f5e74f (diff)
downloadspark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.gz
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.bz2
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.zip
[SPARK-15364][ML][PYSPARK] Implement PySpark picklers for ml.Vector and ml.Matrix under spark.ml.python
## What changes were proposed in this pull request? Now we have PySpark picklers for new and old vector/matrix, individually. However, they are all implemented under `PythonMLlibAPI`. To separate spark.mllib from spark.ml, we should implement the picklers of new vector/matrix under `spark.ml.python` instead. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #13219 from viirya/pyspark-pickler-ml.
Diffstat (limited to 'python/pyspark/ml/common.py')
-rw-r--r--python/pyspark/ml/common.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
new file mode 100644
index 0000000000..256e91e141
--- /dev/null
+++ b/python/pyspark/ml/common.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= '3':
+ long = int
+ unicode = str
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame, SQLContext
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = str(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+]
+
+
+# this will call the ML version of pythonToJava()
+def _to_java_object_rdd(rdd):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, DataFrame):
+ obj = obj._jdf
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, list):
+ obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
+ pass
+ else:
+ data = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.MLSerDe.loads(data)
+ return obj
+
+
+def _java2py(sc, r, encoding="bytes"):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.MLSerDe.javaToPython(r)
+ return RDD(jrdd, sc)
+
+ if clsName == 'Dataset':
+ return DataFrame(r, SQLContext.getOrCreate(sc))
+
+ if clsName in _picklable_classes:
+ r = sc._jvm.MLSerDe.dumps(r)
+ elif isinstance(r, (JavaArray, JavaList)):
+ try:
+ r = sc._jvm.MLSerDe.dumps(r)
+ except Py4JJavaError:
+ pass # not pickable
+
+ if isinstance(r, (bytearray, bytes)):
+ r = PickleSerializer().loads(bytes(r), encoding=encoding)
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def inherit_doc(cls):
+ """
+ A decorator that makes a class inherit documentation from its parents.
+ """
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls