aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-10-05 23:03:09 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-05 23:03:09 -0700
commit7aeb20be7e999523784aca7be1a7c9c99dec125e (patch)
tree90015f378a4281a504eb36e12e90cd701f70bcb5
parentb678e465afa417780b54db0fbbaa311621311f15 (diff)
downloadspark-7aeb20be7e999523784aca7be1a7c9c99dec125e.tar.gz
spark-7aeb20be7e999523784aca7be1a7c9c99dec125e.tar.bz2
spark-7aeb20be7e999523784aca7be1a7c9c99dec125e.zip
[MINOR][ML] Avoid 2D array flatten in NB training.
## What changes were proposed in this pull request? Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might be expensive (It will create another array and copy data there). ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15359 from yanboliang/nb-theta.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 6775745167..e565a6fd3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") (
val numLabels = aggregated.length
val numDocuments = aggregated.map(_._2._1).sum
- val piArray = Array.fill[Double](numLabels)(0.0)
- val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0)
+ val piArray = new Array[Double](numLabels)
+ val thetaArray = new Array[Double](numLabels * numFeatures)
val lambda = $(smoothing)
val piLogDenom = math.log(numDocuments + numLabels * lambda)
@@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") (
}
var j = 0
while (j < numFeatures) {
- thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
}
val pi = Vectors.dense(piArray)
- val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true)
+ val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
new NaiveBayesModel(uid, pi, theta)
}