diff options
author | uncleGen <hustyugm@gmail.com> | 2015-12-09 15:09:40 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-12-09 15:09:40 +0000 |
commit | a113216865fd45ea39ae8f104e784af2cf667dcf (patch) | |
tree | 6ee66e4d7832b96000df6c25429bb393c13a1072 /core | |
parent | f6883bb7afa7d5df480e1c2b3db6cb77198550be (diff) | |
download | spark-a113216865fd45ea39ae8f104e784af2cf667dcf.tar.gz spark-a113216865fd45ea39ae8f104e784af2cf667dcf.tar.bz2 spark-a113216865fd45ea39ae8f104e784af2cf667dcf.zip |
[SPARK-12031][CORE][BUG] Integer overflow when do sampling
Author: uncleGen <hustyugm@gmail.com>
Closes #10023 from uncleGen/1.6-bugfix.
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/Partitioner.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala | 11 |
2 files changed, 8 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e4df7af81a..ef9a2dab1c 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae62..f98932a470 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } |