diff options
Diffstat (limited to 'core/src/main')
3 files changed, 106 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8010bb68e3..ec8e311aff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -775,18 +775,7 @@ abstract class RDD[T: ClassTag]( /** * Return the number of elements in the RDD. */ - def count(): Long = { - sc.runJob(this, (iter: Iterator[T]) => { - // Use a while loop to count the number of elements rather than iter.size because - // iter.size uses a for loop, which is slightly slower in current version of Scala. - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }).sum - } + def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum /** * (Experimental) Approximate version of count() that returns a potentially incomplete result @@ -870,6 +859,29 @@ abstract class RDD[T: ClassTag]( } /** + * Zips this RDD with its element indices. The ordering is first based on the partition index + * and then the ordering of items within each partition. So the first item in the first + * partition gets index 0, and the last item in the last partition receives the largest index. + * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. + * This method needs to trigger a spark job when this RDD contains more than one partitions. + */ + def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this) + + /** + * Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k, + * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method + * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. + */ + def zipWithUniqueId(): RDD[(T, Long)] = { + val n = this.partitions.size + this.mapPartitionsWithIndex { case (k, iter) => + iter.zipWithIndex.map { case (item, i) => + (item, i * n + k) + } + } + } + + /** * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala new file mode 100644 index 0000000000..5e08a469ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.util.Utils + +private[spark] +class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) + extends Partition with Serializable { + override val index: Int = prev.index +} + +/** + * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * index and then the ordering of items within each partition. So the first item in the first + * partition gets index 0, and the last item in the last partition receives the largest index. + * + * @param prev parent RDD + * @tparam T parent RDD item type + */ +private[spark] +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { + + override def getPartitions: Array[Partition] = { + val n = prev.partitions.size + val startIndices: Array[Long] = + if (n == 0) { + Array[Long]() + } else if (n == 1) { + Array(0L) + } else { + prev.context.runJob( + prev, + Utils.getIteratorSize _, + 0 until n - 1, // do not need to count the last partition + false + ).scanLeft(0L)(_ + _) + } + firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index))) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[ZippedWithIndexRDDPartition].prev) + + override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = { + val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition] + firstParent[T].iterator(split.prev, context).zipWithIndex.map { x => + (x._1, split.startIndex + x._2) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c201d0a33f..8749ab7875 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -855,4 +855,17 @@ private[spark] object Utils extends Logging { System.currentTimeMillis - start } + /** + * Counts the number of elements of an iterator using a while loop rather than calling + * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower + * in the current version of Scala. + */ + def getIteratorSize[T](iterator: Iterator[T]): Long = { + var count = 0L + while (iterator.hasNext) { + count += 1L + iterator.next() + } + count + } } |