aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java30
1 files changed, 10 insertions, 20 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 851707c8a1..bd0edf2b9e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.fpm;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -55,30 +57,18 @@ public class JavaFPGrowthSuite implements Serializable {
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);
- FPGrowth fpg = new FPGrowth();
-
- FPGrowthModel<String> model6 = fpg
- .setMinSupport(0.9)
- .setNumPartitions(1)
- .run(rdd);
- assertEquals(0, model6.javaFreqItemsets().count());
-
- FPGrowthModel<String> model3 = fpg
+ FPGrowthModel<String> model = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
- assertEquals(18, model3.javaFreqItemsets().count());
- FPGrowthModel<String> model2 = fpg
- .setMinSupport(0.3)
- .setNumPartitions(4)
- .run(rdd);
- assertEquals(54, model2.javaFreqItemsets().count());
+ List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+ assertEquals(18, freqItemsets.size());
- FPGrowthModel<String> model1 = fpg
- .setMinSupport(0.1)
- .setNumPartitions(8)
- .run(rdd);
- assertEquals(625, model1.javaFreqItemsets().count());
+ for (FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
}
}