aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLian, Cheng <rhythm.mail@gmail.com>2013-12-29 17:13:01 +0800
committerLian, Cheng <rhythm.mail@gmail.com>2013-12-29 17:13:01 +0800
commitf150b6e76c56ed6f604e6dbda7bce6b6278929fb (patch)
treeb52cba135c173d57203b60e7b5258e839533dcc7
parentd7086dc28a856ec8856278be108310ec8264a115 (diff)
downloadspark-f150b6e76c56ed6f604e6dbda7bce6b6278929fb.tar.gz
spark-f150b6e76c56ed6f604e6dbda7bce6b6278929fb.tar.bz2
spark-f150b6e76c56ed6f604e6dbda7bce6b6278929fb.zip
Response to Reynold's comments
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala26
1 files changed, 16 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 2bc4c5afc0..d0f3a368e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,20 +17,22 @@
package org.apache.spark.mllib.classification
+import org.jblas.DoubleMatrix
+
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
-import org.jblas.DoubleMatrix
/**
* Model for Naive Bayes Classifiers.
*
- * @param weightPerLabel Weights computed for every label, which's dimension is C.
- * @param weightMatrix Weights computed for every label and feature, which's dimension is CXD
+ * @param weightPerLabel Weights computed for every label, whose dimension is C.
+ * @param weightMatrix Weights computed for every label and feature, whose dimension is CXD
*/
-class NaiveBayesModel(val weightPerLabel: Array[Double],
- val weightMatrix: Array[Array[Double]])
+class NaiveBayesModel(
+ @transient val weightPerLabel: Array[Double],
+ @transient val weightMatrix: Array[Array[Double]])
extends ClassificationModel with Serializable {
// Create a column vector that can be used for predictions
@@ -50,7 +52,12 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
extends Serializable with Logging {
private def vectorAdd(v1: Array[Double], v2: Array[Double]) = {
- v1.zip(v2).map(pair => pair._1 + pair._2)
+ var i = 0
+ while (i < v1.length) {
+ v1(i) += v2(i)
+ i += 1
+ }
+ v1
}
/**
@@ -79,8 +86,8 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
// considerably expensive.
val N = collected.values.map(_._1).sum
val logDenom = math.log(N + C * lambda)
- val weightPerLabel = Array.fill[Double](C)(0)
- val weightMatrix = Array.fill[Array[Double]](C)(null)
+ val weightPerLabel = new Array[Double](C)
+ val weightMatrix = new Array[Array[Double]](C)
for ((label, (_, labelWeight, weights)) <- collected) {
weightPerLabel(label) = labelWeight - logDenom
@@ -100,8 +107,7 @@ object NaiveBayes {
* @param input RDD of (label, array of features) pairs.
* @param lambda smooth parameter
*/
- def train(C: Int, D: Int, input: RDD[LabeledPoint],
- lambda: Double = 1.0): NaiveBayesModel = {
+ def train(C: Int, D: Int, input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
new NaiveBayes(lambda).run(C, D, input)
}
}