aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMatei Zaharia <matei.zaharia@gmail.com>2013-08-09 20:41:13 -0700
committerMatei Zaharia <matei.zaharia@gmail.com>2013-08-09 20:41:13 -0700
commitcd247ba5bb54afa332519826028ab68a4f73849e (patch)
tree14a5862c20ed86c019ec727914ac0e1af4732264 /mllib
parentb09d4b79e83330c96c161ea4eb9af284f0a835e6 (diff)
parente1a209f791a29225c7c75861aa4a18b14739fcc4 (diff)
downloadspark-cd247ba5bb54afa332519826028ab68a4f73849e.tar.gz
spark-cd247ba5bb54afa332519826028ab68a4f73849e.tar.bz2
spark-cd247ba5bb54afa332519826028ab68a4f73849e.zip
Merge pull request #786 from shivaram/mllib-java
Java fixes, tests and examples for ALS, KMeans
Diffstat (limited to 'mllib')
-rw-r--r--mllib/pom.xml5
-rw-r--r--mllib/src/main/scala/spark/mllib/clustering/KMeans.scala10
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala21
-rw-r--r--mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java115
-rw-r--r--mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala54
-rw-r--r--mllib/src/test/scala/spark/mllib/recommendation/JavaALSSuite.java110
6 files changed, 285 insertions, 30 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index f3928cc73d..a07480fbe2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -52,6 +52,11 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
index b402c71ed2..97e3d110ae 100644
--- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
@@ -112,7 +112,7 @@ class KMeans private (
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
- def train(data: RDD[Array[Double]]): KMeansModel = {
+ def run(data: RDD[Array[Double]]): KMeansModel = {
// TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable
val sc = data.sparkContext
@@ -194,8 +194,8 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new Random().nextInt())
- Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k))
+ val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
+ Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
/**
@@ -210,7 +210,7 @@ class KMeans private (
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
val seed = new Random().nextInt()
- val sample = data.takeSample(true, runs, seed)
+ val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
// On each step, sample 2 * k points on average for each run with probability proportional
@@ -271,7 +271,7 @@ object KMeans {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
- .train(data)
+ .run(data)
}
def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 6ecf0151a1..6c71dc1f32 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -55,8 +55,7 @@ private[recommendation] case class InLinkBlock(
/**
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
*/
-private[recommendation] case class Rating(user: Int, product: Int, rating: Double)
-
+case class Rating(val user: Int, val product: Int, val rating: Double)
/**
* Alternating Least Squares matrix factorization.
@@ -107,7 +106,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
*/
- def train(ratings: RDD[(Int, Int, Double)]): MatrixFactorizationModel = {
+ def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
val numBlocks = if (this.numBlocks == -1) {
math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2)
} else {
@@ -116,8 +115,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val partitioner = new HashPartitioner(numBlocks)
- val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
- val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
+ val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) }
+ val ratingsByProductBlock = ratings.map{ rating =>
+ (rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating))
+ }
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
@@ -356,14 +357,14 @@ object ALS {
* @param blocks level of parallelism to split computation into
*/
def train(
- ratings: RDD[(Int, Int, Double)],
+ ratings: RDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
blocks: Int)
: MatrixFactorizationModel =
{
- new ALS(blocks, rank, iterations, lambda).train(ratings)
+ new ALS(blocks, rank, iterations, lambda).run(ratings)
}
/**
@@ -378,7 +379,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
*/
- def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int, lambda: Double)
+ def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
: MatrixFactorizationModel =
{
train(ratings, rank, iterations, lambda, -1)
@@ -395,7 +396,7 @@ object ALS {
* @param rank number of features to use
* @param iterations number of iterations of ALS (recommended: 10-20)
*/
- def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int)
+ def train(ratings: RDD[Rating], rank: Int, iterations: Int)
: MatrixFactorizationModel =
{
train(ratings, rank, iterations, 0.01, -1)
@@ -423,7 +424,7 @@ object ALS {
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
- (fields(0).toInt, fields(1).toInt, fields(2).toDouble)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
}
val model = ALS.train(ratings, rank, iters, 0.01, blocks)
model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
diff --git a/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java
new file mode 100644
index 0000000000..3f2d82bfb4
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+
+public class JavaKMeansSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaKMeans");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ // L1 distance between two points
+ double distance1(double[] v1, double[] v2) {
+ double distance = 0.0;
+ for (int i = 0; i < v1.length; ++i) {
+ distance = Math.max(distance, Math.abs(v1[i] - v2[i]));
+ }
+ return distance;
+ }
+
+ // Assert that two sets of points are equal, within EPSILON tolerance
+ void assertSetsEqual(double[][] v1, double[][] v2) {
+ double EPSILON = 1e-4;
+ Assert.assertTrue(v1.length == v2.length);
+ for (int i = 0; i < v1.length; ++i) {
+ double minDistance = Double.MAX_VALUE;
+ for (int j = 0; j < v2.length; ++j) {
+ minDistance = Math.min(minDistance, distance1(v1[i], v2[j]));
+ }
+ Assert.assertTrue(minDistance <= EPSILON);
+ }
+
+ for (int i = 0; i < v2.length; ++i) {
+ double minDistance = Double.MAX_VALUE;
+ for (int j = 0; j < v1.length; ++j) {
+ minDistance = Math.min(minDistance, distance1(v2[i], v1[j]));
+ }
+ Assert.assertTrue(minDistance <= EPSILON);
+ }
+ }
+
+
+ @Test
+ public void runKMeansUsingStaticMethods() {
+ List<double[]> points = new ArrayList();
+ points.add(new double[]{1.0, 2.0, 6.0});
+ points.add(new double[]{1.0, 3.0, 0.0});
+ points.add(new double[]{1.0, 4.0, 6.0});
+
+ double[][] expectedCenter = { {1.0, 3.0, 4.0} };
+
+ JavaRDD<double[]> data = sc.parallelize(points, 2);
+ KMeansModel model = KMeans.train(data.rdd(), 1, 1);
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
+
+ model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM());
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
+ }
+
+ @Test
+ public void runKMeansUsingConstructor() {
+ List<double[]> points = new ArrayList();
+ points.add(new double[]{1.0, 2.0, 6.0});
+ points.add(new double[]{1.0, 3.0, 0.0});
+ points.add(new double[]{1.0, 4.0, 6.0});
+
+ double[][] expectedCenter = { {1.0, 3.0, 4.0} };
+
+ JavaRDD<double[]> data = sc.parallelize(points, 2);
+ KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
+
+ model = new KMeans().setK(1)
+ .setMaxIterations(1)
+ .setRuns(1)
+ .setInitializationMode(KMeans.RANDOM())
+ .run(data.rdd());
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
+ }
+}
diff --git a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala
index f98590b8d9..3a556fdc29 100644
--- a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala
@@ -17,6 +17,7 @@
package spark.mllib.recommendation
+import scala.collection.JavaConversions._
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
@@ -27,6 +28,42 @@ import spark.SparkContext._
import org.jblas._
+object ALSSuite {
+
+ def generateRatingsAsJavaList(
+ users: Int,
+ products: Int,
+ features: Int,
+ samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
+ val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
+ (seqAsJavaList(sampledRatings), trueRatings)
+ }
+
+ def generateRatings(
+ users: Int,
+ products: Int,
+ features: Int,
+ samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
+ val rand = new Random(42)
+
+ // Create a random matrix with uniform values from -1 to 1
+ def randomMatrix(m: Int, n: Int) =
+ new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
+
+ val userMatrix = randomMatrix(users, features)
+ val productMatrix = randomMatrix(features, products)
+ val trueRatings = userMatrix.mmul(productMatrix)
+
+ val sampledRatings = {
+ for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
+ yield Rating(u, p, trueRatings.get(u, p))
+ }
+
+ (sampledRatings, trueRatings)
+ }
+
+}
+
class ALSSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test")
@@ -57,21 +94,8 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
def testALS(users: Int, products: Int, features: Int, iterations: Int,
samplingRate: Double, matchThreshold: Double)
{
- val rand = new Random(42)
-
- // Create a random matrix with uniform values from -1 to 1
- def randomMatrix(m: Int, n: Int) =
- new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
-
- val userMatrix = randomMatrix(users, features)
- val productMatrix = randomMatrix(features, products)
- val trueRatings = userMatrix.mmul(productMatrix)
-
- val sampledRatings = {
- for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
- yield (u, p, trueRatings.get(u, p))
- }
-
+ val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
+ features, samplingRate)
val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
val predictedU = new DoubleMatrix(users, features)
diff --git a/mllib/src/test/scala/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/scala/spark/mllib/recommendation/JavaALSSuite.java
new file mode 100644
index 0000000000..7993629a6d
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/recommendation/JavaALSSuite.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.recommendation;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+
+import org.jblas.DoubleMatrix;
+
+public class JavaALSSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaALS");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
+ DoubleMatrix trueRatings, double matchThreshold) {
+ DoubleMatrix predictedU = new DoubleMatrix(users, features);
+ List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
+ for (int i = 0; i < features; ++i) {
+ for (scala.Tuple2<Object, double[]> userFeature : userFeatures) {
+ predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
+ }
+ }
+ DoubleMatrix predictedP = new DoubleMatrix(products, features);
+
+ List<scala.Tuple2<Object, double[]>> productFeatures =
+ model.productFeatures().toJavaRDD().collect();
+ for (int i = 0; i < features; ++i) {
+ for (scala.Tuple2<Object, double[]> productFeature : productFeatures) {
+ predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
+ }
+ }
+
+ DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
+
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double correct = trueRatings.get(u, p);
+ Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
+ }
+ }
+ }
+
+ @Test
+ public void runALSUsingStaticMethods() {
+ int features = 1;
+ int iterations = 15;
+ int users = 10;
+ int products = 10;
+ scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.3);
+ }
+
+ @Test
+ public void runALSUsingConstructor() {
+ int features = 2;
+ int iterations = 15;
+ int users = 20;
+ int products = 30;
+ scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .run(data.rdd());
+ validatePrediction(model, users, products, features, testData._2(), 0.3);
+ }
+}