diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-10-05 23:03:09 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-10-05 23:03:09 -0700 |
commit | 7aeb20be7e999523784aca7be1a7c9c99dec125e (patch) | |
tree | 90015f378a4281a504eb36e12e90cd701f70bcb5 /mllib/src/main/scala | |
parent | b678e465afa417780b54db0fbbaa311621311f15 (diff) | |
download | spark-7aeb20be7e999523784aca7be1a7c9c99dec125e.tar.gz spark-7aeb20be7e999523784aca7be1a7c9c99dec125e.tar.bz2 spark-7aeb20be7e999523784aca7be1a7c9c99dec125e.zip |
[MINOR][ML] Avoid 2D array flatten in NB training.
## What changes were proposed in this pull request?
Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might be expensive (It will create another array and copy data there).
## How was this patch tested?
Existing tests.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #15359 from yanboliang/nb-theta.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 6775745167..e565a6fd3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum - val piArray = Array.fill[Double](numLabels)(0.0) - val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + val piArray = new Array[Double](numLabels) + val thetaArray = new Array[Double](numLabels * numFeatures) val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) @@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") ( } var j = 0 while (j < numFeatures) { - thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1 } val pi = Vectors.dense(piArray) - val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) new NaiveBayesModel(uid, pi, theta) } |