aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-06-03 13:15:57 +0200
committerSean Owen <sowen@cloudera.com>2015-06-03 13:15:57 +0200
commit28dbde3874ccdd44b73675938719b69336d23dac (patch)
tree448b051d7a8fd5be4256b71d35203d9d64aca831 /mllib
parentd38cf217e0c6bfbf451c659675280b43a08bc70f (diff)
downloadspark-28dbde3874ccdd44b73675938719b69336d23dac.tar.gz
spark-28dbde3874ccdd44b73675938719b69336d23dac.tar.bz2
spark-28dbde3874ccdd44b73675938719b69336d23dac.zip
[SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile
jira: https://issues.apache.org/jira/browse/SPARK-7983 Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT). add a quick check. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #6538 from hhbyyh/loadSVM and squashes the following commits: 79d9c11 [Yuhao Yang] optimization as respond to comments 4310710 [Yuhao Yang] merge conflict 96460f1 [Yuhao Yang] merge conflict 20a2811 [Yuhao Yang] use require 6e4f8ca [Yuhao Yang] add check for ascending order 9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception 5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala35
2 files changed, 47 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 541f3288b6..52d6468a72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, "indices should be one-based and in ascending order" )
+ previous = current
+ i += 1
+ }
+
(label, indices.toArray, values.toArray)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 734b7babec..70219e9ad9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
+import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
+ val lines =
+ """
+ |0
+ |0 0:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
+ val lines =
+ """
+ |0
+ |0 3:4.0 2:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),