aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authoruncleGen <hustyugm@gmail.com>2015-12-09 15:09:40 +0000
committerSean Owen <sowen@cloudera.com>2015-12-09 15:09:40 +0000
commita113216865fd45ea39ae8f104e784af2cf667dcf (patch)
tree6ee66e4d7832b96000df6c25429bb393c13a1072 /core
parentf6883bb7afa7d5df480e1c2b3db6cb77198550be (diff)
downloadspark-a113216865fd45ea39ae8f104e784af2cf667dcf.tar.gz
spark-a113216865fd45ea39ae8f104e784af2cf667dcf.tar.bz2
spark-a113216865fd45ea39ae8f104e784af2cf667dcf.zip
[SPARK-12031][CORE][BUG] Integer overflow when do sampling
Author: uncleGen <hustyugm@gmail.com> Closes #10023 from uncleGen/1.6-bugfix.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala11
2 files changed, 8 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index e4df7af81a..ef9a2dab1c 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -253,7 +253,7 @@ private[spark] object RangePartitioner {
*/
def sketch[K : ClassTag](
rdd: RDD[K],
- sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
+ sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
@@ -262,7 +262,7 @@ private[spark] object RangePartitioner {
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
- val numItems = sketched.map(_._2.toLong).sum
+ val numItems = sketched.map(_._2).sum
(numItems, sketched)
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index c9a864ae62..f98932a470 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -34,7 +34,7 @@ private[spark] object SamplingUtils {
input: Iterator[T],
k: Int,
seed: Long = Random.nextLong())
- : (Array[T], Int) = {
+ : (Array[T], Long) = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
@@ -52,16 +52,17 @@ private[spark] object SamplingUtils {
(trimReservoir, i)
} else {
// If input size > k, continue the sampling process.
+ var l = i.toLong
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
- val replacementIndex = rand.nextInt(i)
+ val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
- reservoir(replacementIndex) = item
+ reservoir(replacementIndex.toInt) = item
}
- i += 1
+ l += 1
}
- (reservoir, i)
+ (reservoir, l)
}
}