aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala9
1 files changed, 6 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 1cd6f2a896..377326f873 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
+import org.apache.spark.RangePartitioner
/**
* Regression model for isotonic regression.
@@ -408,9 +409,11 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
*/
private def parallelPoolAdjacentViolators(
input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
- val parallelStepResult = input
- .sortBy(x => (x._2, x._1))
- .glom()
+ val keyedInput = input.keyBy(_._2)
+ val parallelStepResult = keyedInput
+ .partitionBy(new RangePartitioner(keyedInput.getNumPartitions, keyedInput))
+ .values
+ .mapPartitions(p => Iterator(p.toArray.sortBy(x => (x._2, x._1))))
.flatMap(poolAdjacentViolators)
.collect()
.sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering.