diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala | 5 | ||||
-rw-r--r-- | project/MimaExcludes.scala | 5 | ||||
-rw-r--r-- | python/pyspark/ml/classification.py | 11 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 40 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 29 |
5 files changed, 81 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index cbd508ae79..7cbcccf272 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -135,11 +135,6 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { - /** A Python-friendly auxiliary constructor. */ - private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = { - this(uid, Metadata.empty, models.asScala.toArray) - } - /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7e6e143523..9d359427f2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,7 +54,10 @@ object MimaExcludes { // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"), + + // [SPARK-17161] Removing Python-friendly constructors not needed + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") ) // Exclude rules for 2.1.x diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f10556ca92..d41fc81fd7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1517,6 +1517,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() >>> model.transform(test2).head().prediction 2.0 + >>> model_path = temp_path + "/ovr_model" + >>> model.save(model_path) + >>> model2 = OneVsRestModel.load(model_path) + >>> model2.transform(test0).head().prediction + 1.0 .. versionadded:: 2.0.0 """ @@ -1759,9 +1764,13 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): :return: Java object equivalent to this instance. """ + sc = SparkContext._active_spark_context java_models = [model._to_java() for model in self.models] + java_models_array = JavaWrapper._new_java_array( + java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel) + metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata") _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", - self.uid, java_models) + self.uid, metadata.empty(), java_models_array) _java_obj.set("classifier", self.getClassifier()._to_java()) _java_obj.set("featuresCol", self.getFeaturesCol()) _java_obj.set("labelCol", self.getLabelCol()) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 68f5bc30ac..53204cde29 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -60,8 +60,8 @@ from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ GeneralizedLinearRegression from pyspark.ml.tuning import * -from pyspark.ml.wrapper import JavaParams -from pyspark.ml.common import _java2py +from pyspark.ml.wrapper import JavaParams, JavaWrapper +from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1620,6 +1620,42 @@ class MatrixUDTTests(MLlibTestCase): raise ValueError("Expected a matrix but got type %r" % type(m)) +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 13b75e9919..80a0b31cd8 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -16,6 +16,9 @@ # from abc import ABCMeta, abstractmethod +import sys +if sys.version >= '3': + xrange = range from pyspark import SparkContext from pyspark.sql import DataFrame @@ -59,6 +62,32 @@ class JavaWrapper(object): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + @staticmethod + def _new_java_array(pylist, java_class): + """ + Create a Java array of given java_class type. Useful for + calling a method with a Scala Array from Python with Py4J. + + :param pylist: + Python list to convert to a Java Array. + :param java_class: + Java class to specify the type of Array. Should be in the + form of sc._gateway.jvm.* (sc is a valid Spark Context). + :return: + Java Array of converted pylist. + + Example primitive Java classes: + - basestring -> sc._gateway.jvm.java.lang.String + - int -> sc._gateway.jvm.java.lang.Integer + - float -> sc._gateway.jvm.java.lang.Double + - bool -> sc._gateway.jvm.java.lang.Boolean + """ + sc = SparkContext._active_spark_context + java_array = sc._gateway.new_array(java_class, len(pylist)) + for i in xrange(len(pylist)): + java_array[i] = pylist[i] + return java_array + @inherit_doc class JavaParams(JavaWrapper, Params): |