aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala22
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java193
2 files changed, 125 insertions, 90 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index dc698dea75..23d1371079 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -108,6 +108,28 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
+
+ /**
+ * Randomly splits this RDD with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1
+ *
+ * @return split RDDs in an array
+ */
+ def randomSplit(weights: Array[Double]): Array[JavaRDD[T]] =
+ randomSplit(weights, Utils.random.nextLong)
+
+ /**
+ * Randomly splits this RDD with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1
+ * @param seed random seed
+ *
+ * @return split RDDs in an array
+ */
+ def randomSplit(weights: Array[Double], seed: Long): Array[JavaRDD[T]] =
+ rdd.randomSplit(weights, seed).map(wrapRDD)
+
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index b78309f81c..50a6212911 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -23,6 +23,7 @@ import java.util.*;
import scala.Tuple2;
import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
@@ -48,7 +49,6 @@ import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.StatCounter;
-import org.apache.spark.util.Utils;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -70,16 +70,6 @@ public class JavaAPISuite implements Serializable {
sc = null;
}
- static class ReverseIntComparator implements Comparator<Integer>, Serializable {
-
- @Override
- public int compare(Integer a, Integer b) {
- if (a > b) return -1;
- else if (a < b) return 1;
- else return 0;
- }
- }
-
@SuppressWarnings("unchecked")
@Test
public void sparkContextUnion() {
@@ -124,7 +114,7 @@ public class JavaAPISuite implements Serializable {
JavaRDD<Integer> intersections = s1.intersection(s2);
Assert.assertEquals(3, intersections.count());
- ArrayList<Integer> list = new ArrayList<Integer>();
+ List<Integer> list = new ArrayList<Integer>();
JavaRDD<Integer> empty = sc.parallelize(list);
JavaRDD<Integer> emptyIntersection = empty.intersection(s2);
Assert.assertEquals(0, emptyIntersection.count());
@@ -145,6 +135,28 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void sample() {
+ List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
+ JavaRDD<Integer> rdd = sc.parallelize(ints);
+ JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 11);
+ // expected 2 but of course result varies randomly a bit
+ Assert.assertEquals(3, sample20.count());
+ JavaRDD<Integer> sample20NoReplacement = rdd.sample(false, 0.2, 11);
+ Assert.assertEquals(2, sample20NoReplacement.count());
+ }
+
+ @Test
+ public void randomSplit() {
+ List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
+ JavaRDD<Integer> rdd = sc.parallelize(ints);
+ JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
+ Assert.assertEquals(3, splits.length);
+ Assert.assertEquals(2, splits[0].count());
+ Assert.assertEquals(3, splits[1].count());
+ Assert.assertEquals(5, splits[2].count());
+ }
+
+ @Test
public void sortByKey() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
pairs.add(new Tuple2<Integer, Integer>(0, 4));
@@ -161,26 +173,24 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
// Custom comparator
- sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false);
+ sortedRDD = rdd.sortByKey(Collections.<Integer>reverseOrder(), false);
Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
sortedPairs = sortedRDD.collect();
Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
}
- static int foreachCalls = 0;
-
@Test
public void foreach() {
- foreachCalls = 0;
+ final Accumulator<Integer> accum = sc.accumulator(0);
JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World"));
rdd.foreach(new VoidFunction<String>() {
@Override
- public void call(String s) {
- foreachCalls++;
+ public void call(String s) throws IOException {
+ accum.add(1);
}
});
- Assert.assertEquals(2, foreachCalls);
+ Assert.assertEquals(2, accum.value().intValue());
}
@Test
@@ -188,7 +198,7 @@ public class JavaAPISuite implements Serializable {
List<Integer> correct = Arrays.asList(1, 2, 3, 4);
JavaRDD<Integer> rdd = sc.parallelize(correct);
List<Integer> result = Lists.newArrayList(rdd.toLocalIterator());
- Assert.assertTrue(correct.equals(result));
+ Assert.assertEquals(correct, result);
}
@Test
@@ -196,7 +206,7 @@ public class JavaAPISuite implements Serializable {
List<Integer> dataArray = Arrays.asList(1, 2, 3, 4);
JavaPairRDD<Integer, Long> zip = sc.parallelize(dataArray).zipWithUniqueId();
JavaRDD<Long> indexes = zip.values();
- Assert.assertTrue(new HashSet<Long>(indexes.collect()).size() == 4);
+ Assert.assertEquals(4, new HashSet<Long>(indexes.collect()).size());
}
@Test
@@ -205,7 +215,7 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Long> zip = sc.parallelize(dataArray).zipWithIndex();
JavaRDD<Long> indexes = zip.values();
List<Long> correctIndexes = Arrays.asList(0L, 1L, 2L, 3L);
- Assert.assertTrue(indexes.collect().equals(correctIndexes));
+ Assert.assertEquals(correctIndexes, indexes.collect());
}
@SuppressWarnings("unchecked")
@@ -252,8 +262,10 @@ public class JavaAPISuite implements Serializable {
new Tuple2<String, Integer>("Oranges", 2),
new Tuple2<String, Integer>("Apples", 3)
));
- JavaPairRDD<String, Tuple2<Iterable<String>, Iterable<Integer>>> cogrouped = categories.cogroup(prices);
- Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+ JavaPairRDD<String, Tuple2<Iterable<String>, Iterable<Integer>>> cogrouped =
+ categories.cogroup(prices);
+ Assert.assertEquals("[Fruit, Citrus]",
+ Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
cogrouped.collect();
@@ -281,8 +293,7 @@ public class JavaAPISuite implements Serializable {
rdd1.leftOuterJoin(rdd2).filter(
new Function<Tuple2<Integer, Tuple2<Integer, Optional<Character>>>, Boolean>() {
@Override
- public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup)
- throws Exception {
+ public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup) {
return !tup._2()._2().isPresent();
}
}).first();
@@ -356,8 +367,7 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(2, localCounts.get(2).intValue());
Assert.assertEquals(3, localCounts.get(3).intValue());
- localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer,
- Integer>() {
+ localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer a, Integer b) {
return a + b;
@@ -448,16 +458,17 @@ public class JavaAPISuite implements Serializable {
JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction<Integer>() {
@Override
public double call(Integer x) {
- return 1.0 * x;
+ return x.doubleValue();
}
}).cache();
doubles.collect();
- JavaPairRDD<Integer, Integer> pairs = rdd.mapToPair(new PairFunction<Integer, Integer, Integer>() {
- @Override
- public Tuple2<Integer, Integer> call(Integer x) {
- return new Tuple2<Integer, Integer>(x, x);
- }
- }).cache();
+ JavaPairRDD<Integer, Integer> pairs = rdd.mapToPair(
+ new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer x) {
+ return new Tuple2<Integer, Integer>(x, x);
+ }
+ }).cache();
pairs.collect();
JavaRDD<String> strings = rdd.map(new Function<Integer, String>() {
@Override
@@ -487,7 +498,9 @@ public class JavaAPISuite implements Serializable {
@Override
public Iterable<Tuple2<String, String>> call(String s) {
List<Tuple2<String, String>> pairs = new LinkedList<Tuple2<String, String>>();
- for (String word : s.split(" ")) pairs.add(new Tuple2<String, String>(word, word));
+ for (String word : s.split(" ")) {
+ pairs.add(new Tuple2<String, String>(word, word));
+ }
return pairs;
}
}
@@ -499,7 +512,9 @@ public class JavaAPISuite implements Serializable {
@Override
public Iterable<Double> call(String s) {
List<Double> lengths = new LinkedList<Double>();
- for (String word : s.split(" ")) lengths.add(word.length() * 1.0);
+ for (String word : s.split(" ")) {
+ lengths.add((double) word.length());
+ }
return lengths;
}
});
@@ -521,7 +536,7 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<String, Integer> swapped = pairRDD.flatMapToPair(
new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
@Override
- public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) {
return Collections.singletonList(item.swap());
}
});
@@ -530,7 +545,7 @@ public class JavaAPISuite implements Serializable {
// There was never a bug here, but it's worth testing:
pairRDD.mapToPair(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
@Override
- public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) {
return item.swap();
}
}).collect();
@@ -631,14 +646,10 @@ public class JavaAPISuite implements Serializable {
byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8");
String tempDirName = tempDir.getAbsolutePath();
- DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000"));
- ds.write(content1);
- ds.close();
- ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001"));
- ds.write(content2);
- ds.close();
-
- HashMap<String, String> container = new HashMap<String, String>();
+ Files.write(content1, new File(tempDirName + "/part-00000"));
+ Files.write(content2, new File(tempDirName + "/part-00001"));
+
+ Map<String, String> container = new HashMap<String, String>();
container.put(tempDirName+"/part-00000", new Text(content1).toString());
container.put(tempDirName+"/part-00001", new Text(content2).toString());
@@ -844,7 +855,7 @@ public class JavaAPISuite implements Serializable {
JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction<Integer>() {
@Override
public double call(Integer x) {
- return 1.0 * x;
+ return x.doubleValue();
}
});
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
@@ -859,17 +870,7 @@ public class JavaAPISuite implements Serializable {
new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
@Override
public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
- int sizeI = 0;
- int sizeS = 0;
- while (i.hasNext()) {
- sizeI += 1;
- i.next();
- }
- while (s.hasNext()) {
- sizeS += 1;
- s.next();
- }
- return Arrays.asList(sizeI, sizeS);
+ return Arrays.asList(Iterators.size(i), Iterators.size(s));
}
};
@@ -883,6 +884,7 @@ public class JavaAPISuite implements Serializable {
final Accumulator<Integer> intAccum = sc.intAccumulator(10);
rdd.foreach(new VoidFunction<Integer>() {
+ @Override
public void call(Integer x) {
intAccum.add(x);
}
@@ -891,6 +893,7 @@ public class JavaAPISuite implements Serializable {
final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
rdd.foreach(new VoidFunction<Integer>() {
+ @Override
public void call(Integer x) {
doubleAccum.add((double) x);
}
@@ -899,14 +902,17 @@ public class JavaAPISuite implements Serializable {
// Try a custom accumulator type
AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ @Override
public Float addInPlace(Float r, Float t) {
return r + t;
}
+ @Override
public Float addAccumulator(Float r, Float t) {
return r + t;
}
+ @Override
public Float zero(Float initialValue) {
return 0.0f;
}
@@ -914,6 +920,7 @@ public class JavaAPISuite implements Serializable {
final Accumulator<Float> floatAccum = sc.accumulator(10.0f, floatAccumulatorParam);
rdd.foreach(new VoidFunction<Integer>() {
+ @Override
public void call(Integer x) {
floatAccum.add((float) x);
}
@@ -929,7 +936,8 @@ public class JavaAPISuite implements Serializable {
public void keyBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
- public String call(Integer t) throws Exception {
+ @Override
+ public String call(Integer t) {
return t.toString();
}
}).collect();
@@ -941,10 +949,10 @@ public class JavaAPISuite implements Serializable {
public void checkpointAndComputation() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
sc.setCheckpointDir(tempDir.getAbsolutePath());
- Assert.assertEquals(false, rdd.isCheckpointed());
+ Assert.assertFalse(rdd.isCheckpointed());
rdd.checkpoint();
rdd.count(); // Forces the DAG to cause a checkpoint
- Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertTrue(rdd.isCheckpointed());
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
}
@@ -952,10 +960,10 @@ public class JavaAPISuite implements Serializable {
public void checkpointAndRestore() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
sc.setCheckpointDir(tempDir.getAbsolutePath());
- Assert.assertEquals(false, rdd.isCheckpointed());
+ Assert.assertFalse(rdd.isCheckpointed());
rdd.checkpoint();
rdd.count(); // Forces the DAG to cause a checkpoint
- Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertTrue(rdd.isCheckpointed());
Assert.assertTrue(rdd.getCheckpointFile().isPresent());
JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
@@ -966,16 +974,17 @@ public class JavaAPISuite implements Serializable {
@Test
public void mapOnPairRDD() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
- JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(new PairFunction<Integer, Integer, Integer>() {
- @Override
- public Tuple2<Integer, Integer> call(Integer i) throws Exception {
- return new Tuple2<Integer, Integer>(i, i % 2);
- }
- });
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+ new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
JavaPairRDD<Integer, Integer> rdd3 = rdd2.mapToPair(
new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
@Override
- public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) {
return new Tuple2<Integer, Integer>(in._2(), in._1());
}
});
@@ -992,14 +1001,15 @@ public class JavaAPISuite implements Serializable {
public void collectPartitions() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
- JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(new PairFunction<Integer, Integer, Integer>() {
- @Override
- public Tuple2<Integer, Integer> call(Integer i) throws Exception {
- return new Tuple2<Integer, Integer>(i, i % 2);
- }
- });
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+ new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
- List[] parts = rdd1.collectPartitions(new int[] {0});
+ List<Integer>[] parts = rdd1.collectPartitions(new int[] {0});
Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
parts = rdd1.collectPartitions(new int[] {1, 2});
@@ -1010,14 +1020,14 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, Integer>(2, 0)),
rdd2.collectPartitions(new int[] {0})[0]);
- parts = rdd2.collectPartitions(new int[] {1, 2});
+ List<Tuple2<Integer,Integer>>[] parts2 = rdd2.collectPartitions(new int[] {1, 2});
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
new Tuple2<Integer, Integer>(4, 0)),
- parts[0]);
+ parts2[0]);
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
new Tuple2<Integer, Integer>(6, 0),
new Tuple2<Integer, Integer>(7, 1)),
- parts[1]);
+ parts2[1]);
}
@Test
@@ -1034,10 +1044,12 @@ public class JavaAPISuite implements Serializable {
@Test
public void countApproxDistinctByKey() {
List<Tuple2<Integer, Integer>> arrayData = new ArrayList<Tuple2<Integer, Integer>>();
- for (int i = 10; i < 100; i++)
- for (int j = 0; j < i; j++)
+ for (int i = 10; i < 100; i++) {
+ for (int j = 0; j < i; j++) {
arrayData.add(new Tuple2<Integer, Integer>(i, j));
-
+ }
+ }
+ double relativeSD = 0.001;
JavaPairRDD<Integer, Integer> pairRdd = sc.parallelizePairs(arrayData);
List<Tuple2<Integer, Object>> res = pairRdd.countApproxDistinctByKey(8, 0).collect();
for (Tuple2<Integer, Object> resItem : res) {
@@ -1053,12 +1065,13 @@ public class JavaAPISuite implements Serializable {
public void collectAsMapWithIntArrayValues() {
// Regression test for SPARK-1040
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1));
- JavaPairRDD<Integer, int[]> pairRDD = rdd.mapToPair(new PairFunction<Integer, Integer, int[]>() {
- @Override
- public Tuple2<Integer, int[]> call(Integer x) throws Exception {
- return new Tuple2<Integer, int[]>(x, new int[] { x });
- }
- });
+ JavaPairRDD<Integer, int[]> pairRDD = rdd.mapToPair(
+ new PairFunction<Integer, Integer, int[]>() {
+ @Override
+ public Tuple2<Integer, int[]> call(Integer x) {
+ return new Tuple2<Integer, int[]>(x, new int[] { x });
+ }
+ });
pairRDD.collect(); // Works fine
pairRDD.collectAsMap(); // Used to crash with ClassCastException
}