aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/RDD.scala66
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala4
2 files changed, 37 insertions, 33 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index f4288a9661..6270e018b3 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,17 +1,17 @@
package spark
import java.io.EOFException
-import java.net.URL
import java.io.ObjectInputStream
-import java.util.concurrent.atomic.AtomicLong
+import java.net.URL
import java.util.Random
import java.util.Date
import java.util.{HashMap => JHashMap}
+import java.util.concurrent.atomic.AtomicLong
-import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
-import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
@@ -48,7 +48,7 @@ import spark.storage.StorageLevel
import SparkContext._
/**
- * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
+ * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
* partitioned collection of elements that can be operated on in parallel. This class contains the
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
* [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
@@ -87,28 +87,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
@transient val dependencies: List[Dependency[_]]
// Methods available on all RDDs:
-
+
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
-
+
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
/** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil
-
+
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
-
+
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
-
+
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
-
- /**
+
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -124,32 +124,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
-
+
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
-
+
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
- }
-
+ }
+
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-
+
persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} )
-
+
val p = this.partitioner
-
+
new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ override val partitioner = p
}
}
-
+
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
@@ -162,9 +162,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
compute(split)
}
}
-
+
// Transformations (return a new RDD)
-
+
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
@@ -200,13 +200,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var multiplier = 3.0
var initialCount = count()
var maxSelected = 0
-
+
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
-
+
if (num > initialCount) {
total = maxSelected
fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
@@ -216,14 +216,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
total = num
}
-
+
val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
-
+
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
-
+
Utils.randomizeInPlace(samples, rand).take(total)
}
@@ -291,8 +291,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
- new MapPartitionsWithSplitRDD(this, sc.clean(f))
+ def mapPartitionsWithSplit[U: ClassManifest](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
@@ -351,7 +353,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
- * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
+ * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
@@ -452,7 +454,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout)
}
-
+
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index adc541694e..14e390c43b 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -12,9 +12,11 @@ import spark.Split
private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U])
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean)
extends RDD[U](prev.context) {
+ override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))