aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ash <andrew@andrewash.com>2014-06-17 11:47:48 -0700
committerReynold Xin <rxin@apache.org>2014-06-17 11:47:48 -0700
commitb92d16b114fd49e881d09e7974ad57b2a0df2906 (patch)
tree4bede6fbb3f5c230bc545a7464d0c1805b199b08
parente243c5ffacd70ecadaf5c91668955dcc8141e060 (diff)
downloadspark-b92d16b114fd49e881d09e7974ad57b2a0df2906.tar.gz
spark-b92d16b114fd49e881d09e7974ad57b2a0df2906.tar.bz2
spark-b92d16b114fd49e881d09e7974ad57b2a0df2906.zip
SPARK-1063 Add .sortBy(f) method on RDD
This never got merged from the apache/incubator-spark repo (which is now deleted) but there had been several rounds of code review on this PR there. I think this is ready for merging. Author: Andrew Ash <andrew@andrewash.com> This patch had conflicts when merged, resolved by Committer: Reynold Xin <rxin@apache.org> Closes #369 from ash211/sortby and squashes the following commits: d09147a [Andrew Ash] Fix Ordering import 43d0a53 [Andrew Ash] Fix missing .collect() 29a54ed [Andrew Ash] Re-enable test by converting to a closure 5a95348 [Andrew Ash] Add license for RDDSuiteUtils 64ed6e3 [Andrew Ash] Remove leaked diff d4de69a [Andrew Ash] Remove scar tissue 63638b5 [Andrew Ash] Add Python version of .sortBy() 45e0fde [Andrew Ash] Add Java version of .sortBy() adf84c5 [Andrew Ash] Re-indent to keep line lengths under 100 chars 9d9b9d8 [Andrew Ash] Use parentheses on .collect() calls 0457b69 [Andrew Ash] Ignore failing test 99f0baf [Andrew Ash] Merge branch 'master' into sortby 222ae97 [Andrew Ash] Try moving Ordering objects out to a different class 3fd0dd3 [Andrew Ash] Add (failing) test for sortByKey with explicit Ordering b8b5bbc [Andrew Ash] Align remove extra spaces that were used to align ='s in test code 8c53298 [Andrew Ash] Actually use ascending and numPartitions parameters 381eef2 [Andrew Ash] Correct silly typo 7db3e84 [Andrew Ash] Support ascending and numPartitions params in sortBy() 0f685fd [Andrew Ash] Merge remote-tracking branch 'origin/master' into sortby ca4490d [Andrew Ash] Add .sortBy(f) method on RDD
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala12
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java33
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala59
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala31
-rw-r--r--python/pyspark/rdd.py12
6 files changed, 159 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 23d1371079..86fb374bef 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,10 +17,13 @@
package org.apache.spark.api.java
+import java.util.Comparator
+
import scala.language.implicitConversions
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -172,6 +175,19 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
rdd.setName(name)
this
}
+
+ /**
+ * Return this RDD sorted by the given key function.
+ */
+ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.call(x)
+ import com.google.common.collect.Ordering // shadows scala.math.Ordering
+ implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]]
+ implicit val ctag: ClassTag[S] = fakeClassTag
+ wrapRDD(rdd.sortBy(fn, ascending, numPartitions))
+ }
+
}
object JavaRDD {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 27cc60d775..cf915b870e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -443,6 +443,18 @@ abstract class RDD[T: ClassTag](
def ++(other: RDD[T]): RDD[T] = this.union(other)
/**
+ * Return this RDD sorted by the given key function.
+ */
+ def sortBy[K](
+ f: (T) ⇒ K,
+ ascending: Boolean = true,
+ numPartitions: Int = this.partitions.size)
+ (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] =
+ this.keyBy[K](f)
+ .sortByKey(ascending, numPartitions)
+ .values
+
+ /**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
*
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index ef41bfb88d..e46298c6a9 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -181,6 +181,39 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void sortBy() {
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(0, 4));
+ pairs.add(new Tuple2<Integer, Integer>(3, 2));
+ pairs.add(new Tuple2<Integer, Integer>(-1, 1));
+
+ JavaRDD<Tuple2<Integer, Integer>> rdd = sc.parallelize(pairs);
+
+ // compare on first value
+ JavaRDD<Tuple2<Integer, Integer>> sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() {
+ public Integer call(Tuple2<Integer, Integer> t) throws Exception {
+ return t._1();
+ }
+ }, true, 2);
+
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
+
+ // compare on second value
+ sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() {
+ public Integer call(Tuple2<Integer, Integer> t) throws Exception {
+ return t._2();
+ }
+ }, true, 2);
+ Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
+ sortedPairs = sortedRDD.collect();
+ Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(1));
+ Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(2));
+ }
+
+ @Test
public void foreach() {
final Accumulator<Integer> accum = sc.accumulator(0);
JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e94a1e76d4..0e5625b764 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -26,6 +26,8 @@ import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.util.Utils
+import org.apache.spark.rdd.RDDSuiteUtils._
+
class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
@@ -585,14 +587,63 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("sortByKey") {
+ val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))
+
+ val col1 = Array("4|60|C", "5|50|A", "6|40|B")
+ val col2 = Array("6|40|B", "5|50|A", "4|60|C")
+ val col3 = Array("5|50|A", "6|40|B", "4|60|C")
+
+ assert(data.sortBy(_.split("\\|")(0)).collect() === col1)
+ assert(data.sortBy(_.split("\\|")(1)).collect() === col2)
+ assert(data.sortBy(_.split("\\|")(2)).collect() === col3)
+ }
+
+ test("sortByKey ascending parameter") {
+ val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))
+
+ val asc = Array("4|60|C", "5|50|A", "6|40|B")
+ val desc = Array("6|40|B", "5|50|A", "4|60|C")
+
+ assert(data.sortBy(_.split("\\|")(0), true).collect() === asc)
+ assert(data.sortBy(_.split("\\|")(0), false).collect() === desc)
+ }
+
+ test("sortByKey with explicit ordering") {
+ val data = sc.parallelize(Seq("Bob|Smith|50",
+ "Jane|Smith|40",
+ "Thomas|Williams|30",
+ "Karen|Williams|60"))
+
+ val ageOrdered = Array("Thomas|Williams|30",
+ "Jane|Smith|40",
+ "Bob|Smith|50",
+ "Karen|Williams|60")
+
+ // last name, then first name
+ val nameOrdered = Array("Bob|Smith|50",
+ "Jane|Smith|40",
+ "Karen|Williams|60",
+ "Thomas|Williams|30")
+
+ val parse = (s: String) => {
+ val split = s.split("\\|")
+ Person(split(0), split(1), split(2).toInt)
+ }
+
+ import scala.reflect.classTag
+ assert(data.sortBy(parse, true, 2)(AgeOrdering, classTag[Person]).collect() === ageOrdered)
+ assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
+ }
+
test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
val intersection = Array(2, 4, 6, 8, 10)
// intersection is commutative
- assert(all.intersection(evens).collect.sorted === intersection)
- assert(evens.intersection(all).collect.sorted === intersection)
+ assert(all.intersection(evens).collect().sorted === intersection)
+ assert(evens.intersection(all).collect().sorted === intersection)
}
test("intersection strips duplicates in an input") {
@@ -600,8 +651,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val b = sc.parallelize(Seq(1,1,2,3))
val intersection = Array(1,2,3)
- assert(a.intersection(b).collect.sorted === intersection)
- assert(b.intersection(a).collect.sorted === intersection)
+ assert(a.intersection(b).collect().sorted === intersection)
+ assert(b.intersection(a).collect().sorted === intersection)
}
test("zipWithIndex") {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala
new file mode 100644
index 0000000000..4762fc1785
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+object RDDSuiteUtils {
+ case class Person(first: String, last: String, age: Int)
+
+ object AgeOrdering extends Ordering[Person] {
+ def compare(a:Person, b:Person) = a.age compare b.age
+ }
+
+ object NameOrdering extends Ordering[Person] {
+ def compare(a:Person, b:Person) =
+ implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first))
+ }
+}
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bb4d035edc..65f63153cd 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -549,6 +549,18 @@ class RDD(object):
.mapPartitions(mapFunc,preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))
+ def sortBy(self, keyfunc, ascending=True, numPartitions=None):
+ """
+ Sorts this RDD by the given keyfunc
+
+ >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+ >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect()
+ [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ """
+ return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values()
+
def glom(self):
"""
Return an RDD created by coalescing all elements within each partition