aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala9
3 files changed, 29 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d7837b6730..c368aadd23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.param
import java.lang.reflect.Modifier
+import java.util.{List => JList}
import java.util.NoSuchElementException
import scala.annotation.varargs
@@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
this
}
+ /** Put param pairs with a [[java.util.List]] of values for Python. */
+ private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = {
+ put(paramPairs.asScala: _*)
+ }
+
/**
* Optionally returns the value associated with a param.
*/
@@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
}
}
+ /** Java-friendly method for Python API */
+ private[ml] def toList: java.util.List[ParamPair[_]] = {
+ this.toSeq.asJava
+ }
+
/**
* Number of param pairs in this map.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 040b0093b9..4d9d4d472e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
+ this(uid, bestModel, avgMetrics.asScala.toArray)
+ }
+
@Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 07330bb6b0..0f2179c2a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = {
+ this(uid, bestModel, validationMetrics.asScala.toArray)
+ }
+
@Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)