aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
-rw-r--r--python/pyspark/rdd.py5
3 files changed, 16 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index b62f3fbdc4..ede5568493 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the first iteration, just try all partitions next.
+ // If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
- // by 50%.
+ // by 50%. We also cap the estimation in the end.
if (results.size == 0) {
- numPartsToTry = totalParts - 1
+ numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max(1,
+ (1.5 * num * partsScanned / results.size).toInt - partsScanned)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 2aba40d152..71cabf61d4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag](
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
+ // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
// interpolate the number of partitions we need to try, but overestimate it by 50%.
+ // We also cap the estimation in the end.
if (buf.size == 0) {
numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index e13bab946c..15be4bfec9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1070,10 +1070,13 @@ class RDD(object):
# If we didn't find any rows after the previous iteration,
# quadruple and retry. Otherwise, interpolate the number of
# partitions we need to try, but overestimate it by 50%.
+ # We also cap the estimation in the end.
if len(items) == 0:
numPartsToTry = partsScanned * 4
else:
- numPartsToTry = int(1.5 * num * partsScanned / len(items))
+ # the first paramter of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
+ numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)
left = num - len(items)