aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala2
2 files changed, 22 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 9fab1d78ab..b073eba8a1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -35,11 +35,10 @@ import org.apache.spark.util.Utils
* @param preferredLocation the preferred location for this partition
*/
private[spark] case class CoalescedRDDPartition(
- index: Int,
- @transient rdd: RDD[_],
- parentsIndices: Array[Int],
- @transient preferredLocation: String = ""
- ) extends Partition {
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int],
+ @transient preferredLocation: Option[String] = None) extends Partition {
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
@@ -55,9 +54,10 @@ private[spark] case class CoalescedRDDPartition(
* @return locality of this coalesced partition between 0 and 1
*/
def localFraction: Double = {
- val loc = parents.count(p =>
- rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation))
-
+ val loc = parents.count { p =>
+ val parentPreferredLocations = rdd.context.getPreferredLocs(rdd, p.index).map(_.host)
+ preferredLocation.exists(parentPreferredLocations.contains)
+ }
if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble)
}
}
@@ -73,9 +73,9 @@ private[spark] case class CoalescedRDDPartition(
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
- @transient var prev: RDD[T],
- maxPartitions: Int,
- balanceSlack: Double = 0.10)
+ @transient var prev: RDD[T],
+ maxPartitions: Int,
+ balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
@@ -113,7 +113,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
* @return the machine most preferred by split
*/
override def getPreferredLocations(partition: Partition): Seq[String] = {
- List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
}
}
@@ -147,7 +147,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
*
*/
-private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
@@ -341,8 +341,14 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
}
}
-private[spark] case class PartitionGroup(prefLoc: String = "") {
+private case class PartitionGroup(prefLoc: Option[String] = None) {
var arr = mutable.ArrayBuffer[Partition]()
-
def size = arr.size
}
+
+private object PartitionGroup {
+ def apply(prefLoc: String): PartitionGroup = {
+ require(prefLoc != "", "Preferred location must not be empty")
+ PartitionGroup(Some(prefLoc))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e079ca3b1e..8de634a676 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -297,7 +297,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("coalesced RDDs with locality") {
val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b"))))
val coal3 = data3.coalesce(3)
- val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ val list3 = coal3.partitions.flatMap(_.asInstanceOf[CoalescedRDDPartition].preferredLocation)
assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped")
// RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5