diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala new file mode 100644 index 0000000000..9b0c882481 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.{ObjectOutputStream, IOException} +import org.apache.spark._ + + +private[spark] +class CartesianPartition( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Partition { + var s1 = rdd1.partitions(s1Index) + var s2 = rdd2.partitions(s2Index) + override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.partitions(s1Index) + s2 = rdd2.partitions(s2Index) + oos.defaultWriteObject() + } +} + +private[spark] +class CartesianRDD[T: ClassManifest, U:ClassManifest]( + sc: SparkContext, + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) + with Serializable { + + val numPartitionsInRdd2 = rdd2.partitions.size + + override def getPartitions: Array[Partition] = { + // create the cross product split + val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { + val idx = s1.index * numPartitionsInRdd2 + s2.index + array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) + } + array + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val currSplit = split.asInstanceOf[CartesianPartition] + (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct + } + + override def compute(split: Partition, context: TaskContext) = { + val currSplit = split.asInstanceOf[CartesianPartition] + for (x <- rdd1.iterator(currSplit.s1, context); + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + } + + override def getDependencies: Seq[Dependency[_]] = List( + new NarrowDependency(rdd1) { + def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) + }, + new NarrowDependency(rdd2) { + def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) + } + ) + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} |