aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-18 11:52:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-18 11:52:29 -0700
commitb64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch)
tree18131b3a63a970be653d9350785dc0ab0bcbbfff /mllib
parent775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff)
downloadspark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala7
1 files changed, 7 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 4de1b877b0..f10c60a78d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -17,8 +17,10 @@
package org.apache.spark.ml.classification
+import java.util.{List => JList}
import java.util.UUID
+import scala.collection.JavaConverters._
import scala.language.existentials
import org.apache.hadoop.fs.Path
@@ -135,6 +137,11 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = {
+ this(uid, Metadata.empty, models.asScala.toArray)
+ }
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)