diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/java_gateway.py | 1 | ||||
-rw-r--r-- | python/pyspark/ml/base.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/classification.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/clustering.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/common.py | 137 | ||||
-rw-r--r-- | python/pyspark/ml/evaluation.py | 2 | ||||
-rwxr-xr-x | python/pyspark/ml/feature.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/pipeline.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/recommendation.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 2 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 10 | ||||
-rw-r--r-- | python/pyspark/ml/tuning.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/util.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 2 |
14 files changed, 154 insertions, 16 deletions
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index cd4c55f79f..527ca82d31 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -116,6 +116,7 @@ def launch_gateway(): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a7a58e17a4..339e5d6af5 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod from pyspark import since from pyspark.ml.param import Params -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc @inherit_doc diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 77badebeb4..121b9262dd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,7 +26,7 @@ from pyspark.ml.regression import ( from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 92df19e804..75d9a0e8ca 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,7 @@ from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'KMeans', 'KMeansModel', diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py new file mode 100644 index 0000000000..256e91e141 --- /dev/null +++ b/python/pyspark/ml/common.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= '3': + long = int + unicode = str + +import py4j.protocol +from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaObject +from py4j.java_collections import ListConverter, JavaArray, JavaList + +from pyspark import RDD, SparkContext +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import DataFrame, SQLContext + +# Hack for support float('inf') in Py4j +_old_smart_decode = py4j.protocol.smart_decode + +_float_str_mapping = { + 'nan': 'NaN', + 'inf': 'Infinity', + '-inf': '-Infinity', +} + + +def _new_smart_decode(obj): + if isinstance(obj, float): + s = str(obj) + return _float_str_mapping.get(s, s) + return _old_smart_decode(obj) + +py4j.protocol.smart_decode = _new_smart_decode + + +_picklable_classes = [ + 'SparseVector', + 'DenseVector', + 'DenseMatrix', +] + + +# this will call the ML version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True) + + +def _py2java(sc, obj): + """ Convert Python object into Java """ + if isinstance(obj, RDD): + obj = _to_java_object_rdd(obj) + elif isinstance(obj, DataFrame): + obj = obj._jdf + elif isinstance(obj, SparkContext): + obj = obj._jsc + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + elif isinstance(obj, JavaObject): + pass + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): + pass + else: + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.MLSerDe.loads(data) + return obj + + +def _java2py(sc, r, encoding="bytes"): + if isinstance(r, JavaObject): + clsName = r.getClass().getSimpleName() + # convert RDD into JavaRDD + if clsName != 'JavaRDD' and clsName.endswith("RDD"): + r = r.toJavaRDD() + clsName = 'JavaRDD' + + if clsName == 'JavaRDD': + jrdd = sc._jvm.MLSerDe.javaToPython(r) + return RDD(jrdd, sc) + + if clsName == 'Dataset': + return DataFrame(r, SQLContext.getOrCreate(sc)) + + if clsName in _picklable_classes: + r = sc._jvm.MLSerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.MLSerDe.dumps(r) + except Py4JJavaError: + pass # not pickable + + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) + return r + + +def callJavaFunc(sc, func, *args): + """ Call Java Function """ + args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*args)) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index cd071f1b7c..1fe8772da7 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -21,7 +21,7 @@ from pyspark import since, keyword_only from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', 'MulticlassClassificationEvaluator'] diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ca77ac395d..a28764a752 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -25,7 +25,7 @@ from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', 'Bucketizer', diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 0777527134..a48f4bb2ad 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaParams -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc @inherit_doc diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 1778bfe938..0a7096794d 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -19,7 +19,7 @@ from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['ALS', 'ALSModel'] diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7c79ab73c7..db31993f0f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -21,7 +21,7 @@ from pyspark import since, keyword_only from pyspark.ml.param.shared import * from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc from pyspark.sql import DataFrame diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4358175a57..981ed9dda0 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -61,7 +61,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ GeneralizedLinearRegression from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams -from pyspark.mllib.common import _java2py +from pyspark.ml.common import _java2py from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) + jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 - jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) + jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index fe87b6cdb9..f857c5e8c8 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -25,7 +25,7 @@ from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand -from pyspark.mllib.common import inherit_doc, _py2java +from pyspark.ml.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9d28823196..4a31a29809 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -23,7 +23,7 @@ if sys.version > '3': unicode = str from pyspark import SparkContext, since -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc def _jvm(): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index fef0040faf..25c44b7533 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -22,7 +22,7 @@ from pyspark.sql import DataFrame from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.param import Params from pyspark.ml.util import _jvm -from pyspark.mllib.common import inherit_doc, _java2py, _py2java +from pyspark.ml.common import inherit_doc, _java2py, _py2java class JavaWrapper(object): |