diff options
author | Alain <aihe@usc.edu> | 2015-05-05 16:47:34 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-05-05 16:47:34 +0100 |
commit | d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7 (patch) | |
tree | 57def81c93b434a5ef1ad42c44192074435aa52a /mllib | |
parent | 9d250e64dac263bcbbad6b023382ac7b5b592408 (diff) | |
download | spark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.tar.gz spark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.tar.bz2 spark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.zip |
[MLLIB] [TREE] Verify size of input rdd > 0 when building meta data
Require non empty input rdd such that we can take the first labeledpoint and get the feature size
Author: Alain <aihe@usc.edu>
Author: aihe@usc.edu <aihe@usc.edu>
Closes #5810 from AiHe/decisiontree-issue and squashes the following commits:
3b1d08a [aihe@usc.edu] [MLLIB][tree] merge the assertion into the evaluation of numFeatures
cf2e567 [Alain] [MLLIB][tree] Use a rdd api to verify size of input rdd > 0 when building meta data
b448f47 [Alain] [MLLIB][tree] Verify size of input rdd > 0 when building meta data
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index f1a6ed2301..f73896e37c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -107,7 +107,10 @@ private[tree] object DecisionTreeMetadata extends Logging { numTrees: Int, featureSubsetStrategy: String): DecisionTreeMetadata = { - val numFeatures = input.take(1)(0).features.size + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } val numExamples = input.count() val numClasses = strategy.algo match { case Classification => strategy.numClasses |