aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala69
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala13
3 files changed, 106 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 8010bb68e3..ec8e311aff 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -775,18 +775,7 @@ abstract class RDD[T: ClassTag](
/**
* Return the number of elements in the RDD.
*/
- def count(): Long = {
- sc.runJob(this, (iter: Iterator[T]) => {
- // Use a while loop to count the number of elements rather than iter.size because
- // iter.size uses a for loop, which is slightly slower in current version of Scala.
- var result = 0L
- while (iter.hasNext) {
- result += 1L
- iter.next()
- }
- result
- }).sum
- }
+ def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum
/**
* (Experimental) Approximate version of count() that returns a potentially incomplete result
@@ -870,6 +859,29 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Zips this RDD with its element indices. The ordering is first based on the partition index
+ * and then the ordering of items within each partition. So the first item in the first
+ * partition gets index 0, and the last item in the last partition receives the largest index.
+ * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type.
+ * This method needs to trigger a spark job when this RDD contains more than one partitions.
+ */
+ def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this)
+
+ /**
+ * Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k,
+ * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method
+ * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]].
+ */
+ def zipWithUniqueId(): RDD[(T, Long)] = {
+ val n = this.partitions.size
+ this.mapPartitionsWithIndex { case (k, iter) =>
+ iter.zipWithIndex.map { case (item, i) =>
+ (item, i * n + k)
+ }
+ }
+ }
+
+ /**
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
new file mode 100644
index 0000000000..5e08a469ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{TaskContext, Partition}
+import org.apache.spark.util.Utils
+
+private[spark]
+class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
+ extends Partition with Serializable {
+ override val index: Int = prev.index
+}
+
+/**
+ * Represents a RDD zipped with its element indices. The ordering is first based on the partition
+ * index and then the ordering of items within each partition. So the first item in the first
+ * partition gets index 0, and the last item in the last partition receives the largest index.
+ *
+ * @param prev parent RDD
+ * @tparam T parent RDD item type
+ */
+private[spark]
+class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) {
+
+ override def getPartitions: Array[Partition] = {
+ val n = prev.partitions.size
+ val startIndices: Array[Long] =
+ if (n == 0) {
+ Array[Long]()
+ } else if (n == 1) {
+ Array(0L)
+ } else {
+ prev.context.runJob(
+ prev,
+ Utils.getIteratorSize _,
+ 0 until n - 1, // do not need to count the last partition
+ false
+ ).scanLeft(0L)(_ + _)
+ }
+ firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[ZippedWithIndexRDDPartition].prev)
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = {
+ val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition]
+ firstParent[T].iterator(split.prev, context).zipWithIndex.map { x =>
+ (x._1, split.startIndex + x._2)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c201d0a33f..8749ab7875 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -855,4 +855,17 @@ private[spark] object Utils extends Logging {
System.currentTimeMillis - start
}
+ /**
+ * Counts the number of elements of an iterator using a while loop rather than calling
+ * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower
+ * in the current version of Scala.
+ */
+ def getIteratorSize[T](iterator: Iterator[T]): Long = {
+ var count = 0L
+ while (iterator.hasNext) {
+ count += 1L
+ iterator.next()
+ }
+ count
+ }
}