aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/ParallelCollection.scala
blob: 9b57ae3b4f2bc26782627f66ad0c49387f0efffe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package spark

import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer

private[spark] class ParallelCollectionSplit[T: ClassManifest](
    val rddId: Long,
    val slice: Int,
    values: Seq[T])
  extends Split with Serializable {
  
  def iterator(): Iterator[T] = values.iterator

  override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt

  override def equals(other: Any): Boolean = other match {
    case that: ParallelCollectionSplit[_] => (this.rddId == that.rddId && this.slice == that.slice)
    case _ => false
  }

  override val index: Int = slice
}

private[spark] class ParallelCollection[T: ClassManifest](
    sc: SparkContext, 
    @transient data: Seq[T],
    numSlices: Int)
  extends RDD[T](sc) {
  // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
  // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
  // instead.

  @transient
  val splits_ = {
    val slices = ParallelCollection.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
  }

  override def splits = splits_.asInstanceOf[Array[Split]]

  override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator
  
  override def preferredLocations(s: Split): Seq[String] = Nil
  
  override val dependencies: List[Dependency[_]] = Nil
}

private object ParallelCollection {
  /**
   * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
   * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes 
   * it efficient to run Spark over RDDs representing large sets of numbers.
   */
  def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    if (numSlices < 1) {
      throw new IllegalArgumentException("Positive number of slices required")
    }
    seq match {
      case r: Range.Inclusive => {
        val sign = if (r.step < 0) {
          -1 
        } else {
          1
        }
        slice(new Range(
            r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
      }
      case r: Range => {
        (0 until numSlices).map(i => {
          val start = ((i * r.length.toLong) / numSlices).toInt
          val end = (((i+1) * r.length.toLong) / numSlices).toInt
          new Range(r.start + start * r.step, r.start + end * r.step, r.step)
        }).asInstanceOf[Seq[Seq[T]]]
      }
      case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc
        val slices = new ArrayBuffer[Seq[T]](numSlices)
        val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
        var r = nr
        for (i <- 0 until numSlices) {
          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
          r = r.drop(sliceSize)
        }
        slices
      }
      case _ => {
        val array = seq.toArray  // To prevent O(n^2) operations for List etc
        (0 until numSlices).map(i => {
          val start = ((i * array.length.toLong) / numSlices).toInt
          val end = (((i+1) * array.length.toLong) / numSlices).toInt
          array.slice(start, end).toSeq
        })
      }
    }
  }
}