aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-05-05 18:32:54 -0700
committerMatei Zaharia <matei@databricks.com>2014-05-05 18:32:54 -0700
commit98750a74daf7e2b873da85d2d5067f47e3bbdc4e (patch)
tree7751cfc30345957b4ee65bde5a0a91fe57a984e3 /mllib/src/test
parentea10b3126167af3f50f7c2a70e1d942e839fcb66 (diff)
downloadspark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.tar.gz
spark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.tar.bz2
spark-98750a74daf7e2b873da85d2d5067f47e3bbdc4e.zip
[SPARK-1594][MLLIB] Cleaning up MLlib APIs and guide
Final pass before the v1.0 release. * Remove `VectorRDDs` * Move `BinaryClassificationMetrics` from `evaluation.binary` to `evaluation` * Change default value of `addIntercept` to false and allow to add intercept in Ridge and Lasso. * Clean `DecisionTree` package doc and test suite. * Mark model constructors `private[spark]` * Rename `loadLibSVMData` to `loadLibSVMFile` and hide `LabelParser` from users. * Add `saveAsLibSVMFile`. * Add `appendBias` to `MLUtils`. Author: Xiangrui Meng <meng@databricks.com> Closes #524 from mengxr/mllib-cleaning and squashes the following commits: 295dc8b [Xiangrui Meng] update loadLibSVMFile doc 1977ac1 [Xiangrui Meng] fix doc of appendBias 649fcf0 [Xiangrui Meng] rename loadLibSVMData to loadLibSVMFile; hide LabelParser from user APIs 54b812c [Xiangrui Meng] add appendBias a71e7d0 [Xiangrui Meng] add saveAsLibSVMFile d976295 [Xiangrui Meng] Merge branch 'master' into mllib-cleaning b7e5cec [Xiangrui Meng] remove some experimental annotations and make model constructors private[mllib] 9b02b93 [Xiangrui Meng] minor code style update a593ddc [Xiangrui Meng] fix python tests fc28c18 [Xiangrui Meng] mark more classes experimental f6cbbff [Xiangrui Meng] fix Java tests 0af70b0 [Xiangrui Meng] minor 6e139ef [Xiangrui Meng] Merge branch 'master' into mllib-cleaning 94e6dce [Xiangrui Meng] move BinaryLabelCounter and BinaryConfusionMatrixImpl to evaluation.binary df34907 [Xiangrui Meng] clean DecisionTreeSuite to use LocalSparkContext c81807f [Xiangrui Meng] set the default value of AddIntercept to false 03389c0 [Xiangrui Meng] allow to add intercept in Ridge and Lasso c66c56f [Xiangrui Meng] move tree md to package object doc a2695df [Xiangrui Meng] update guide for BinaryClassificationMetrics 9194f4c [Xiangrui Meng] move BinaryClassificationMetrics one level up 1c1a0e3 [Xiangrui Meng] remove VectorRDDs because it only contains one function that is not necessary for us to maintain
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java3
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala)3
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala33
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala66
13 files changed, 71 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index e18e3bc6a8..d75d3a6b26 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
+ lrImpl.setIntercept(true);
lrImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -80,8 +81,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Test
public void runLRUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
- double B = -1.5;
+ double A = 0.0;
+ double B = -2.5;
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
@@ -92,6 +93,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
+ System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 4701a5e545..667f76a1bd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -67,6 +67,7 @@ public class JavaSVMSuite implements Serializable {
SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
SVMWithSGD svmSGDImpl = new SVMWithSGD();
+ svmSGDImpl.setIntercept(true);
svmSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(1.0)
.setNumIterations(100);
@@ -79,7 +80,7 @@ public class JavaSVMSuite implements Serializable {
@Test
public void runSVMUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 5a4410a632..7151e55351 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -68,6 +68,7 @@ public class JavaLinearRegressionSuite implements Serializable {
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ linSGDImpl.setIntercept(true);
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -77,7 +78,7 @@ public class JavaLinearRegressionSuite implements Serializable {
@Test
public void runLinearRegressionUsingStaticMethods() {
int nPoints = 100;
- double A = 3.0;
+ double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 1e03c9df82..4d7b984e3e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -46,24 +46,14 @@ object LogisticRegressionSuite {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1)
- val unifRand = new scala.util.Random(45)
- val rLogis = (0 until nPoints).map { i =>
- val u = unifRand.nextDouble()
- math.log(u) - math.log(1.0-u)
- }
-
- // y <- A + B*x + rLogis()
- // y <- as.numeric(y > 0)
- val y: Seq[Int] = (0 until nPoints).map { i =>
- val yVal = offset + scale * x1(i) + rLogis(i)
- if (yVal > 0) 1 else 0
+ val y = (0 until nPoints).map { i =>
+ val p = 1.0 / (1.0 + math.exp(-(offset + scale * x1(i))))
+ if (rnd.nextDouble() < p) 1.0 else 0.0
}
val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
testData
}
-
}
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
@@ -85,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val lr = new LogisticRegressionWithSGD()
+ val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer.setStepSize(10.0).setNumIterations(20)
val model = lr.run(testRDD)
@@ -118,7 +108,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
testRDD.cache()
// Use half as many iterations as the previous test.
- val lr = new LogisticRegressionWithSGD()
+ val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer.setStepSize(10.0).setNumIterations(10)
val model = lr.run(testRDD, initialWeights)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index dfacbfeee6..77d6f04b32 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,7 +69,6 @@ class SVMSuite extends FunSuite with LocalSparkContext {
assert(numOffPredictions < input.length / 5)
}
-
test("SVM using local random SGD") {
val nPoints = 10000
@@ -83,7 +82,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val svm = new SVMWithSGD()
+ val svm = new SVMWithSGD().setIntercept(true)
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
val model = svm.run(testRDD)
@@ -115,7 +114,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val svm = new SVMWithSGD()
+ val svm = new SVMWithSGD().setIntercept(true)
svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
val model = svm.run(testRDD, initialWeights)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 173fdaefab..9d16182f9d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.evaluation.binary
+package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.evaluation.AreaUnderCurve
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
test("binary evaluation metrics") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala
deleted file mode 100644
index 692f025e95..0000000000
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.rdd
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
-
-class VectorRDDsSuite extends FunSuite with LocalSparkContext {
-
- test("from array rdd") {
- val data = Seq(Array(1.0, 2.0), Array(3.0, 4.0))
- val arrayRdd = sc.parallelize(data, 2)
- val vectorRdd = VectorRDDs.fromArrayRDD(arrayRdd)
- assert(arrayRdd.collect().map(v => Vectors.dense(v)) === vectorRdd.collect())
- }
-}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 4dfcd4b52e..2d944f3eb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -27,7 +27,6 @@ import org.jblas.DoubleMatrix
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.SparkContext._
-import org.apache.spark.Partitioner
object ALSSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 6aad9eb84e..bfa42959c8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -112,10 +112,4 @@ class LassoSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
-
- test("do not support intercept") {
- intercept[UnsupportedOperationException] {
- new LassoWithSGD().setIntercept(true)
- }
- }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 2f7d30708c..7aaad7d7a3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -37,7 +37,7 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
test("linear regression") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 42), 2).cache()
- val linReg = new LinearRegressionWithSGD()
+ val linReg = new LinearRegressionWithSGD().setIntercept(true)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index f66fc6ea6c..67768e17fb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -72,10 +72,4 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
-
- test("do not support intercept") {
- intercept[UnsupportedOperationException] {
- new RidgeRegressionWithSGD().setIntercept(true)
- }
- }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 350130c914..be383aab71 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -17,10 +17,8 @@
package org.apache.spark.mllib.tree
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
@@ -28,19 +26,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.LocalSparkContext
-class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
-
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 674378a34c..3f64baf6fe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.mllib.util
import java.io.File
+import scala.io.Source
import scala.math
-import scala.util.Random
import org.scalatest.FunSuite
@@ -29,7 +29,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNor
import com.google.common.base.Charsets
import com.google.common.io.Files
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
class MLUtilsSuite extends FunSuite with LocalSparkContext {
@@ -58,7 +59,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}
- test("loadLibSVMData") {
+ test("loadLibSVMFile") {
val lines =
"""
|+1 1:1.0 3:2.0 5:3.0
@@ -70,8 +71,8 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
- val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
- val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
+ val pointsWithNumFeatures = loadLibSVMFile(sc, path, multiclass = false, 6).collect()
+ val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect()
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
assert(points.length === 3)
@@ -83,29 +84,54 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}
- val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
+ val multiclassPoints = loadLibSVMFile(sc, path, multiclass = true).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
assert(multiclassPoints(2).label === -1.0)
- try {
- file.delete()
- tempDir.delete()
- } catch {
- case t: Throwable =>
- }
+ deleteQuietly(tempDir)
+ }
+
+ test("saveAsLibSVMFile") {
+ val examples = sc.parallelize(Seq(
+ LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
+ LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03))
+ ), 2)
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output")
+ MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString)
+ val lines = outputDir.listFiles()
+ .filter(_.getName.startsWith("part-"))
+ .flatMap(Source.fromFile(_).getLines())
+ .toSet
+ val expected = Set("1.1 1:1.23 3:4.56", "0.0 1:1.01 2:2.02 3:3.03")
+ assert(lines === expected)
+ deleteQuietly(tempDir)
+ }
+
+ test("appendBias") {
+ val sv = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0)))
+ val sv1 = appendBias(sv).asInstanceOf[SparseVector]
+ assert(sv1.size === 4)
+ assert(sv1.indices === Array(0, 2, 3))
+ assert(sv1.values === Array(1.0, 3.0, 1.0))
+
+ val dv = Vectors.dense(1.0, 0.0, 3.0)
+ val dv1 = appendBias(dv).asInstanceOf[DenseVector]
+ assert(dv1.size === 4)
+ assert(dv1.values === Array(1.0, 0.0, 3.0, 1.0))
}
test("kFold") {
val data = sc.parallelize(1 to 100, 2)
val collectedData = data.collect().sorted
- val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
+ val twoFoldedRdd = kFold(data, 2, 1)
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
- val foldedRdds = MLUtils.kFold(data, folds, seed)
+ val foldedRdds = kFold(data, folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map { case (training, validation) =>
val result = validation.union(training).collect().sorted
@@ -132,4 +158,16 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}
+ /** Delete a file/directory quietly. */
+ def deleteQuietly(f: File) {
+ if (f.isDirectory) {
+ f.listFiles().foreach(deleteQuietly)
+ }
+ try {
+ f.delete()
+ } catch {
+ case _: Throwable =>
+ }
+ }
}
+