diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 1cd6f2a896..377326f873 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession +import org.apache.spark.RangePartitioner /** * Regression model for isotonic regression. @@ -408,9 +409,11 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali */ private def parallelPoolAdjacentViolators( input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = { - val parallelStepResult = input - .sortBy(x => (x._2, x._1)) - .glom() + val keyedInput = input.keyBy(_._2) + val parallelStepResult = keyedInput + .partitionBy(new RangePartitioner(keyedInput.getNumPartitions, keyedInput)) + .values + .mapPartitions(p => Iterator(p.toArray.sortBy(x => (x._2, x._1)))) .flatMap(poolAdjacentViolators) .collect() .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering. |