aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-19 18:06:16 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-19 18:06:16 -0800
commit0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f (patch)
tree36bdfdec69a205b85f7b85697c36abf2044d9ff5
parent6bddc40353057a562c78e75c5549c79a0d7d5f8b (diff)
downloadspark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.tar.gz
spark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.tar.bz2
spark-0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f.zip
[SPARK-5900][MLLIB] make PIC and FPGrowth Java-friendly
In the previous version, PIC stores clustering assignments as an `RDD[(Long, Int)]`. This is mapped to `RDD<Tuple2<Object, Object>>` in Java and hence Java users have to cast types manually. We should either create a new method called `javaAssignments` that returns `JavaRDD[(java.lang.Long, java.lang.Int)]` or wrap the result pair in a class. I chose the latter approach in this PR. Now assignments are stored as an `RDD[Assignment]`, where `Assignment` is a class with `id` and `cluster`. Similarly, in FPGrowth, the frequent itemsets are stored as an `RDD[(Array[Item], Long)]`, which is mapped to `RDD<Tuple2<Object, Object>>`. Though we provide a "Java-friendly" method `javaFreqItemsets` that returns `JavaRDD[(Array[Item], java.lang.Long)]`. It doesn't really work because `Array[Item]` is mapped to `Object` in Java. So in this PR I created a class `FreqItemset` to wrap the results. It has `items` and `freq`, as well as a `javaItems` method that returns `List<Item>` in Java. I'm not certain that the names I chose are proper: `Assignment`/`id`/`cluster` and `FreqItemset`/`items`/`freq`. Please let me know if there are better suggestions. CC: jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4695 from mengxr/SPARK-5900 and squashes the following commits: 865b5ca [Xiangrui Meng] make Assignment serializable cffa96e [Xiangrui Meng] fix test 9c0e590 [Xiangrui Meng] remove unused Tuple2 1b9db3d [Xiangrui Meng] make PIC and FPGrowth Java-friendly
-rw-r--r--docs/mllib-clustering.md8
-rw-r--r--docs/mllib-frequent-pattern-mining.md12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java5
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala41
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java30
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala10
11 files changed, 93 insertions, 74 deletions
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 6e46a47338..0b6db4fcb7 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -314,8 +314,8 @@ val pic = new PowerIteartionClustering()
.setMaxIterations(20)
val model = pic.run(similarities)
-model.assignments.foreach { case (vertexId, clusterId) =>
- println(s"$vertexId -> $clusterId")
+model.assignments.foreach { a =>
+ println(s"${a.id} -> ${a.cluster}")
}
{% endhighlight %}
@@ -349,8 +349,8 @@ PowerIterationClustering pic = new PowerIterationClustering()
.setMaxIterations(10);
PowerIterationClusteringModel model = pic.run(similarities);
-for (Tuple2<Object, Object> assignment: model.assignments().toJavaRDD().collect()) {
- System.out.println(assignment._1() + " -> " + assignment._2());
+for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
+ System.out.println(a.id() + " -> " + a.cluster());
}
{% endhighlight %}
</div>
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md
index 0ff9738768..9fd9be0dd0 100644
--- a/docs/mllib-frequent-pattern-mining.md
+++ b/docs/mllib-frequent-pattern-mining.md
@@ -57,8 +57,8 @@ val fpg = new FPGrowth()
.setNumPartitions(10)
val model = fpg.run(transactions)
-model.freqItemsets.collect().foreach { case (itemset, freq) =>
- println(itemset.mkString("[", ",", "]") + ", " + freq)
+model.freqItemsets.collect().foreach { itemset =>
+ println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq)
}
{% endhighlight %}
@@ -74,10 +74,9 @@ Calling `FPGrowth.run` with transactions returns an
that stores the frequent itemsets with their frequencies.
{% highlight java %}
-import java.util.Arrays;
import java.util.List;
-import scala.Tuple2;
+import com.google.common.base.Joiner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.fpm.FPGrowth;
@@ -88,11 +87,10 @@ JavaRDD<List<String>> transactions = ...
FPGrowth fpg = new FPGrowth()
.setMinSupport(0.2)
.setNumPartitions(10);
-
FPGrowthModel<String> model = fpg.run(transactions);
-for (Tuple2<Object, Long> s: model.javaFreqItemsets().collect()) {
- System.out.println("(" + Arrays.toString((Object[]) s._1()) + "): " + s._2());
+for (FPGrowth.FreqItemset<String> itemset: model.freqItemsets().toJavaRDD().collect()) {
+ System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
}
{% endhighlight %}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
index 0db572d760..f50e802cf6 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
@@ -18,10 +18,8 @@
package org.apache.spark.examples.mllib;
import java.util.ArrayList;
-import java.util.Arrays;
-
-import scala.Tuple2;
+import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
@@ -54,8 +52,8 @@ public class JavaFPGrowthExample {
.setMinSupport(0.3);
FPGrowthModel<String> model = fpg.run(transactions);
- for (Tuple2<Object, Long> s: model.javaFreqItemsets().collect()) {
- System.out.println(Arrays.toString((Object[]) s._1()) + ", " + s._2());
+ for (FPGrowth.FreqItemset<String> s: model.freqItemsets().toJavaRDD().collect()) {
+ System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
}
sc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
index e9371de39f..6c6f9768f0 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
@@ -17,7 +17,6 @@
package org.apache.spark.examples.mllib;
-import scala.Tuple2;
import scala.Tuple3;
import com.google.common.collect.Lists;
@@ -49,8 +48,8 @@ public class JavaPowerIterationClusteringExample {
.setMaxIterations(10);
PowerIterationClusteringModel model = pic.run(similarities);
- for (Tuple2<Object, Object> assignment: model.assignments().toJavaRDD().collect()) {
- System.out.println(assignment._1() + " -> " + assignment._2());
+ for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
+ System.out.println(a.id() + " -> " + a.cluster());
}
sc.stop();
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
index ae66107d70..aaae275ec5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
@@ -42,8 +42,8 @@ object FPGrowthExample {
.setMinSupport(0.3)
val model = fpg.run(transactions)
- model.freqItemsets.collect().foreach { case (itemset, freq) =>
- println(itemset.mkString("[", ",", "]") + ", " + freq)
+ model.freqItemsets.collect().foreach { itemset =>
+ println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq)
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
index b2373adba1..91c9772744 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
@@ -44,8 +44,7 @@ import org.apache.spark.{SparkConf, SparkContext}
*
* Here is a sample run and output:
*
- * ./bin/run-example mllib.PowerIterationClusteringExample
- * -k 3 --n 30 --maxIterations 15
+ * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15
*
* Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14],
* 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29]
@@ -103,7 +102,7 @@ object PowerIterationClusteringExample {
.setMaxIterations(params.maxIterations)
.run(circlesRdd)
- val clusters = model.assignments.collect.groupBy(_._2).mapValues(_.map(_._1))
+ val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id))
val assignments = clusters.toList.sortBy { case (k, v) => v.length}
val assignmentsStr = assignments
.map { case (k, v) =>
@@ -153,8 +152,5 @@ object PowerIterationClusteringExample {
val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0)
val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)
coeff * math.exp(expCoeff * ssquares)
- // math.exp((p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2))
}
-
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 63d03347f4..180023922a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -17,9 +17,9 @@
package org.apache.spark.mllib.clustering
-import org.apache.spark.api.java.JavaRDD
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
@@ -33,12 +33,12 @@ import org.apache.spark.util.random.XORShiftRandom
* Model produced by [[PowerIterationClustering]].
*
* @param k number of clusters
- * @param assignments an RDD of (vertexID, clusterID) pairs
+ * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s
*/
@Experimental
class PowerIterationClusteringModel(
val k: Int,
- val assignments: RDD[(Long, Int)]) extends Serializable
+ val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
/**
* :: Experimental ::
@@ -133,16 +133,33 @@ class PowerIterationClustering private[clustering] (
*/
private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = {
val v = powerIter(w, maxIterations)
- val assignments = kMeans(v, k)
+ val assignments = kMeans(v, k).mapPartitions({ iter =>
+ iter.map { case (id, cluster) =>
+ new Assignment(id, cluster)
+ }
+ }, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
}
}
-private[clustering] object PowerIterationClustering extends Logging {
+@Experimental
+object PowerIterationClustering extends Logging {
+
+ /**
+ * :: Experimental ::
+ * Cluster assignment.
+ * @param id node id
+ * @param cluster assigned cluster id
+ */
+ @Experimental
+ class Assignment(val id: Long, val cluster: Int) extends Serializable
+
/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
*/
- def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = {
+ private[clustering]
+ def normalize(similarities: RDD[(Long, Long, Double)])
+ : Graph[Double, Double] = {
val edges = similarities.flatMap { case (i, j, s) =>
if (s < 0.0) {
throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
@@ -173,6 +190,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @return a graph with edges representing W and vertices representing a random vector
* with unit 1-norm
*/
+ private[clustering]
def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = {
val r = g.vertices.mapPartitionsWithIndex(
(part, iter) => {
@@ -194,6 +212,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing the degree vector
*/
+ private[clustering]
def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = {
val sum = g.vertices.values.sum()
val v0 = g.vertices.mapValues(_ / sum)
@@ -207,6 +226,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param maxIterations maximum number of iterations
* @return a [[VertexRDD]] representing the pseudo-eigenvector
*/
+ private[clustering]
def powerIter(
g: Graph[Double, Double],
maxIterations: Int): VertexRDD[Double] = {
@@ -246,6 +266,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param k number of clusters
* @return a [[VertexRDD]] representing the clustering assignments
*/
+ private[clustering]
def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = {
val points = v.mapValues(x => Vectors.dense(x)).cache()
val model = new KMeans()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 3168d608c9..efa8459d3c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -26,8 +26,9 @@ import scala.reflect.ClassTag
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -35,18 +36,11 @@ import org.apache.spark.storage.StorageLevel
* :: Experimental ::
*
* Model trained by [[FPGrowth]], which holds frequent itemsets.
- * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
+ * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
* @tparam Item item type
*/
@Experimental
-class FPGrowthModel[Item: ClassTag](
- val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
-
- /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
- def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
- JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
- }
-}
+class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
/**
* :: Experimental ::
@@ -151,7 +145,7 @@ class FPGrowth private (
data: RDD[Array[Item]],
minCount: Long,
freqItems: Array[Item],
- partitioner: Partitioner): RDD[(Array[Item], Long)] = {
+ partitioner: Partitioner): RDD[FreqItemset[Item]] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
@@ -161,7 +155,7 @@ class FPGrowth private (
.flatMap { case (part, tree) =>
tree.extract(minCount, x => partitioner.getPartition(x) == part)
}.map { case (ranks, count) =>
- (ranks.map(i => freqItems(i)).toArray, count)
+ new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
}
}
@@ -193,3 +187,26 @@ class FPGrowth private (
output
}
}
+
+/**
+ * :: Experimental ::
+ */
+@Experimental
+object FPGrowth {
+
+ /**
+ * Frequent itemset.
+ * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
+ * @param freq frequency
+ * @tparam Item item type
+ */
+ class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
+
+ /**
+ * Returns items in a Java List.
+ */
+ def javaItems: java.util.List[Item] = {
+ items.toList.asJava
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 851707c8a1..bd0edf2b9e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.fpm;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -55,30 +57,18 @@ public class JavaFPGrowthSuite implements Serializable {
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);
- FPGrowth fpg = new FPGrowth();
-
- FPGrowthModel<String> model6 = fpg
- .setMinSupport(0.9)
- .setNumPartitions(1)
- .run(rdd);
- assertEquals(0, model6.javaFreqItemsets().count());
-
- FPGrowthModel<String> model3 = fpg
+ FPGrowthModel<String> model = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
- assertEquals(18, model3.javaFreqItemsets().count());
- FPGrowthModel<String> model2 = fpg
- .setMinSupport(0.3)
- .setNumPartitions(4)
- .run(rdd);
- assertEquals(54, model2.javaFreqItemsets().count());
+ List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+ assertEquals(18, freqItemsets.size());
- FPGrowthModel<String> model1 = fpg
- .setMinSupport(0.1)
- .setNumPartitions(8)
- .run(rdd);
- assertEquals(625, model1.javaFreqItemsets().count());
+ for (FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 03ecd9ca73..6315c03a70 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -51,8 +51,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setK(2)
.run(sc.parallelize(similarities, 2))
val predictions = Array.fill(2)(mutable.Set.empty[Long])
- model.assignments.collect().foreach { case (i, c) =>
- predictions(c) += i
+ model.assignments.collect().foreach { a =>
+ predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
@@ -61,8 +61,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setInitializationMode("degree")
.run(sc.parallelize(similarities, 2))
val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
- model2.assignments.collect().foreach { case (i, c) =>
- predictions2(c) += i
+ model2.assignments.collect().foreach { a =>
+ predictions2(a.cluster) += a.id
}
assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 68128284b8..bd5b9cc3af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -46,8 +46,8 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
@@ -96,10 +96,10 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
+ assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),