aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-06 11:24:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:24:11 -0700
commitdb0b06c6ea7412266158b1c710bdc8ca30e26430 (patch)
tree58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /mllib
parent3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff)
downloadspark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala9
3 files changed, 29 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d7837b6730..c368aadd23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.param
import java.lang.reflect.Modifier
+import java.util.{List => JList}
import java.util.NoSuchElementException
import scala.annotation.varargs
@@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
this
}
+ /** Put param pairs with a [[java.util.List]] of values for Python. */
+ private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = {
+ put(paramPairs.asScala: _*)
+ }
+
/**
* Optionally returns the value associated with a param.
*/
@@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
}
}
+ /** Java-friendly method for Python API */
+ private[ml] def toList: java.util.List[ParamPair[_]] = {
+ this.toSeq.asJava
+ }
+
/**
* Number of param pairs in this map.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 040b0093b9..4d9d4d472e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
+ this(uid, bestModel, avgMetrics.asScala.toArray)
+ }
+
@Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 07330bb6b0..0f2179c2a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = {
+ this(uid, bestModel, validationMetrics.asScala.toArray)
+ }
+
@Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)