From db0b06c6ea7412266158b1c710bdc8ca30e26430 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 6 Apr 2016 11:24:11 -0700 Subject: [SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin Closes #12020 from yinxusen/SPARK-13786. --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 11 +++++++++++ .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 9 +++++++++ .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 9 +++++++++ 3 files changed, 29 insertions(+) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7837b6730..c368aadd23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import java.lang.reflect.Modifier +import java.util.{List => JList} import java.util.NoSuchElementException import scala.annotation.varargs @@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) this } + /** Put param pairs with a [[java.util.List]] of values for Python. */ + private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { + put(paramPairs.asScala: _*) + } + /** * Optionally returns the value associated with a param. */ @@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) } } + /** Java-friendly method for Python API */ + private[ml] def toList: java.util.List[ParamPair[_]] = { + this.toSeq.asJava + } + /** * Number of param pairs in this map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 040b0093b9..4d9d4d472e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) + } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 07330bb6b0..0f2179c2a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { + this(uid, bestModel, validationMetrics.asScala.toArray) + } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) -- cgit v1.2.3