aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/common.py')
-rw-r--r--python/pyspark/mllib/common.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
new file mode 100644
index 0000000000..76864d8163
--- /dev/null
+++ b/python/pyspark/mllib/common.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = unicode(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'LinkedList',
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+ 'Rating',
+ 'LabeledPoint',
+]
+
+
+# this will call the MLlib version of pythonToJava()
+def _to_java_object_rdd(rdd, cache=False):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ if cache:
+ rdd.cache()
+ return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, dict):
+ obj = MapConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, (list, tuple)):
+ obj = ListConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, basestring)):
+ pass
+ else:
+ bytes = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.SerDe.loads(bytes)
+ return obj
+
+
+def _java2py(sc, r):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.SerDe.javaToPython(r)
+ return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer()))
+
+ elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes:
+ r = sc._jvm.SerDe.dumps(r)
+
+ if isinstance(r, bytearray):
+ r = PickleSerializer().loads(str(r))
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def callMLlibFunc(name, *args):
+ """ Call API in PythonMLLibAPI """
+ sc = SparkContext._active_spark_context
+ api = getattr(sc._jvm.PythonMLLibAPI(), name)
+ return callJavaFunc(sc, api, *args)
+
+
+class JavaModelWrapper(object):
+ """
+ Wrapper for the model in JVM
+ """
+ def __init__(self, java_model):
+ self._sc = SparkContext._active_spark_context
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def call(self, name, *a):
+ """Call method of java_model"""
+ return callJavaFunc(self._sc, getattr(self._java_model, name), *a)