aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAlain <aihe@usc.edu>2015-05-05 16:47:34 +0100
committerSean Owen <sowen@cloudera.com>2015-05-05 16:47:34 +0100
commitd4cb38aeb7412a353c6cbca2a9b8f9729afbaba7 (patch)
tree57def81c93b434a5ef1ad42c44192074435aa52a /mllib
parent9d250e64dac263bcbbad6b023382ac7b5b592408 (diff)
downloadspark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.tar.gz
spark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.tar.bz2
spark-d4cb38aeb7412a353c6cbca2a9b8f9729afbaba7.zip
[MLLIB] [TREE] Verify size of input rdd > 0 when building meta data
Require non empty input rdd such that we can take the first labeledpoint and get the feature size Author: Alain <aihe@usc.edu> Author: aihe@usc.edu <aihe@usc.edu> Closes #5810 from AiHe/decisiontree-issue and squashes the following commits: 3b1d08a [aihe@usc.edu] [MLLIB][tree] merge the assertion into the evaluation of numFeatures cf2e567 [Alain] [MLLIB][tree] Use a rdd api to verify size of input rdd > 0 when building meta data b448f47 [Alain] [MLLIB][tree] Verify size of input rdd > 0 when building meta data
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala5
1 files changed, 4 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index f1a6ed2301..f73896e37c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -107,7 +107,10 @@ private[tree] object DecisionTreeMetadata extends Logging {
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
- val numFeatures = input.take(1)(0).features.size
+ val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
+ throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
+ s"but was given by empty one.")
+ }
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClasses