aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-19 18:06:16 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-19 18:06:16 -0800
commit0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f (patch)
tree36bdfdec69a205b85f7b85697c36abf2044d9ff5 /mllib/src/test
parent6bddc40353057a562c78e75c5549c79a0d7d5f8b (diff)
downloadspark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.tar.gz
spark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.tar.bz2
spark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.zip
[SPARK-5900][MLLIB] make PIC and FPGrowth Java-friendly
In the previous version, PIC stores clustering assignments as an `RDD[(Long, Int)]`. This is mapped to `RDD<Tuple2<Object, Object>>` in Java and hence Java users have to cast types manually. We should either create a new method called `javaAssignments` that returns `JavaRDD[(java.lang.Long, java.lang.Int)]` or wrap the result pair in a class. I chose the latter approach in this PR. Now assignments are stored as an `RDD[Assignment]`, where `Assignment` is a class with `id` and `cluster`. Similarly, in FPGrowth, the frequent itemsets are stored as an `RDD[(Array[Item], Long)]`, which is mapped to `RDD<Tuple2<Object, Object>>`. Though we provide a "Java-friendly" method `javaFreqItemsets` that returns `JavaRDD[(Array[Item], java.lang.Long)]`. It doesn't really work because `Array[Item]` is mapped to `Object` in Java. So in this PR I created a class `FreqItemset` to wrap the results. It has `items` and `freq`, as well as a `javaItems` method that returns `List<Item>` in Java. I'm not certain that the names I chose are proper: `Assignment`/`id`/`cluster` and `FreqItemset`/`items`/`freq`. Please let me know if there are better suggestions. CC: jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4695 from mengxr/SPARK-5900 and squashes the following commits: 865b5ca [Xiangrui Meng] make Assignment serializable cffa96e [Xiangrui Meng] fix test 9c0e590 [Xiangrui Meng] remove unused Tuple2 1b9db3d [Xiangrui Meng] make PIC and FPGrowth Java-friendly
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java30
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala10
3 files changed, 19 insertions, 29 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 851707c8a1..bd0edf2b9e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.fpm;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -55,30 +57,18 @@ public class JavaFPGrowthSuite implements Serializable {
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);
- FPGrowth fpg = new FPGrowth();
-
- FPGrowthModel<String> model6 = fpg
- .setMinSupport(0.9)
- .setNumPartitions(1)
- .run(rdd);
- assertEquals(0, model6.javaFreqItemsets().count());
-
- FPGrowthModel<String> model3 = fpg
+ FPGrowthModel<String> model = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
- assertEquals(18, model3.javaFreqItemsets().count());
- FPGrowthModel<String> model2 = fpg
- .setMinSupport(0.3)
- .setNumPartitions(4)
- .run(rdd);
- assertEquals(54, model2.javaFreqItemsets().count());
+ List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+ assertEquals(18, freqItemsets.size());
- FPGrowthModel<String> model1 = fpg
- .setMinSupport(0.1)
- .setNumPartitions(8)
- .run(rdd);
- assertEquals(625, model1.javaFreqItemsets().count());
+ for (FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 03ecd9ca73..6315c03a70 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -51,8 +51,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setK(2)
.run(sc.parallelize(similarities, 2))
val predictions = Array.fill(2)(mutable.Set.empty[Long])
- model.assignments.collect().foreach { case (i, c) =>
- predictions(c) += i
+ model.assignments.collect().foreach { a =>
+ predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
@@ -61,8 +61,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setInitializationMode("degree")
.run(sc.parallelize(similarities, 2))
val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
- model2.assignments.collect().foreach { case (i, c) =>
- predictions2(c) += i
+ model2.assignments.collect().foreach { a =>
+ predictions2(a.cluster) += a.id
}
assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 68128284b8..bd5b9cc3af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -46,8 +46,8 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
@@ -96,10 +96,10 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
+ assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),