diff options
author | Joseph K. Bradley <joseph.kurata.bradley@gmail.com> | 2015-03-13 10:26:09 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-13 10:26:09 -0700 |
commit | dc4abd4dc40deacab39bfa9572b06bf0ea6daa6d (patch) | |
tree | 3c342af22d24d2dc4f87d7b9f9fe5b2377702230 | |
parent | ea3d2eed9b0a94b34543d9a9df87dc63a279deb1 (diff) | |
download | spark-dc4abd4dc40deacab39bfa9572b06bf0ea6daa6d.tar.gz spark-dc4abd4dc40deacab39bfa9572b06bf0ea6daa6d.tar.bz2 spark-dc4abd4dc40deacab39bfa9572b06bf0ea6daa6d.zip |
[SPARK-6252] [mllib] Added getLambda to Scala NaiveBayes
Note: not relevant for Python API since it only has a static train method
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #4969 from jkbradley/SPARK-6252 and squashes the following commits:
a471d90 [Joseph K. Bradley] small edits from review
63eff48 [Joseph K. Bradley] Added getLambda to Scala NaiveBayes
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 3 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala | 8 |
2 files changed, 11 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b11fd4f128..2ebc7fa5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -166,6 +166,9 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } + /** Get the smoothing parameter. Default: 1.0. */ + def getLambda: Double = lambda + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 64dcc0fb9f..5a27c7d230 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -85,6 +85,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + test("get, set params") { + val nb = new NaiveBayes() + nb.setLambda(2.0) + assert(nb.getLambda === 2.0) + nb.setLambda(3.0) + assert(nb.getLambda === 3.0) + } + test("Naive Bayes") { val nPoints = 10000 |