diff options
author | Bryan Cutler <cutlerb@gmail.com> | 2017-01-31 15:42:36 -0800 |
---|---|---|
committer | Holden Karau <holden@us.ibm.com> | 2017-01-31 15:42:36 -0800 |
commit | 57d70d26c88819360cdc806e7124aa2cc1b9e4c5 (patch) | |
tree | 989a46211f9f6e7069dd77a41bf3f805716f863d /python | |
parent | ce112cec4f9bff222aa256893f94c316662a2a7e (diff) | |
download | spark-57d70d26c88819360cdc806e7124aa2cc1b9e4c5.tar.gz spark-57d70d26c88819360cdc806e7124aa2cc1b9e4c5.tar.bz2 spark-57d70d26c88819360cdc806e7124aa2cc1b9e4c5.zip |
[SPARK-17161][PYSPARK][ML] Add PySpark-ML JavaWrapper convenience function to create Py4J JavaArrays
## What changes were proposed in this pull request?
Adding convenience function to Python `JavaWrapper` so that it is easy to create a Py4J JavaArray that is compatible with current class constructors that have a Scala `Array` as input so that it is not necessary to have a Java/Python friendly constructor. The function takes a Java class as input that is used by Py4J to create the Java array of the given class. As an example, `OneVsRest` has been updated to use this and the alternate constructor is removed.
## How was this patch tested?
Added unit tests for the new convenience function and updated `OneVsRest` doctests which use this to persist the model.
Author: Bryan Cutler <cutlerb@gmail.com>
Closes #14725 from BryanCutler/pyspark-new_java_array-CountVectorizer-SPARK-17161.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 11 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 40 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 29 |
3 files changed, 77 insertions, 3 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f10556ca92..d41fc81fd7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1517,6 +1517,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() >>> model.transform(test2).head().prediction 2.0 + >>> model_path = temp_path + "/ovr_model" + >>> model.save(model_path) + >>> model2 = OneVsRestModel.load(model_path) + >>> model2.transform(test0).head().prediction + 1.0 .. versionadded:: 2.0.0 """ @@ -1759,9 +1764,13 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): :return: Java object equivalent to this instance. """ + sc = SparkContext._active_spark_context java_models = [model._to_java() for model in self.models] + java_models_array = JavaWrapper._new_java_array( + java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel) + metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata") _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", - self.uid, java_models) + self.uid, metadata.empty(), java_models_array) _java_obj.set("classifier", self.getClassifier()._to_java()) _java_obj.set("featuresCol", self.getFeaturesCol()) _java_obj.set("labelCol", self.getLabelCol()) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 68f5bc30ac..53204cde29 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -60,8 +60,8 @@ from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ GeneralizedLinearRegression from pyspark.ml.tuning import * -from pyspark.ml.wrapper import JavaParams -from pyspark.ml.common import _java2py +from pyspark.ml.wrapper import JavaParams, JavaWrapper +from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1620,6 +1620,42 @@ class MatrixUDTTests(MLlibTestCase): raise ValueError("Expected a matrix but got type %r" % type(m)) +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 13b75e9919..80a0b31cd8 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -16,6 +16,9 @@ # from abc import ABCMeta, abstractmethod +import sys +if sys.version >= '3': + xrange = range from pyspark import SparkContext from pyspark.sql import DataFrame @@ -59,6 +62,32 @@ class JavaWrapper(object): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + @staticmethod + def _new_java_array(pylist, java_class): + """ + Create a Java array of given java_class type. Useful for + calling a method with a Scala Array from Python with Py4J. + + :param pylist: + Python list to convert to a Java Array. + :param java_class: + Java class to specify the type of Array. Should be in the + form of sc._gateway.jvm.* (sc is a valid Spark Context). + :return: + Java Array of converted pylist. + + Example primitive Java classes: + - basestring -> sc._gateway.jvm.java.lang.String + - int -> sc._gateway.jvm.java.lang.Integer + - float -> sc._gateway.jvm.java.lang.Double + - bool -> sc._gateway.jvm.java.lang.Boolean + """ + sc = SparkContext._active_spark_context + java_array = sc._gateway.new_array(java_class, len(pylist)) + for i in xrange(len(pylist)): + java_array[i] = pylist[i] + return java_array + @inherit_doc class JavaParams(JavaWrapper, Params): |