aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKan Zhang <kzhang@apache.org>2014-06-03 22:47:18 -0700
committerReynold Xin <rxin@apache.org>2014-06-03 22:47:18 -0700
commitc402a4a685721d05932bbc578d997f330ff65a49 (patch)
treeb5a985c8b01dcf028d9ac7d7bb5e1b153d18b5ee
parent4ca06256690c5e03058dd179c2fc6437a917cfee (diff)
downloadspark-c402a4a685721d05932bbc578d997f330ff65a49.tar.gz
spark-c402a4a685721d05932bbc578d997f330ff65a49.tar.bz2
spark-c402a4a685721d05932bbc578d997f330ff65a49.zip
[SPARK-1817] RDD.zip() should verify partition sizes for each partition
RDD.zip() will throw an exception if it finds partition sizes are not the same. Author: Kan Zhang <kzhang@apache.org> Closes #944 from kanzhang/SPARK-1817 and squashes the following commits: c073848 [Kan Zhang] [SPARK-1817] Cosmetic updates 524c670 [Kan Zhang] [SPARK-1817] RDD.zip() should verify partition sizes for each partition
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala87
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala4
-rw-r--r--project/MimaExcludes.scala2
5 files changed, 33 insertions, 100 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 585b2f76af..54bdc3e7cb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -654,7 +654,19 @@ abstract class RDD[T: ClassTag](
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
- def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
+ zipPartitions(other, true) { (thisIter, otherIter) =>
+ new Iterator[(T, U)] {
+ def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
+ case (true, true) => true
+ case (false, false) => false
+ case _ => throw new SparkException("Can only zip RDDs with " +
+ "same number of elements in each partition")
+ }
+ def next = (thisIter.next, otherIter.next)
+ }
+ }
+ }
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
deleted file mode 100644
index b8110ffc42..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.io.{IOException, ObjectOutputStream}
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
-
-private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
- idx: Int,
- @transient rdd1: RDD[T],
- @transient rdd2: RDD[U]
- ) extends Partition {
-
- var partition1 = rdd1.partitions(idx)
- var partition2 = rdd2.partitions(idx)
- override val index: Int = idx
-
- def partitions = (partition1, partition2)
-
- @throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
- // Update the reference to parent partition at the time of task serialization
- partition1 = rdd1.partitions(idx)
- partition2 = rdd2.partitions(idx)
- oos.defaultWriteObject()
- }
-}
-
-private[spark] class ZippedRDD[T: ClassTag, U: ClassTag](
- sc: SparkContext,
- var rdd1: RDD[T],
- var rdd2: RDD[U])
- extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
-
- override def getPartitions: Array[Partition] = {
- if (rdd1.partitions.size != rdd2.partitions.size) {
- throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
- }
- val array = new Array[Partition](rdd1.partitions.size)
- for (i <- 0 until rdd1.partitions.size) {
- array(i) = new ZippedPartition(i, rdd1, rdd2)
- }
- array
- }
-
- override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
- val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
- rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context))
- }
-
- override def getPreferredLocations(s: Partition): Seq[String] = {
- val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
- val pref1 = rdd1.preferredLocations(partition1)
- val pref2 = rdd2.preferredLocations(partition2)
- // Check whether there are any hosts that match both RDDs; otherwise return the union
- val exactMatchLocations = pref1.intersect(pref2)
- if (!exactMatchLocations.isEmpty) {
- exactMatchLocations
- } else {
- (pref1 ++ pref2).distinct
- }
- }
-
- override def clearDependencies() {
- super.clearDependencies()
- rdd1 = null
- rdd2 = null
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 64933f4b10..f64f3c9036 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -167,26 +167,28 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
})
}
- test("ZippedRDD") {
- testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
- testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+ test("ZippedPartitionsRDD") {
+ testRDD(rdd => rdd.zip(rdd.map(x => x)))
+ testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)))
- // Test that the ZippedPartition updates parent partitions
- // after the parent RDD has been checkpointed and parent partitions have been changed.
- // Note that this test is very specific to the current implementation of ZippedRDD.
+ // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
+ // been checkpointed and parent partitions have been changed.
+ // Note that this test is very specific to the implementation of ZippedPartitionsRDD.
val rdd = generateFatRDD()
- val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x))
+ val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
zippedRDD.rdd1.checkpoint()
zippedRDD.rdd2.checkpoint()
val partitionBeforeCheckpoint =
- serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+ serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
zippedRDD.count()
val partitionAfterCheckpoint =
- serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+ serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
assert(
- partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass &&
- partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass,
- "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed"
+ partitionAfterCheckpoint.partitions(0).getClass !=
+ partitionBeforeCheckpoint.partitions(0).getClass &&
+ partitionAfterCheckpoint.partitions(1).getClass !=
+ partitionBeforeCheckpoint.partitions(1).getClass,
+ "ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed"
)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index bbd0c14178..286e221e33 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -350,6 +350,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
+
+ intercept[SparkException] {
+ nums.zip(sc.parallelize(1 to 5, 2)).collect()
+ }
}
test("partition pruning") {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fadf6a4d8b..dd7efceb23 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -54,6 +54,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
) ++
+ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
+ MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
MimaBuild.excludeSparkClass("util.SerializableHyperLogLog")
case v if v.startsWith("1.0") =>
Seq(