From db0b06c6ea7412266158b1c710bdc8ca30e26430 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 6 Apr 2016 11:24:11 -0700 Subject: [SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin Closes #12020 from yinxusen/SPARK-13786. --- .../scala/org/apache/spark/ml/param/params.scala | 11 + .../apache/spark/ml/tuning/CrossValidator.scala | 9 + .../spark/ml/tuning/TrainValidationSplit.scala | 9 + python/pyspark/ml/tests.py | 56 ++- python/pyspark/ml/tuning.py | 407 +++++++++++++++------ python/pyspark/ml/wrapper.py | 23 ++ 6 files changed, 404 insertions(+), 111 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7837b6730..c368aadd23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import java.lang.reflect.Modifier +import java.util.{List => JList} import java.util.NoSuchElementException import scala.annotation.varargs @@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) this } + /** Put param pairs with a [[java.util.List]] of values for Python. */ + private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { + put(paramPairs.asScala: _*) + } + /** * Optionally returns the value associated with a param. */ @@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) } } + /** Java-friendly method for Python API */ + private[ml] def toList: java.util.List[ParamPair[_]] = { + this.toSeq.asJava + } + /** * Number of param pairs in this map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 040b0093b9..4d9d4d472e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) + } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 07330bb6b0..0f2179c2a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { + this(uid, bestModel, validationMetrics.asScala.toArray) + } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f6159b2c95..e3f873e3a7 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,7 +44,7 @@ import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans -from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -53,7 +53,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.linalg import DenseVector, SparseVector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + class TrainValidationSplitTests(PySparkTestCase): @@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + class PersistenceTest(PySparkTestCase): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a528d22e18..da00f317b3 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,12 +18,15 @@ import itertools import numpy as np +from pyspark import SparkContext from pyspark import since from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.util import keyword_only +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.wrapper import JavaWrapper from pyspark.sql.functions import rand +from pyspark.mllib.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -91,7 +94,84 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator, HasSeed): +class ValidatorParams(HasSeed): + """ + Common params for TrainValidationSplit and CrossValidator. + """ + + estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") + estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") + evaluator = Param( + Params._dummy(), "evaluator", + "evaluator used to select hyper-parameters that maximize the validator metric") + + def setEstimator(self, value): + """ + Sets the value of :py:attr:`estimator`. + """ + return self._set(estimator=value) + + def getEstimator(self): + """ + Gets the value of estimator or its default value. + """ + return self.getOrDefault(self.estimator) + + def setEstimatorParamMaps(self, value): + """ + Sets the value of :py:attr:`estimatorParamMaps`. + """ + return self._set(estimatorParamMaps=value) + + def getEstimatorParamMaps(self): + """ + Gets the value of estimatorParamMaps or its default value. + """ + return self.getOrDefault(self.estimatorParamMaps) + + def setEvaluator(self, value): + """ + Sets the value of :py:attr:`evaluator`. + """ + return self._set(evaluator=value) + + def getEvaluator(self): + """ + Gets the value of evaluator or its default value. + """ + return self.getOrDefault(self.evaluator) + + @classmethod + def _from_java_impl(cls, java_stage): + """ + Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. + """ + + # Load information from java_stage to the instance. + estimator = JavaWrapper._from_java(java_stage.getEstimator()) + evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + return estimator, epms, evaluator + + def _to_java_impl(self): + """ + Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + + java_estimator = self.getEstimator()._to_java() + java_evaluator = self.getEvaluator()._to_java() + return java_estimator, java_epms, java_evaluator + + +class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): """ K-fold cross validation. @@ -116,11 +196,6 @@ class CrossValidator(Estimator, HasSeed): .. versionadded:: 1.4.0 """ - estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - evaluator = Param( - Params._dummy(), "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", typeConverter=TypeConverters.toInt) @@ -148,51 +223,6 @@ class CrossValidator(Estimator, HasSeed): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("1.4.0") - def setEstimator(self, value): - """ - Sets the value of :py:attr:`estimator`. - """ - self._paramMap[self.estimator] = value - return self - - @since("1.4.0") - def getEstimator(self): - """ - Gets the value of estimator or its default value. - """ - return self.getOrDefault(self.estimator) - - @since("1.4.0") - def setEstimatorParamMaps(self, value): - """ - Sets the value of :py:attr:`estimatorParamMaps`. - """ - self._paramMap[self.estimatorParamMaps] = value - return self - - @since("1.4.0") - def getEstimatorParamMaps(self): - """ - Gets the value of estimatorParamMaps or its default value. - """ - return self.getOrDefault(self.estimatorParamMaps) - - @since("1.4.0") - def setEvaluator(self, value): - """ - Sets the value of :py:attr:`evaluator`. - """ - self._paramMap[self.evaluator] = value - return self - - @since("1.4.0") - def getEvaluator(self): - """ - Gets the value of evaluator or its default value. - """ - return self.getOrDefault(self.evaluator) - @since("1.4.0") def setNumFolds(self, value): """ @@ -236,7 +266,7 @@ class CrossValidator(Estimator, HasSeed): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return CrossValidatorModel(bestModel) + return self._copyValues(CrossValidatorModel(bestModel)) @since("1.4.0") def copy(self, extra=None): @@ -258,8 +288,58 @@ class CrossValidator(Estimator, HasSeed): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidator, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) + numFolds = java_stage.getNumFolds() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + numFolds=numFolds, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage -class CrossValidatorModel(Model): + def _to_java(self): + """ + Transfer this instance to a Java CrossValidator. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setSeed(self.getSeed()) + _java_obj.setNumFolds(self.getNumFolds()) + + return _java_obj + + +class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from k-fold cross validation. @@ -289,8 +369,60 @@ class CrossValidatorModel(Model): extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) -class TrainValidationSplit(Estimator, HasSeed): + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidatorModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaWrapper._from_java(java_stage.bestModel()) + estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + +class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): """ Train-Validation-Split. @@ -315,11 +447,6 @@ class TrainValidationSplit(Estimator, HasSeed): .. versionadded:: 2.0.0 """ - estimator = Param(Params._dummy(), "estimator", "estimator to be tested") - estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - evaluator = Param( - Params._dummy(), "evaluator", - "evaluator used to select hyper-parameters that maximize the validated metric") trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ validation data. Must be between 0 and 1.") @@ -347,51 +474,6 @@ class TrainValidationSplit(Estimator, HasSeed): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("2.0.0") - def setEstimator(self, value): - """ - Sets the value of :py:attr:`estimator`. - """ - self._paramMap[self.estimator] = value - return self - - @since("2.0.0") - def getEstimator(self): - """ - Gets the value of estimator or its default value. - """ - return self.getOrDefault(self.estimator) - - @since("2.0.0") - def setEstimatorParamMaps(self, value): - """ - Sets the value of :py:attr:`estimatorParamMaps`. - """ - self._paramMap[self.estimatorParamMaps] = value - return self - - @since("2.0.0") - def getEstimatorParamMaps(self): - """ - Gets the value of estimatorParamMaps or its default value. - """ - return self.getOrDefault(self.estimatorParamMaps) - - @since("2.0.0") - def setEvaluator(self, value): - """ - Sets the value of :py:attr:`evaluator`. - """ - self._paramMap[self.evaluator] = value - return self - - @since("2.0.0") - def getEvaluator(self): - """ - Gets the value of evaluator or its default value. - """ - return self.getOrDefault(self.evaluator) - @since("2.0.0") def setTrainRatio(self, value): """ @@ -429,7 +511,7 @@ class TrainValidationSplit(Estimator, HasSeed): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return TrainValidationSplitModel(bestModel) + return self._copyValues(TrainValidationSplitModel(bestModel)) @since("2.0.0") def copy(self, extra=None): @@ -451,8 +533,59 @@ class TrainValidationSplit(Estimator, HasSeed): newTVS.setEvaluator(self.getEvaluator().copy(extra)) return newTVS + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplit, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) + trainRatio = java_stage.getTrainRatio() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + trainRatio=trainRatio, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. -class TrainValidationSplitModel(Model): + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setTrainRatio(self.getTrainRatio()) + _java_obj.setSeed(self.getSeed()) + + return _java_obj + + +class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from train validation split. """ @@ -480,19 +613,75 @@ class TrainValidationSplitModel(Model): extra = dict() return TrainValidationSplitModel(self.bestModel.copy(extra)) + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaWrapper._from_java(java_stage.bestModel()) + estimator, epms, evaluator = \ + super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaWrapper._new_java_obj( + "org.apache.spark.ml.tuning.TrainValidationSplitModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + if __name__ == "__main__": import doctest + from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() + # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.tuning tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 35b0eba926..ca93bf7d7d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -76,6 +76,17 @@ class JavaWrapper(Params): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -88,6 +99,18 @@ class JavaWrapper(Params): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + @staticmethod def _empty_java_param_map(): """ -- cgit v1.2.3