aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala35
1 files changed, 35 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 734b7babec..70219e9ad9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
+import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
+ val lines =
+ """
+ |0
+ |0 0:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
+ val lines =
+ """
+ |0
+ |0 3:4.0 2:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),