aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-08-13 18:09:40 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-10-08 23:14:38 -0700
commitb535db7d89abd59713ce83ae937d06193a04441e (patch)
treec871c5023b8b6802d3a30f4179e7e66a8b983fb0 /core
parente67d5b962a2adddc073cfc9c99be9012fbb69838 (diff)
downloadspark-b535db7d89abd59713ce83ae937d06193a04441e.tar.gz
spark-b535db7d89abd59713ce83ae937d06193a04441e.tar.bz2
spark-b535db7d89abd59713ce83ae937d06193a04441e.zip
Added a fast and low-memory append-only map implementation for cogroup
and parallel reduce operations
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala16
-rw-r--r--core/src/main/scala/spark/util/AppendOnlyMap.scala241
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala4
-rw-r--r--core/src/test/scala/spark/util/AppendOnlyMapSuite.scala141
5 files changed, 411 insertions, 29 deletions
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 3ef402926e..fa1419df18 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -21,8 +21,10 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
+import spark.util.AppendOnlyMap
+
/** A set of functions used to aggregate data.
- *
+ *
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
@@ -33,27 +35,29 @@ case class Aggregator[K, V, C] (
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- for (kv <- iter) {
- val oldC = combiners.get(kv._1)
- if (oldC == null) {
- combiners.put(kv._1, createCombiner(kv._2))
- } else {
- combiners.put(kv._1, mergeValue(oldC, kv._2))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ for ((k, v) <- iter) {
+ combiners.changeValue(k, (hadValue, oldValue) => {
+ if (hadValue) {
+ mergeValue(oldValue, v)
+ } else {
+ createCombiner(v)
+ }
+ })
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- iter.foreach { case(k, c) =>
- val oldC = combiners.get(k)
- if (oldC == null) {
- combiners.put(k, c)
- } else {
- combiners.put(k, mergeCombiners(oldC, c))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ for ((k, c) <- iter) {
+ combiners.changeValue(k, (hadValue, oldValue) => {
+ if (hadValue) {
+ mergeCombiners(oldValue, c)
+ } else {
+ c
+ }
+ })
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 0187256a8e..f6dd8a65cb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -105,17 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- val seq = map.get(k)
- if (seq != null) {
- seq
- } else {
- val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
- map.put(k, seq)
- seq
- }
+ map.changeValue(k, (hadValue, oldValue) => {
+ if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any])
+ })
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
}
}
- JavaConversions.mapAsScalaMap(map).iterator
+ map.iterator
}
override def clearDependencies() {
diff --git a/core/src/main/scala/spark/util/AppendOnlyMap.scala b/core/src/main/scala/spark/util/AppendOnlyMap.scala
new file mode 100644
index 0000000000..416b93ea41
--- /dev/null
+++ b/core/src/main/scala/spark/util/AppendOnlyMap.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.util
+
+/**
+ * A simple open hash table optimized for the append-only use case, where keys
+ * are never removed, but the value for each key may be changed.
+ *
+ * This implementation uses quadratic probing with a power-of-2 hash table
+ * size, which is guaranteed to explore all spaces for each key (see
+ * http://en.wikipedia.org/wiki/Quadratic_probing).
+ *
+ * TODO: Cache the hash values of each key? java.util.HashMap does that.
+ */
+private[spark]
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+ if (!isPowerOf2(initialCapacity)) {
+ throw new IllegalArgumentException("Initial capacity must be power of 2")
+ }
+ if (initialCapacity >= (1 << 30)) {
+ throw new IllegalArgumentException("Can't make capacity bigger than 2^29 elements")
+ }
+
+ private var capacity = initialCapacity
+ private var curSize = 0
+
+ // Holds keys and values in the same array for memory locality; specifically, the order of
+ // elements is key0, value0, key1, value1, key2, value2, etc.
+ private var data = new Array[AnyRef](2 * capacity)
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ private val LOAD_FACTOR = 0.7
+
+ /** Get the value for a given key */
+ def apply(key: K): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ return nullValue
+ }
+ val mask = capacity - 1
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(k) || curKey.eq(null) || curKey == k) {
+ return data(2 * pos + 1).asInstanceOf[V]
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return null.asInstanceOf[V]
+ }
+
+ /** Set the value for a key */
+ def update(key: K, value: V) {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = value
+ haveNullValue = true
+ return
+ }
+ val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
+ if (isNewEntry) {
+ incrementSize()
+ }
+ }
+
+ /**
+ * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
+ * for key, if any, or null otherwise. Returns the newly updated value.
+ */
+ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = updateFunc(haveNullValue, nullValue)
+ haveNullValue = true
+ return nullValue
+ }
+ val mask = capacity - 1
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ val newValue = updateFunc(false, null.asInstanceOf[V])
+ data(2 * pos) = k
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ incrementSize()
+ return newValue
+ } else if (curKey.eq(k) || curKey == k) {
+ val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V])
+ data(2*pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ null.asInstanceOf[V] // Never reached but needed to keep compiler happy
+ }
+
+ /** Iterator method from Iterable */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
+ }
+ null
+ }
+
+ override def hasNext: Boolean = nextValue() != null
+
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
+ }
+ }
+
+ override def size: Int = curSize
+
+ /** Increase table size by 1, rehashing if necessary */
+ private def incrementSize() {
+ curSize += 1
+ if (curSize > LOAD_FACTOR * capacity) {
+ growTable()
+ }
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def rehash(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ /**
+ * Put an entry into a table represented by data, returning true if
+ * this increases the size of the table or false otherwise. Assumes
+ * that "data" has at least one empty slot.
+ */
+ private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
+ val mask = (data.length / 2) - 1
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = key
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return true
+ } else if (curKey.eq(key) || curKey == key) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return false
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return false // Never reached but needed to keep compiler happy
+ }
+
+ /** Double the table's size and re-hash everything */
+ private def growTable() {
+ val newCapacity = capacity * 2
+ if (newCapacity >= (1 << 30)) {
+ // We can't make the table this big because we want an array of 2x
+ // that size for our data, but array sizes are at most Int.MaxValue
+ throw new Exception("Can't make capacity bigger than 2^29 elements")
+ }
+ val newData = new Array[AnyRef](2 * newCapacity)
+ var pos = 0
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ putInto(newData, data(2 * pos), data(2 * pos + 1))
+ }
+ pos += 1
+ }
+ data = newData
+ capacity = newCapacity
+ }
+
+ private def isPowerOf2(num: Int): Boolean = {
+ var n = num
+ while (n > 0) {
+ if (n == 1) {
+ return true
+ } else if (n % 2 == 1) {
+ return false
+ } else {
+ n /= 2
+ }
+ }
+ return false
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 41a161e08a..794c3e8f8f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -44,7 +44,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
- d.count
+ d.count()
Thread.sleep(1000)
listener.stageInfos.size should be (1)
@@ -55,7 +55,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
d4.setName("A Cogroup")
- d4.collectAsMap
+ d4.collectAsMap()
Thread.sleep(1000)
listener.stageInfos.size should be (4)
diff --git a/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..d1e36781ef
--- /dev/null
+++ b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.util
+
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+
+class AppendOnlyMapSuite extends FunSuite {
+ test("initialization") {
+ val goodMap1 = new AppendOnlyMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new AppendOnlyMap[Int, Int](256)
+ assert(goodMap2.size === 0)
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](255) // Invalid map size: not power of 2
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](-1) // Invalid map size: not power of 2
+ }
+ }
+
+ test("object keys and values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map("" + i) === "" + i)
+ }
+ assert(map("0") === null)
+ assert(map("101") === null)
+ assert(map(null) === null)
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet)
+ }
+
+ test("primitive keys and values") {
+ val map = new AppendOnlyMap[Int, Int]()
+ for (i <- 1 to 100) {
+ map(i) = i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i) === i)
+ }
+ assert(map(0) === null)
+ assert(map(101) === null)
+ val set = new HashSet[(Int, Int)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(x => (x, x)).toSet)
+ }
+
+ test("null keys") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "" + i)
+ oldValue + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ i + "!"
+ })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ "null!"
+ })
+ assert(map.size === 401)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+}