aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 6775745167..e565a6fd3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") (
val numLabels = aggregated.length
val numDocuments = aggregated.map(_._2._1).sum
- val piArray = Array.fill[Double](numLabels)(0.0)
- val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0)
+ val piArray = new Array[Double](numLabels)
+ val thetaArray = new Array[Double](numLabels * numFeatures)
val lambda = $(smoothing)
val piLogDenom = math.log(numDocuments + numLabels * lambda)
@@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") (
}
var j = 0
while (j < numFeatures) {
- thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
}
val pi = Vectors.dense(piArray)
- val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true)
+ val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
new NaiveBayesModel(uid, pi, theta)
}