aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala35
2 files changed, 47 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 541f3288b6..52d6468a72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, "indices should be one-based and in ascending order" )
+ previous = current
+ i += 1
+ }
+
(label, indices.toArray, values.toArray)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 734b7babec..70219e9ad9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
+import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
+ val lines =
+ """
+ |0
+ |0 0:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
+ val lines =
+ """
+ |0
+ |0 3:4.0 2:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),