diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala new file mode 100644 index 0000000000..1e8d89e912 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.Random + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.{RDD, Partition, TaskContext} + +private[spark] +class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { + override val index: Int = prev.index +} + +class SampledRDD[T: ClassManifest]( + prev: RDD[T], + withReplacement: Boolean, + frac: Double, + seed: Int) + extends RDD[T](prev) { + + override def getPartitions: Array[Partition] = { + val rg = new Random(seed) + firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) + + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[SampledRDDPartition] + if (withReplacement) { + // For large datasets, the expected number of occurrences of each element in a sample with + // replacement is Poisson(frac). We use that to get a count for each element. + val poisson = new Poisson(frac, new DRand(split.seed)) + firstParent[T].iterator(split.prev, context).flatMap { element => + val count = poisson.nextInt() + if (count == 0) { + Iterator.empty // Avoid object allocation when we return 0 items, which is quite often + } else { + Iterator.fill(count)(element) + } + } + } else { // Sampling without replacement + val rand = new Random(split.seed) + firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) + } + } +} |