diff options
author | Andrew Ash <andrew@andrewash.com> | 2014-02-06 22:38:36 -0800 |
---|---|---|
committer | Patrick Wendell <pwendell@gmail.com> | 2014-02-06 22:39:08 -0800 |
commit | 3a9d82cc9e85accb5c1577cf4718aa44c8d5038c (patch) | |
tree | 6bbb87e3997b4f90b6e1927405104d360e78a4e7 | |
parent | 1896c6e7c9f5c29284a045128b4aca0d5a6e7220 (diff) | |
download | spark-3a9d82cc9e85accb5c1577cf4718aa44c8d5038c.tar.gz spark-3a9d82cc9e85accb5c1577cf4718aa44c8d5038c.tar.bz2 spark-3a9d82cc9e85accb5c1577cf4718aa44c8d5038c.zip |
Merge pull request #506 from ash211/intersection. Closes #506.
SPARK-1062 Add rdd.intersection(otherRdd) method
Author: Andrew Ash <andrew@andrewash.com>
== Merge branch commits ==
commit 5d9982b171b9572649e9828f37ef0b43f0242912
Author: Andrew Ash <andrew@andrewash.com>
Date: Thu Feb 6 18:11:45 2014 -0800
Minor fixes
- style: (v,null) => (v, null)
- mention the shuffle in Javadoc
commit b86d02f14e810902719cef893cf6bfa18ff9acb0
Author: Andrew Ash <andrew@andrewash.com>
Date: Sun Feb 2 13:17:40 2014 -0800
Overload .intersection() for numPartitions and custom Partitioner
commit bcaa34911fcc6bb5bc5e4f9fe46d1df73cb71c09
Author: Andrew Ash <andrew@andrewash.com>
Date: Sun Feb 2 13:05:40 2014 -0800
Better naming of parameters in intersection's filter
commit b10a6af2d793ec6e9a06c798007fac3f6b860d89
Author: Andrew Ash <andrew@andrewash.com>
Date: Sat Jan 25 23:06:26 2014 -0800
Follow spark code format conventions of tab => 2 spaces
commit 965256e4304cca514bb36a1a36087711dec535ec
Author: Andrew Ash <andrew@andrewash.com>
Date: Fri Jan 24 00:28:01 2014 -0800
Add rdd.intersection(otherRdd) method
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 37 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 23 |
2 files changed, 58 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 033d334079..8010bb68e3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -394,6 +394,43 @@ abstract class RDD[T: ClassTag]( def ++(other: RDD[T]): RDD[T] = this.union(other) /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: RDD[T]): RDD[T] = + this.map(v => (v, null)).cogroup(other.map(v => (v, null))) + .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } + .keys + + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + * + * @param partitioner Partitioner to use for the resulting RDD + */ + def intersection(other: RDD[T], partitioner: Partitioner): RDD[T] = + this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner) + .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } + .keys + + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. Performs a hash partition across the cluster + * + * Note that this method performs a shuffle internally. + * + * @param numPartitions How many partitions to use in the resulting RDD + */ + def intersection(other: RDD[T], numPartitions: Int): RDD[T] = + this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions)) + .filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty } + .keys + + /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): RDD[Array[T]] = new GlommedRDD(this) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 223ebec5fa..879c4e5f17 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -373,8 +373,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { val prng42 = new Random(42) val prng43 = new Random(43) Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) - else 0 == prng43.nextInt(3)} + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} } assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) @@ -506,4 +506,23 @@ class RDDSuite extends FunSuite with SharedSparkContext { sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) } } + + test("intersection") { + val all = sc.parallelize(1 to 10) + val evens = sc.parallelize(2 to 10 by 2) + val intersection = Array(2, 4, 6, 8, 10) + + // intersection is commutative + assert(all.intersection(evens).collect.sorted === intersection) + assert(evens.intersection(all).collect.sorted === intersection) + } + + test("intersection strips duplicates in an input") { + val a = sc.parallelize(Seq(1,2,3,3)) + val b = sc.parallelize(Seq(1,1,2,3)) + val intersection = Array(1,2,3) + + assert(a.intersection(b).collect.sorted === intersection) + assert(b.intersection(a).collect.sorted === intersection) + } } |