aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2014-07-23 20:11:28 -0700
committerReynold Xin <rxin@apache.org>2014-07-23 20:11:28 -0700
commit2d25e34814f81f288587f3277324cb655a5fb38d (patch)
tree0969341520725401d4d7309135b16ba5a499461e /graphx
parent60f0ae3d87c84fd96e1f4d0abf5be1f51870e7ab (diff)
downloadspark-2d25e34814f81f288587f3277324cb655a5fb38d.tar.gz
spark-2d25e34814f81f288587f3277324cb655a5fb38d.tar.bz2
spark-2d25e34814f81f288587f3277324cb655a5fb38d.zip
Replace RoutingTableMessage with pair
RoutingTableMessage was used to construct routing tables to enable joining VertexRDDs with partitioned edges. It stored three elements: the destination vertex ID, the source edge partition, and a byte specifying the position in which the edge partition referenced the vertex to enable join elimination. However, this was incompatible with sort-based shuffle (SPARK-2045). It was also slightly wasteful, because partition IDs are usually much smaller than 2^32, though this was mitigated by a custom serializer that used variable-length encoding. This commit replaces RoutingTableMessage with a pair of (VertexId, Int) where the Int encodes both the source partition ID (in the lower 30 bits) and the position (in the top 2 bits). Author: Ankur Dave <ankurdave@gmail.com> Closes #1553 from ankurdave/remove-RoutingTableMessage and squashes the following commits: 697e17b [Ankur Dave] Replace RoutingTableMessage with pair
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala1
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala47
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala16
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/package.scala2
4 files changed, 36 insertions, 30 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index eea9fe9520..1948c978c3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
kryo.register(classOf[Edge[Object]])
- kryo.register(classOf[RoutingTableMessage])
kryo.register(classOf[(VertexId, Object)])
kryo.register(classOf[EdgePartition[Object, Object]])
kryo.register(classOf[BitSet])
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 502b112d31..a565d3b28b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
-/**
- * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
- * the edge partition references `vid` in the specified `position` (src, dst, or both).
-*/
-private[graphx]
-class RoutingTableMessage(
- var vid: VertexId,
- var pid: PartitionID,
- var position: Byte)
- extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
- override def _1 = vid
- override def _2 = (pid, position)
- override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
-}
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
+ new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
}
}
@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
private[graphx]
object RoutingTablePartition {
+ /**
+ * A message from an edge partition to a vertex specifying the position in which the edge
+ * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
+ * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
+ */
+ type RoutingTableMessage = (VertexId, Int)
+
+ private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
+ val positionUpper2 = position << 30
+ val pidLower30 = pid & 0x3FFFFFFF
+ (vid, positionUpper2 | pidLower30)
+ }
+
+ private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
+ private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
+ private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
+
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
@@ -77,7 +81,9 @@ object RoutingTablePartition {
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
}
map.iterator.map { vidAndPosition =>
- new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
+ val vid = vidAndPosition._1
+ val position = vidAndPosition._2
+ toMessage(vid, pid, position)
}
}
@@ -88,9 +94,12 @@ object RoutingTablePartition {
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
for (msg <- iter) {
- pid2vid(msg.pid) += msg.vid
- srcFlags(msg.pid) += (msg.position & 0x1) != 0
- dstFlags(msg.pid) += (msg.position & 0x2) != 0
+ val vid = vidFromMessage(msg)
+ val pid = pidFromMessage(msg)
+ val position = positionFromMessage(msg)
+ pid2vid(pid) += vid
+ srcFlags(pid) += (position & 0x1) != 0
+ dstFlags(pid) += (position & 0x2) != 0
}
new RoutingTablePartition(pid2vid.zipWithIndex.map {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
index 2d98c24d69..3909efcdfc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.graphx._
import org.apache.spark.serializer._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
+
private[graphx]
class RoutingTableMessageSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T): SerializationStream = {
val msg = t.asInstanceOf[RoutingTableMessage]
- writeVarLong(msg.vid, optimizePositive = false)
- writeUnsignedVarInt(msg.pid)
- // TODO: Write only the bottom two bits of msg.position
- s.write(msg.position)
+ writeVarLong(msg._1, optimizePositive = false)
+ writeInt(msg._2)
this
}
}
@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleDeserializationStream(s) {
override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
- val b = readUnsignedVarInt()
- val c = s.read()
- if (c == -1) throw new EOFException
- new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
+ val b = readInt()
+ (a, b).asInstanceOf[T]
}
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
index ff17edeaf8..6aab28ff05 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
@@ -30,7 +30,7 @@ package object graphx {
*/
type VertexId = Long
- /** Integer identifer of a graph partition. */
+ /** Integer identifer of a graph partition. Must be less than 2^30. */
// TODO: Consider using Char.
type PartitionID = Int