aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-03-20 14:53:59 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 14:53:59 -0400
commit25636d9867c6bc901463b6b227cb444d701cfdd1 (patch)
tree0decdcff3c8d20c399d792cfde375f9c737acd8d /mllib/src
parent5e6ad24ff645a9b0f63d9c0f17193550963aa0a7 (diff)
downloadspark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.gz
spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.bz2
spark-25636d9867c6bc901463b6b227cb444d701cfdd1.zip
[Spark 6096][MLlib] Add Naive Bayes load save methods in Python
See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096). Author: Xusen Yin <yinxusen@gmail.com> Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits: bd0fea5 [Xusen Yin] fix style problem, etc. 3fd41f2 [Xusen Yin] use hanging indent in Python style e83803d [Xusen Yin] fix Python style d6dbde5 [Xusen Yin] fix python call java error a054bb3 [Xusen Yin] add save load for NaiveBayes python
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala11
1 files changed, 11 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 068449aa1d..d60e82c410 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,6 +17,10 @@
package org.apache.spark.mllib.classification
+import java.lang.{Iterable => JIterable}
+
+import scala.collection.JavaConverters._
+
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] (
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
+ /** A Java-friendly constructor that takes three Iterable parameters. */
+ private[mllib] def this(
+ labels: JIterable[Double],
+ pi: JIterable[Double],
+ theta: JIterable[JIterable[Double]]) =
+ this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
+
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)