aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala14
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala50
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala18
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala369
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala122
5 files changed, 6 insertions, 567 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index 2c8b245955..12216d9d33 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -27,8 +27,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.graphx.impl.RoutingTablePartition
import org.apache.spark.graphx.impl.ShippableVertexPartition
import org.apache.spark.graphx.impl.VertexAttributeBlock
-import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._
-import org.apache.spark.graphx.impl.VertexRDDFunctions._
/**
* Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by
@@ -233,7 +231,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
this.withPartitionsRDD[VD3](
partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.leftJoin(msgs)(f))
}
)
@@ -277,7 +275,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
this.withPartitionsRDD(
partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.innerJoin(msgs)(f))
}
)
@@ -297,7 +295,7 @@ class VertexRDD[@specialized VD: ClassTag](
*/
def aggregateUsingIndex[VD2: ClassTag](
messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
- val shuffled = messages.copartitionWithVertices(this.partitioner.get)
+ val shuffled = messages.partitionBy(this.partitioner.get)
val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
}
@@ -371,7 +369,7 @@ object VertexRDD {
def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val vertexPartitions = vPartitioned.mapPartitions(
iter => Iterator(ShippableVertexPartition(iter)),
@@ -412,7 +410,7 @@ object VertexRDD {
): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get)
val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) {
@@ -454,7 +452,7 @@ object VertexRDD {
.setName("VertexRDD.createRoutingTables - vid2pid (aggregation)")
val numEdgePartitions = edges.partitions.size
- vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions(
+ vid2pid.partitionBy(vertexPartitioner).mapPartitions(
iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)),
preservesPartitioning = true)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
deleted file mode 100644
index 714f3b81c9..0000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.implicitConversions
-import scala.reflect.{classTag, ClassTag}
-
-import org.apache.spark.Partitioner
-import org.apache.spark.graphx.{PartitionID, VertexId}
-import org.apache.spark.rdd.{ShuffledRDD, RDD}
-
-
-private[graphx]
-class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
- def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
- val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
-
- // Set a custom serializer if the data is of int or double type.
- if (classTag[VD] == ClassTag.Int) {
- rdd.setSerializer(new IntAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Long) {
- rdd.setSerializer(new LongAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Double) {
- rdd.setSerializer(new DoubleAggMsgSerializer)
- }
- rdd
- }
-}
-
-private[graphx]
-object VertexRDDFunctions {
- implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = {
- new VertexRDDFunctions(rdd)
- }
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index b27485953f..7a7fa91aad 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -30,24 +30,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
private[graphx]
-class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
- /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
- def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, Int, Int](
- self, partitioner).setSerializer(new RoutingTableMessageSerializer)
- }
-}
-
-private[graphx]
-object RoutingTableMessageRDDFunctions {
- import scala.language.implicitConversions
-
- implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = {
- new RoutingTableMessageRDDFunctions(rdd)
- }
-}
-
-private[graphx]
object RoutingTablePartition {
/**
* A message from an edge partition to a vertex specifying the position in which the edge
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
deleted file mode 100644
index 3909efcdfc..0000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ /dev/null
@@ -1,369 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.existentials
-
-import java.io.{EOFException, InputStream, OutputStream}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.serializer._
-
-import org.apache.spark.graphx._
-import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
-
-private[graphx]
-class RoutingTableMessageSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream): SerializationStream =
- new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- val msg = t.asInstanceOf[RoutingTableMessage]
- writeVarLong(msg._1, optimizePositive = false)
- writeInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream): DeserializationStream =
- new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-private[graphx]
-class VertexIdMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, _)]
- writeVarLong(msg._1, optimizePositive = false)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- (readVarLong(optimizePositive = false), null).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Int]. */
-private[graphx]
-class IntAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Int)]
- writeVarLong(msg._1, optimizePositive = false)
- writeUnsignedVarInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readUnsignedVarInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Long]. */
-private[graphx]
-class LongAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Long)]
- writeVarLong(msg._1, optimizePositive = false)
- writeVarLong(msg._2, optimizePositive = true)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readVarLong(optimizePositive = true)
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Double]. */
-private[graphx]
-class DoubleAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Double)]
- writeVarLong(msg._1, optimizePositive = false)
- writeDouble(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readDouble()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Helper classes to shorten the implementation of those special serializers.
-////////////////////////////////////////////////////////////////////////////////
-
-private[graphx]
-abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
- // The implementation should override this one.
- def writeObject[T: ClassTag](t: T): SerializationStream
-
- def writeInt(v: Int) {
- s.write(v >> 24)
- s.write(v >> 16)
- s.write(v >> 8)
- s.write(v)
- }
-
- def writeUnsignedVarInt(value: Int) {
- if ((value >>> 7) == 0) {
- s.write(value.toInt)
- } else if ((value >>> 14) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7)
- } else if ((value >>> 21) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14)
- } else if ((value >>> 28) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21)
- } else {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21 | 0x80)
- s.write(value >>> 28)
- }
- }
-
- def writeVarLong(value: Long, optimizePositive: Boolean) {
- val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value
- if ((v >>> 7) == 0) {
- s.write(v.toInt)
- } else if ((v >>> 14) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7).toInt)
- } else if ((v >>> 21) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14).toInt)
- } else if ((v >>> 28) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21).toInt)
- } else if ((v >>> 35) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28).toInt)
- } else if ((v >>> 42) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35).toInt)
- } else if ((v >>> 49) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42).toInt)
- } else if ((v >>> 56) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49).toInt)
- } else {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49 | 0x80).toInt)
- s.write((v >>> 56).toInt)
- }
- }
-
- def writeLong(v: Long) {
- s.write((v >>> 56).toInt)
- s.write((v >>> 48).toInt)
- s.write((v >>> 40).toInt)
- s.write((v >>> 32).toInt)
- s.write((v >>> 24).toInt)
- s.write((v >>> 16).toInt)
- s.write((v >>> 8).toInt)
- s.write(v.toInt)
- }
-
- def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
-
- override def flush(): Unit = s.flush()
-
- override def close(): Unit = s.close()
-}
-
-private[graphx]
-abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
- // The implementation should override this one.
- def readObject[T: ClassTag](): T
-
- def readInt(): Int = {
- val first = s.read()
- if (first < 0) throw new EOFException
- (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
- }
-
- def readUnsignedVarInt(): Int = {
- var value: Int = 0
- var i: Int = 0
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b: Int = readOrThrow()
- while ((b & 0x80) != 0) {
- value |= (b & 0x7F) << i
- i += 7
- if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long")
- b = readOrThrow()
- }
- value | (b << i)
- }
-
- def readVarLong(optimizePositive: Boolean): Long = {
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b = readOrThrow()
- var ret: Long = b & 0x7F
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 7
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 14
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 21
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 28
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 35
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 42
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 49
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= b.toLong << 56
- }
- }
- }
- }
- }
- }
- }
- }
- if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret
- }
-
- def readLong(): Long = {
- val first = s.read()
- if (first < 0) throw new EOFException()
- (first.toLong << 56) |
- (s.read() & 0xFF).toLong << 48 |
- (s.read() & 0xFF).toLong << 40 |
- (s.read() & 0xFF).toLong << 32 |
- (s.read() & 0xFF).toLong << 24 |
- (s.read() & 0xFF) << 16 |
- (s.read() & 0xFF) << 8 |
- (s.read() & 0xFF)
- }
-
- def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
-
- override def close(): Unit = s.close()
-}
-
-private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
-
- override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
- throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
- throw new UnsupportedOperationException
-
- // The implementation should override the following two.
- override def serializeStream(s: OutputStream): SerializationStream
- override def deserializeStream(s: InputStream): DeserializationStream
-}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
deleted file mode 100644
index 864cb1fdf0..0000000000
--- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx
-
-import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
-
-import scala.util.Random
-import scala.reflect.ClassTag
-
-import org.scalatest.FunSuite
-
-import org.apache.spark._
-import org.apache.spark.graphx.impl._
-import org.apache.spark.serializer.SerializationStream
-
-
-class SerializerSuite extends FunSuite with LocalSparkContext {
-
- test("IntAggMsgSerializer") {
- val outMsg = (4: VertexId, 5)
- val bout = new ByteArrayOutputStream
- val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Int) = inStrm.readObject()
- val inMsg2: (VertexId, Int) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("LongAggMsgSerializer") {
- val outMsg = (4: VertexId, 1L << 32)
- val bout = new ByteArrayOutputStream
- val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Long) = inStrm.readObject()
- val inMsg2: (VertexId, Long) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("DoubleAggMsgSerializer") {
- val outMsg = (4: VertexId, 5.0)
- val bout = new ByteArrayOutputStream
- val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Double) = inStrm.readObject()
- val inMsg2: (VertexId, Double) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("variable long encoding") {
- def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
- val bout = new ByteArrayOutputStream
- val stream = new ShuffleSerializationStream(bout) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
- this
- }
- }
- stream.writeObject(v)
-
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val dstream = new ShuffleDeserializationStream(bin) {
- def readObject[T: ClassTag](): T = {
- readVarLong(optimizePositive).asInstanceOf[T]
- }
- }
- val read = dstream.readObject[Long]()
- assert(read === v)
- }
-
- // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
- val d = Random.nextLong() % 128
- Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
- 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
- testVarLongEncoding(number, optimizePositive = false)
- testVarLongEncoding(number, optimizePositive = true)
- testVarLongEncoding(-number, optimizePositive = false)
- testVarLongEncoding(-number, optimizePositive = true)
- }
- }
-}