aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-01-26 18:03:21 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-26 18:03:21 -0800
commitf2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 (patch)
treed07a22d3c559f1e3d5006ba8df227d224da3622d /mllib
parent661e0fca5d5d86efab5fb26da600ac2ac96b09ec (diff)
downloadspark-f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1.tar.gz
spark-f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1.tar.bz2
spark-f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1.zip
[SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train...
... decision tree model Labels loaded from libsvm files are mapped to 0.0 if they are negative labels because they should be nonnegative value. Author: lewuathe <lewuathe@me.com> Closes #3975 from Lewuathe/map-negative-label-to-positive and squashes the following commits: 12d1d59 [lewuathe] [SPARK-5119] Fix code styles 6d9a18a [lewuathe] [SPARK-5119] Organize test codes 62a150c [lewuathe] [SPARK-5119] Modify Impurities throw exceptions with negatie labels 3336c21 [lewuathe] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train decision tree model
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala42
3 files changed, 52 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 0e02345aa3..b7950e0078 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
val lbl = label.toInt
require(lbl < stats.length,
s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "Entropy does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 7c83cd48e1..c946db9c0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
val lbl = label.toInt
require(lbl < stats.length,
s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "GiniImpurity does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
new file mode 100644
index 0000000000..92b498580a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
+ */
+class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+ test("Gini impurity does not support negative labels") {
+ val gini = new GiniAggregator(2)
+ intercept[IllegalArgumentException] {
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+
+ test("Entropy does not support negative labels") {
+ val entropy = new EntropyAggregator(2)
+ intercept[IllegalArgumentException] {
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+}