aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala90
1 files changed, 90 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
new file mode 100644
index 0000000000..9b0c882481
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.{ObjectOutputStream, IOException}
+import org.apache.spark._
+
+
+private[spark]
+class CartesianPartition(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Partition {
+ var s1 = rdd1.partitions(s1Index)
+ var s2 = rdd2.partitions(s2Index)
+ override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.partitions(s1Index)
+ s2 = rdd2.partitions(s2Index)
+ oos.defaultWriteObject()
+ }
+}
+
+private[spark]
+class CartesianRDD[T: ClassManifest, U:ClassManifest](
+ sc: SparkContext,
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
+ with Serializable {
+
+ val numPartitionsInRdd2 = rdd2.partitions.size
+
+ override def getPartitions: Array[Partition] = {
+ // create the cross product split
+ val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size)
+ for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
+ val idx = s1.index * numPartitionsInRdd2 + s2.index
+ array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
+ }
+ array
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
+ (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
+ }
+
+ override def compute(split: Partition, context: TaskContext) = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
+ for (x <- rdd1.iterator(currSplit.s1, context);
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = List(
+ new NarrowDependency(rdd1) {
+ def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
+ },
+ new NarrowDependency(rdd2) {
+ def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
+ }
+ )
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}