aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2016-05-14 09:45:56 +0100
committerSean Owen <sowen@cloudera.com>2016-05-14 09:45:56 +0100
commit354f8f11bd4b20fa99bd67a98da3525fd3d75c81 (patch)
tree91811f926f3b46d39efe908f1468e4c4bc74bc28 /mllib
parent0f1f31d3a6669fbac474518cf2a871485e202bdc (diff)
downloadspark-354f8f11bd4b20fa99bd67a98da3525fd3d75c81.tar.gz
spark-354f8f11bd4b20fa99bd67a98da3525fd3d75c81.tar.bz2
spark-354f8f11bd4b20fa99bd67a98da3525fd3d75c81.zip
[SPARK-15096][ML] LogisticRegression MultiClassSummarizer numClasses can fail if no valid labels are found
## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Throw better exception when numClasses is empty and empty.max is thrown. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add a new unit test, which calls histogram with empty numClasses. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #12969 from wangmiao1981/logisticR.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala4
2 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3e8040d3e9..ffd03e55b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -745,7 +745,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
def countInvalid: Long = totalInvalidCnt
/** @return The number of distinct labels in the input dataset. */
- def numClasses: Int = distinctMap.keySet.max + 1
+ def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1
/** @return The weightSum of each label in the input dataset. */
def histogram: Array[Double] = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index f127aa217c..69650ebb36 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -256,6 +256,10 @@ class LogisticRegressionSuite
assert(summarizer4.countInvalid === 2)
assert(summarizer4.numClasses === 4)
+ val summarizer5 = new MultiClassSummarizer
+ assert(summarizer5.histogram.isEmpty)
+ assert(summarizer5.numClasses === 0)
+
// small map merges large one
val summarizerA = summarizer1.merge(summarizer2)
assert(summarizerA.hashCode() === summarizer2.hashCode())