diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-04-18 11:52:29 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-18 11:52:29 -0700 |
commit | b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch) | |
tree | 18131b3a63a970be653d9350785dc0ab0bcbbfff /mllib | |
parent | 775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff) | |
download | spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2 spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip |
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-14306
Add PySpark OneVsRest save/load supports.
## How was this patch tested?
Test with Python unit test.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 4de1b877b0..f10c60a78d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.classification +import java.util.{List => JList} import java.util.UUID +import scala.collection.JavaConverters._ import scala.language.existentials import org.apache.hadoop.fs.Path @@ -135,6 +137,11 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = { + this(uid, Metadata.empty, models.asScala.toArray) + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) |