diff options
Diffstat (limited to 'mllib/src/main')
3 files changed, 29 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7837b6730..c368aadd23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import java.lang.reflect.Modifier +import java.util.{List => JList} import java.util.NoSuchElementException import scala.annotation.varargs @@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) this } + /** Put param pairs with a [[java.util.List]] of values for Python. */ + private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { + put(paramPairs.asScala: _*) + } + /** * Optionally returns the value associated with a param. */ @@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) } } + /** Java-friendly method for Python API */ + private[ml] def toList: java.util.List[ParamPair[_]] = { + this.toSeq.asJava + } + /** * Number of param pairs in this map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 040b0093b9..4d9d4d472e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) + } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 07330bb6b0..0f2179c2a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { + this(uid, bestModel, validationMetrics.asScala.toArray) + } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) |