diff options
author | Xusen Yin <yinxusen@gmail.com> | 2015-03-20 14:53:59 -0400 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-20 14:53:59 -0400 |
commit | 25636d9867c6bc901463b6b227cb444d701cfdd1 (patch) | |
tree | 0decdcff3c8d20c399d792cfde375f9c737acd8d /mllib/src | |
parent | 5e6ad24ff645a9b0f63d9c0f17193550963aa0a7 (diff) | |
download | spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.gz spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.bz2 spark-25636d9867c6bc901463b6b227cb444d701cfdd1.zip |
[Spark 6096][MLlib] Add Naive Bayes load save methods in Python
See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096).
Author: Xusen Yin <yinxusen@gmail.com>
Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits:
bd0fea5 [Xusen Yin] fix style problem, etc.
3fd41f2 [Xusen Yin] use hanging indent in Python style
e83803d [Xusen Yin] fix Python style
d6dbde5 [Xusen Yin] fix python call java error
a054bb3 [Xusen Yin] add save load for NaiveBayes python
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 068449aa1d..d60e82c410 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,6 +17,10 @@ package org.apache.spark.mllib.classification +import java.lang.{Iterable => JIterable} + +import scala.collection.JavaConverters._ + import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] ( val pi: Array[Double], val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { + /** A Java-friendly constructor that takes three Iterable parameters. */ + private[mllib] def this( + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = + this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) + private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) |