diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala | 22 | ||||
-rw-r--r-- | core/src/test/java/org/apache/spark/JavaAPISuite.java | 193 |
2 files changed, 125 insertions, 90 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index dc698dea75..23d1371079 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -108,6 +108,28 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) + + /** + * Randomly splits this RDD with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * + * @return split RDDs in an array + */ + def randomSplit(weights: Array[Double]): Array[JavaRDD[T]] = + randomSplit(weights, Utils.random.nextLong) + + /** + * Randomly splits this RDD with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param seed random seed + * + * @return split RDDs in an array + */ + def randomSplit(weights: Array[Double], seed: Long): Array[JavaRDD[T]] = + rdd.randomSplit(weights, seed).map(wrapRDD) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b78309f81c..50a6212911 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -23,6 +23,7 @@ import java.util.*; import scala.Tuple2; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.base.Optional; import com.google.common.base.Charsets; @@ -48,7 +49,6 @@ import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; -import org.apache.spark.util.Utils; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -70,16 +70,6 @@ public class JavaAPISuite implements Serializable { sc = null; } - static class ReverseIntComparator implements Comparator<Integer>, Serializable { - - @Override - public int compare(Integer a, Integer b) { - if (a > b) return -1; - else if (a < b) return 1; - else return 0; - } - } - @SuppressWarnings("unchecked") @Test public void sparkContextUnion() { @@ -124,7 +114,7 @@ public class JavaAPISuite implements Serializable { JavaRDD<Integer> intersections = s1.intersection(s2); Assert.assertEquals(3, intersections.count()); - ArrayList<Integer> list = new ArrayList<Integer>(); + List<Integer> list = new ArrayList<Integer>(); JavaRDD<Integer> empty = sc.parallelize(list); JavaRDD<Integer> emptyIntersection = empty.intersection(s2); Assert.assertEquals(0, emptyIntersection.count()); @@ -145,6 +135,28 @@ public class JavaAPISuite implements Serializable { } @Test + public void sample() { + List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + JavaRDD<Integer> rdd = sc.parallelize(ints); + JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 11); + // expected 2 but of course result varies randomly a bit + Assert.assertEquals(3, sample20.count()); + JavaRDD<Integer> sample20NoReplacement = rdd.sample(false, 0.2, 11); + Assert.assertEquals(2, sample20NoReplacement.count()); + } + + @Test + public void randomSplit() { + List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + JavaRDD<Integer> rdd = sc.parallelize(ints); + JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11); + Assert.assertEquals(3, splits.length); + Assert.assertEquals(2, splits[0].count()); + Assert.assertEquals(3, splits[1].count()); + Assert.assertEquals(5, splits[2].count()); + } + + @Test public void sortByKey() { List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>(); pairs.add(new Tuple2<Integer, Integer>(0, 4)); @@ -161,26 +173,24 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2)); // Custom comparator - sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); + sortedRDD = rdd.sortByKey(Collections.<Integer>reverseOrder(), false); Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1)); Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2)); } - static int foreachCalls = 0; - @Test public void foreach() { - foreachCalls = 0; + final Accumulator<Integer> accum = sc.accumulator(0); JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction<String>() { @Override - public void call(String s) { - foreachCalls++; + public void call(String s) throws IOException { + accum.add(1); } }); - Assert.assertEquals(2, foreachCalls); + Assert.assertEquals(2, accum.value().intValue()); } @Test @@ -188,7 +198,7 @@ public class JavaAPISuite implements Serializable { List<Integer> correct = Arrays.asList(1, 2, 3, 4); JavaRDD<Integer> rdd = sc.parallelize(correct); List<Integer> result = Lists.newArrayList(rdd.toLocalIterator()); - Assert.assertTrue(correct.equals(result)); + Assert.assertEquals(correct, result); } @Test @@ -196,7 +206,7 @@ public class JavaAPISuite implements Serializable { List<Integer> dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD<Integer, Long> zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD<Long> indexes = zip.values(); - Assert.assertTrue(new HashSet<Long>(indexes.collect()).size() == 4); + Assert.assertEquals(4, new HashSet<Long>(indexes.collect()).size()); } @Test @@ -205,7 +215,7 @@ public class JavaAPISuite implements Serializable { JavaPairRDD<Integer, Long> zip = sc.parallelize(dataArray).zipWithIndex(); JavaRDD<Long> indexes = zip.values(); List<Long> correctIndexes = Arrays.asList(0L, 1L, 2L, 3L); - Assert.assertTrue(indexes.collect().equals(correctIndexes)); + Assert.assertEquals(correctIndexes, indexes.collect()); } @SuppressWarnings("unchecked") @@ -252,8 +262,10 @@ public class JavaAPISuite implements Serializable { new Tuple2<String, Integer>("Oranges", 2), new Tuple2<String, Integer>("Apples", 3) )); - JavaPairRDD<String, Tuple2<Iterable<String>, Iterable<Integer>>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + JavaPairRDD<String, Tuple2<Iterable<String>, Iterable<Integer>>> cogrouped = + categories.cogroup(prices); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -281,8 +293,7 @@ public class JavaAPISuite implements Serializable { rdd1.leftOuterJoin(rdd2).filter( new Function<Tuple2<Integer, Tuple2<Integer, Optional<Character>>>, Boolean>() { @Override - public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup) - throws Exception { + public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup) { return !tup._2()._2().isPresent(); } }).first(); @@ -356,8 +367,7 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(2, localCounts.get(2).intValue()); Assert.assertEquals(3, localCounts.get(3).intValue()); - localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer, - Integer>() { + localCounts = rdd.reduceByKeyLocally(new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer a, Integer b) { return a + b; @@ -448,16 +458,17 @@ public class JavaAPISuite implements Serializable { JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction<Integer>() { @Override public double call(Integer x) { - return 1.0 * x; + return x.doubleValue(); } }).cache(); doubles.collect(); - JavaPairRDD<Integer, Integer> pairs = rdd.mapToPair(new PairFunction<Integer, Integer, Integer>() { - @Override - public Tuple2<Integer, Integer> call(Integer x) { - return new Tuple2<Integer, Integer>(x, x); - } - }).cache(); + JavaPairRDD<Integer, Integer> pairs = rdd.mapToPair( + new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer x) { + return new Tuple2<Integer, Integer>(x, x); + } + }).cache(); pairs.collect(); JavaRDD<String> strings = rdd.map(new Function<Integer, String>() { @Override @@ -487,7 +498,9 @@ public class JavaAPISuite implements Serializable { @Override public Iterable<Tuple2<String, String>> call(String s) { List<Tuple2<String, String>> pairs = new LinkedList<Tuple2<String, String>>(); - for (String word : s.split(" ")) pairs.add(new Tuple2<String, String>(word, word)); + for (String word : s.split(" ")) { + pairs.add(new Tuple2<String, String>(word, word)); + } return pairs; } } @@ -499,7 +512,9 @@ public class JavaAPISuite implements Serializable { @Override public Iterable<Double> call(String s) { List<Double> lengths = new LinkedList<Double>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; } }); @@ -521,7 +536,7 @@ public class JavaAPISuite implements Serializable { JavaPairRDD<String, Integer> swapped = pairRDD.flatMapToPair( new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() { @Override - public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception { + public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) { return Collections.singletonList(item.swap()); } }); @@ -530,7 +545,7 @@ public class JavaAPISuite implements Serializable { // There was never a bug here, but it's worth testing: pairRDD.mapToPair(new PairFunction<Tuple2<Integer, String>, String, Integer>() { @Override - public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception { + public Tuple2<String, Integer> call(Tuple2<Integer, String> item) { return item.swap(); } }).collect(); @@ -631,14 +646,10 @@ public class JavaAPISuite implements Serializable { byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8"); String tempDirName = tempDir.getAbsolutePath(); - DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000")); - ds.write(content1); - ds.close(); - ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001")); - ds.write(content2); - ds.close(); - - HashMap<String, String> container = new HashMap<String, String>(); + Files.write(content1, new File(tempDirName + "/part-00000")); + Files.write(content2, new File(tempDirName + "/part-00001")); + + Map<String, String> container = new HashMap<String, String>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -844,7 +855,7 @@ public class JavaAPISuite implements Serializable { JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction<Integer>() { @Override public double call(Integer x) { - return 1.0 * x; + return x.doubleValue(); } }); JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles); @@ -859,17 +870,7 @@ public class JavaAPISuite implements Serializable { new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() { @Override public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) { - int sizeI = 0; - int sizeS = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS); + return Arrays.asList(Iterators.size(i), Iterators.size(s)); } }; @@ -883,6 +884,7 @@ public class JavaAPISuite implements Serializable { final Accumulator<Integer> intAccum = sc.intAccumulator(10); rdd.foreach(new VoidFunction<Integer>() { + @Override public void call(Integer x) { intAccum.add(x); } @@ -891,6 +893,7 @@ public class JavaAPISuite implements Serializable { final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(new VoidFunction<Integer>() { + @Override public void call(Integer x) { doubleAccum.add((double) x); } @@ -899,14 +902,17 @@ public class JavaAPISuite implements Serializable { // Try a custom accumulator type AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } + @Override public Float addAccumulator(Float r, Float t) { return r + t; } + @Override public Float zero(Float initialValue) { return 0.0f; } @@ -914,6 +920,7 @@ public class JavaAPISuite implements Serializable { final Accumulator<Float> floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(new VoidFunction<Integer>() { + @Override public void call(Integer x) { floatAccum.add((float) x); } @@ -929,7 +936,8 @@ public class JavaAPISuite implements Serializable { public void keyBy() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2)); List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() { - public String call(Integer t) throws Exception { + @Override + public String call(Integer t) { return t.toString(); } }).collect(); @@ -941,10 +949,10 @@ public class JavaAPISuite implements Serializable { public void checkpointAndComputation() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertEquals(false, rdd.isCheckpointed()); + Assert.assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertTrue(rdd.isCheckpointed()); Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); } @@ -952,10 +960,10 @@ public class JavaAPISuite implements Serializable { public void checkpointAndRestore() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertEquals(false, rdd.isCheckpointed()); + Assert.assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertTrue(rdd.isCheckpointed()); Assert.assertTrue(rdd.getCheckpointFile().isPresent()); JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); @@ -966,16 +974,17 @@ public class JavaAPISuite implements Serializable { @Test public void mapOnPairRDD() { JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(new PairFunction<Integer, Integer, Integer>() { - @Override - public Tuple2<Integer, Integer> call(Integer i) throws Exception { - return new Tuple2<Integer, Integer>(i, i % 2); - } - }); + JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair( + new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer i) { + return new Tuple2<Integer, Integer>(i, i % 2); + } + }); JavaPairRDD<Integer, Integer> rdd3 = rdd2.mapToPair( new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() { @Override - public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception { + public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) { return new Tuple2<Integer, Integer>(in._2(), in._1()); } }); @@ -992,14 +1001,15 @@ public class JavaAPISuite implements Serializable { public void collectPartitions() { JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(new PairFunction<Integer, Integer, Integer>() { - @Override - public Tuple2<Integer, Integer> call(Integer i) throws Exception { - return new Tuple2<Integer, Integer>(i, i % 2); - } - }); + JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair( + new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer i) { + return new Tuple2<Integer, Integer>(i, i % 2); + } + }); - List[] parts = rdd1.collectPartitions(new int[] {0}); + List<Integer>[] parts = rdd1.collectPartitions(new int[] {0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[] {1, 2}); @@ -1010,14 +1020,14 @@ public class JavaAPISuite implements Serializable { new Tuple2<Integer, Integer>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); - parts = rdd2.collectPartitions(new int[] {1, 2}); + List<Tuple2<Integer,Integer>>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1), new Tuple2<Integer, Integer>(4, 0)), - parts[0]); + parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1), new Tuple2<Integer, Integer>(6, 0), new Tuple2<Integer, Integer>(7, 1)), - parts[1]); + parts2[1]); } @Test @@ -1034,10 +1044,12 @@ public class JavaAPISuite implements Serializable { @Test public void countApproxDistinctByKey() { List<Tuple2<Integer, Integer>> arrayData = new ArrayList<Tuple2<Integer, Integer>>(); - for (int i = 10; i < 100; i++) - for (int j = 0; j < i; j++) + for (int i = 10; i < 100; i++) { + for (int j = 0; j < i; j++) { arrayData.add(new Tuple2<Integer, Integer>(i, j)); - + } + } + double relativeSD = 0.001; JavaPairRDD<Integer, Integer> pairRdd = sc.parallelizePairs(arrayData); List<Tuple2<Integer, Object>> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); for (Tuple2<Integer, Object> resItem : res) { @@ -1053,12 +1065,13 @@ public class JavaAPISuite implements Serializable { public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1)); - JavaPairRDD<Integer, int[]> pairRDD = rdd.mapToPair(new PairFunction<Integer, Integer, int[]>() { - @Override - public Tuple2<Integer, int[]> call(Integer x) throws Exception { - return new Tuple2<Integer, int[]>(x, new int[] { x }); - } - }); + JavaPairRDD<Integer, int[]> pairRDD = rdd.mapToPair( + new PairFunction<Integer, Integer, int[]>() { + @Override + public Tuple2<Integer, int[]> call(Integer x) { + return new Tuple2<Integer, int[]>(x, new int[] { x }); + } + }); pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } |