aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2014-01-13 22:58:38 -0800
committerPatrick Wendell <pwendell@gmail.com>2014-01-13 22:58:38 -0800
commit4a805aff5e381752afb2bfd579af908d623743ed (patch)
tree78f81dfdf6bcaa47ff87f8c882f829eae59c2bdb /graphx
parent945fe7a37ea3189b5a9f8a74e5c2fa9c1088ebfc (diff)
parent80e73ed0004cceb47a450c79aa4faa598502fa45 (diff)
downloadspark-4a805aff5e381752afb2bfd579af908d623743ed.tar.gz
spark-4a805aff5e381752afb2bfd579af908d623743ed.tar.bz2
spark-4a805aff5e381752afb2bfd579af908d623743ed.zip
Merge pull request #367 from ankurdave/graphx
GraphX: Unifying Graphs and Tables GraphX extends Spark's distributed fault-tolerant collections API and interactive console with a new graph API which leverages recent advances in graph systems (e.g., [GraphLab](http://graphlab.org)) to enable users to easily and interactively build, transform, and reason about graph structured data at scale. See http://amplab.github.io/graphx/. Thanks to @jegonzal, @rxin, @ankurdave, @dcrankshaw, @jianpingjwang, @amatsukawa, @kellrott, and @adamnovak. Tasks left: - [x] Graph-level uncache - [x] Uncache previous iterations in Pregel - [x] ~~Uncache previous iterations in GraphLab~~ (postponed to post-release) - [x] - Describe GC issue with GraphLab - [ ] Write `docs/graphx-programming-guide.md` - [x] - Mention future Bagel support in docs - [ ] - Section on caching/uncaching in docs: As with Spark, cache something that is used more than once. In an iterative algorithm, try to cache and force (i.e., materialize) something every iteration, then uncache the cached things that depended on the newly materialized RDD but that won't be referenced again. - [x] Undo modifications to core collections and instead copy them to org.apache.spark.graphx - [x] Make Graph serializable to work around capture in Spark shell - [x] Rename graph -> graphx in package name and subproject - [x] Remove standalone PageRank - [x] ~~Fix amplab/graphx#52 by checking `iter.hasNext`~~
Diffstat (limited to 'graphx')
-rw-r--r--graphx/data/followers.txt8
-rw-r--r--graphx/data/users.txt7
-rw-r--r--graphx/pom.xml67
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Edge.scala45
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala44
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala102
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala49
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Graph.scala405
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala31
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala72
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala301
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala103
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala139
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala347
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala220
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala45
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala42
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala379
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala98
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala195
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala65
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala395
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala261
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala7
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala136
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala38
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala147
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala138
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala94
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala76
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/package.scala18
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala117
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala218
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala153
-rw-r--r--graphx/src/test/resources/log4j.properties28
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala66
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala273
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala28
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala41
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala183
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala85
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala76
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala113
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala113
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala119
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala31
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala57
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala70
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala93
49 files changed, 5938 insertions, 0 deletions
diff --git a/graphx/data/followers.txt b/graphx/data/followers.txt
new file mode 100644
index 0000000000..7bb8e900e2
--- /dev/null
+++ b/graphx/data/followers.txt
@@ -0,0 +1,8 @@
+2 1
+4 1
+1 2
+6 3
+7 3
+7 6
+6 7
+3 7
diff --git a/graphx/data/users.txt b/graphx/data/users.txt
new file mode 100644
index 0000000000..982d19d50b
--- /dev/null
+++ b/graphx/data/users.txt
@@ -0,0 +1,7 @@
+1,BarackObama,Barack Obama
+2,ladygaga,Goddess of Love
+3,jeresig,John Resig
+4,justinbieber,Justin Bieber
+6,matei_zaharia,Matei Zaharia
+7,odersky,Martin Odersky
+8,anonsys
diff --git a/graphx/pom.xml b/graphx/pom.xml
new file mode 100644
index 0000000000..3e5faf230d
--- /dev/null
+++ b/graphx/pom.xml
@@ -0,0 +1,67 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>0.9.0-incubating-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-graphx_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project GraphX</name>
+ <url>http://spark-project.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
new file mode 100644
index 0000000000..738a38b27f
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
@@ -0,0 +1,45 @@
+package org.apache.spark.graphx
+
+/**
+ * A single directed edge consisting of a source id, target id,
+ * and the data associated with the edge.
+ *
+ * @tparam ED type of the edge attribute
+ *
+ * @param srcId The vertex id of the source vertex
+ * @param dstId The vertex id of the target vertex
+ * @param attr The attribute associated with the edge
+ */
+case class Edge[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED] (
+ var srcId: VertexID = 0,
+ var dstId: VertexID = 0,
+ var attr: ED = null.asInstanceOf[ED])
+ extends Serializable {
+
+ /**
+ * Given one vertex in the edge return the other vertex.
+ *
+ * @param vid the id one of the two vertices on the edge.
+ * @return the id of the other vertex on the edge.
+ */
+ def otherVertexId(vid: VertexID): VertexID =
+ if (srcId == vid) dstId else { assert(dstId == vid); srcId }
+
+ /**
+ * Return the relative direction of the edge to the corresponding
+ * vertex.
+ *
+ * @param vid the id of one of the two vertices in the edge.
+ * @return the relative direction of the edge to the corresponding
+ * vertex.
+ */
+ def relativeDirection(vid: VertexID): EdgeDirection =
+ if (vid == srcId) EdgeDirection.Out else { assert(vid == dstId); EdgeDirection.In }
+}
+
+object Edge {
+ private[graphx] def lexicographicOrdering[ED] = new Ordering[Edge[ED]] {
+ override def compare(a: Edge[ED], b: Edge[ED]): Int =
+ (if (a.srcId != b.srcId) a.srcId - b.srcId else a.dstId - b.dstId).toInt
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
new file mode 100644
index 0000000000..f265764006
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
@@ -0,0 +1,44 @@
+package org.apache.spark.graphx
+
+/**
+ * The direction of a directed edge relative to a vertex.
+ */
+class EdgeDirection private (private val name: String) extends Serializable {
+ /**
+ * Reverse the direction of an edge. An in becomes out,
+ * out becomes in and both and either remain the same.
+ */
+ def reverse: EdgeDirection = this match {
+ case EdgeDirection.In => EdgeDirection.Out
+ case EdgeDirection.Out => EdgeDirection.In
+ case EdgeDirection.Either => EdgeDirection.Either
+ case EdgeDirection.Both => EdgeDirection.Both
+ }
+
+ override def toString: String = "EdgeDirection." + name
+
+ override def equals(o: Any) = o match {
+ case other: EdgeDirection => other.name == name
+ case _ => false
+ }
+
+ override def hashCode = name.hashCode
+}
+
+
+/**
+ * A set of [[EdgeDirection]]s.
+ */
+object EdgeDirection {
+ /** Edges arriving at a vertex. */
+ final val In = new EdgeDirection("In")
+
+ /** Edges originating from a vertex. */
+ final val Out = new EdgeDirection("Out")
+
+ /** Edges originating from *or* arriving at a vertex of interest. */
+ final val Either = new EdgeDirection("Either")
+
+ /** Edges originating from *and* arriving at a vertex of interest. */
+ final val Both = new EdgeDirection("Both")
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
new file mode 100644
index 0000000000..832b7816fe
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -0,0 +1,102 @@
+package org.apache.spark.graphx
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * `EdgeRDD[ED]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each partition
+ * for performance.
+ */
+class EdgeRDD[@specialized ED: ClassTag](
+ val partitionsRDD: RDD[(PartitionID, EdgePartition[ED])])
+ extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ partitionsRDD.setName("EdgeRDD")
+
+ override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
+
+ /**
+ * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the
+ * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new
+ * partitioner that allows co-partitioning with `partitionsRDD`.
+ */
+ override val partitioner =
+ partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+
+ override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
+ firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context).next._2.iterator
+ }
+
+ override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
+
+ override def persist(newLevel: StorageLevel): EdgeRDD[ED] = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def persist(): EdgeRDD[ED] = persist(StorageLevel.MEMORY_ONLY)
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def cache(): EdgeRDD[ED] = persist()
+
+ override def unpersist(blocking: Boolean = true): EdgeRDD[ED] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ private[graphx] def mapEdgePartitions[ED2: ClassTag](f: (PartitionID, EdgePartition[ED]) => EdgePartition[ED2])
+ : EdgeRDD[ED2] = {
+ new EdgeRDD[ED2](partitionsRDD.mapPartitions({ iter =>
+ val (pid, ep) = iter.next()
+ Iterator(Tuple2(pid, f(pid, ep)))
+ }, preservesPartitioning = true))
+ }
+
+ /**
+ * Map the values in an edge partitioning preserving the structure but changing the values.
+ *
+ * @tparam ED2 the new edge value type
+ * @param f the function from an edge to a new edge value
+ * @return a new EdgeRDD containing the new edge values
+ */
+ def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] =
+ mapEdgePartitions((pid, part) => part.map(f))
+
+ /**
+ * Reverse all the edges in this RDD.
+ *
+ * @return a new EdgeRDD containing all the edges reversed
+ */
+ def reverse: EdgeRDD[ED] = mapEdgePartitions((pid, part) => part.reverse)
+
+ /**
+ * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same
+ * [[PartitionStrategy]].
+ *
+ * @param other the EdgeRDD to join with
+ * @param f the join function applied to corresponding values of `this` and `other`
+ * @return a new EdgeRDD containing only edges that appear in both `this` and `other`, with values
+ * supplied by `f`
+ */
+ def innerJoin[ED2: ClassTag, ED3: ClassTag]
+ (other: EdgeRDD[ED2])
+ (f: (VertexID, VertexID, ED, ED2) => ED3): EdgeRDD[ED3] = {
+ val ed2Tag = classTag[ED2]
+ val ed3Tag = classTag[ED3]
+ new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
+ (thisIter, otherIter) =>
+ val (pid, thisEPart) = thisIter.next()
+ val (_, otherEPart) = otherIter.next()
+ Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))
+ })
+ }
+
+ private[graphx] def collectVertexIDs(): RDD[VertexID] = {
+ partitionsRDD.flatMap { case (_, p) => Array.concat(p.srcIds, p.dstIds) }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
new file mode 100644
index 0000000000..4253b24b5a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
@@ -0,0 +1,49 @@
+package org.apache.spark.graphx
+
+/**
+ * An edge triplet represents an edge along with the vertex attributes of its neighboring vertices.
+ *
+ * @tparam VD the type of the vertex attribute.
+ * @tparam ED the type of the edge attribute
+ */
+class EdgeTriplet[VD, ED] extends Edge[ED] {
+ /**
+ * The source vertex attribute
+ */
+ var srcAttr: VD = _ //nullValue[VD]
+
+ /**
+ * The destination vertex attribute
+ */
+ var dstAttr: VD = _ //nullValue[VD]
+
+ /**
+ * Set the edge properties of this triplet.
+ */
+ protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = {
+ srcId = other.srcId
+ dstId = other.dstId
+ attr = other.attr
+ this
+ }
+
+ /**
+ * Given one vertex in the edge return the other vertex.
+ *
+ * @param vid the id one of the two vertices on the edge
+ * @return the attribute for the other vertex on the edge
+ */
+ def otherVertexAttr(vid: VertexID): VD =
+ if (srcId == vid) dstAttr else { assert(dstId == vid); srcAttr }
+
+ /**
+ * Get the vertex object for the given vertex in the edge.
+ *
+ * @param vid the id of one of the two vertices on the edge
+ * @return the attr for the vertex with that id
+ */
+ def vertexAttr(vid: VertexID): VD =
+ if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr }
+
+ override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString()
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
new file mode 100644
index 0000000000..9dd05ade0a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -0,0 +1,405 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * The Graph abstractly represents a graph with arbitrary objects
+ * associated with vertices and edges. The graph provides basic
+ * operations to access and manipulate the data associated with
+ * vertices and edges as well as the underlying structure. Like Spark
+ * RDDs, the graph is a functional data-structure in which mutating
+ * operations return new graphs.
+ *
+ * @note [[GraphOps]] contains additional convenience operations and graph algorithms.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ */
+abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializable {
+
+ /**
+ * An RDD containing the vertices and their associated attributes.
+ *
+ * @note vertex ids are unique.
+ * @return an RDD containing the vertices in this graph
+ */
+ val vertices: VertexRDD[VD]
+
+ /**
+ * An RDD containing the edges and their associated attributes. The entries in the RDD contain
+ * just the source id and target id along with the edge data.
+ *
+ * @return an RDD containing the edges in this graph
+ *
+ * @see [[Edge]] for the edge type.
+ * @see [[triplets]] to get an RDD which contains all the edges
+ * along with their vertex data.
+ *
+ */
+ val edges: EdgeRDD[ED]
+
+ /**
+ * An RDD containing the edge triplets, which are edges along with the vertex data associated with
+ * the adjacent vertices. The caller should use [[edges]] if the vertex data are not needed, i.e.
+ * if only the edge data and adjacent vertex ids are needed.
+ *
+ * @return an RDD containing edge triplets
+ *
+ * @example This operation might be used to evaluate a graph
+ * coloring where we would like to check that both vertices are a
+ * different color.
+ * {{{
+ * type Color = Int
+ * val graph: Graph[Color, Int] = GraphLoader.edgeListFile("hdfs://file.tsv")
+ * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum
+ * }}}
+ */
+ val triplets: RDD[EdgeTriplet[VD, ED]]
+
+ /**
+ * Caches the vertices and edges associated with this graph at the specified storage level.
+ *
+ * @param newLevel the level at which to cache the graph.
+ *
+ * @return A reference to this graph for convenience.
+ */
+ def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED]
+
+ /**
+ * Caches the vertices and edges associated with this graph. This is used to
+ * pin a graph in memory enabling multiple queries to reuse the same
+ * construction process.
+ */
+ def cache(): Graph[VD, ED]
+
+ /**
+ * Uncaches only the vertices of this graph, leaving the edges alone. This is useful in iterative
+ * algorithms that modify the vertex attributes but reuse the edges. This method can be used to
+ * uncache the vertex attributes of previous iterations once they are no longer needed, improving
+ * GC performance.
+ */
+ def unpersistVertices(blocking: Boolean = true): Graph[VD, ED]
+
+ /**
+ * Repartitions the edges in the graph according to `partitionStrategy`.
+ */
+ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED]
+
+ /**
+ * Transforms each vertex attribute in the graph using the map function.
+ *
+ * @note The new graph has the same structure. As a consequence the underlying index structures
+ * can be reused.
+ *
+ * @param map the function from a vertex object to a new vertex value
+ *
+ * @tparam VD2 the new vertex data type
+ *
+ * @example We might use this operation to change the vertex values
+ * from one type to another to initialize an algorithm.
+ * {{{
+ * val rawGraph: Graph[(), ()] = Graph.textFile("hdfs://file")
+ * val root = 42
+ * var bfsGraph = rawGraph.mapVertices[Int]((vid, data) => if (vid == root) 0 else Math.MaxValue)
+ * }}}
+ *
+ */
+ def mapVertices[VD2: ClassTag](map: (VertexID, VD) => VD2): Graph[VD2, ED]
+
+ /**
+ * Transforms each edge attribute in the graph using the map function. The map function is not
+ * passed the vertex value for the vertices adjacent to the edge. If vertex values are desired,
+ * use `mapTriplets`.
+ *
+ * @note This graph is not changed and that the new graph has the
+ * same structure. As a consequence the underlying index structures
+ * can be reused.
+ *
+ * @param map the function from an edge object to a new edge value.
+ *
+ * @tparam ED2 the new edge data type
+ *
+ * @example This function might be used to initialize edge
+ * attributes.
+ *
+ */
+ def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2] = {
+ mapEdges((pid, iter) => iter.map(map))
+ }
+
+ /**
+ * Transforms each edge attribute using the map function, passing it a whole partition at a
+ * time. The map function is given an iterator over edges within a logical partition as well as
+ * the partition's ID, and it should return a new iterator over the new values of each edge. The
+ * new iterator's elements must correspond one-to-one with the old iterator's elements. If
+ * adjacent vertex values are desired, use `mapTriplets`.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map a function that takes a partition id and an iterator
+ * over all the edges in the partition, and must return an iterator over
+ * the new values for each edge in the order of the input iterator
+ *
+ * @tparam ED2 the new edge data type
+ *
+ */
+ def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2])
+ : Graph[VD, ED2]
+
+ /**
+ * Transforms each edge attribute using the map function, passing it the adjacent vertex attributes
+ * as well. If adjacent vertex values are not required, consider using `mapEdges` instead.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map the function from an edge object to a new edge value.
+ *
+ * @tparam ED2 the new edge data type
+ *
+ * @example This function might be used to initialize edge
+ * attributes based on the attributes associated with each vertex.
+ * {{{
+ * val rawGraph: Graph[Int, Int] = someLoadFunction()
+ * val graph = rawGraph.mapTriplets[Int]( edge =>
+ * edge.src.data - edge.dst.data)
+ * }}}
+ *
+ */
+ def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+ mapTriplets((pid, iter) => iter.map(map))
+ }
+
+ /**
+ * Transforms each edge attribute a partition at a time using the map function, passing it the
+ * adjacent vertex attributes as well. The map function is given an iterator over edge triplets
+ * within a logical partition and should yield a new iterator over the new values of each edge in
+ * the order in which they are provided. If adjacent vertex values are not required, consider
+ * using `mapEdges` instead.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map the iterator transform
+ *
+ * @tparam ED2 the new edge data type
+ *
+ */
+ def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
+ : Graph[VD, ED2]
+
+ /**
+ * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
+ * graph contains an edge from b to a.
+ */
+ def reverse: Graph[VD, ED]
+
+ /**
+ * Restricts the graph to only the vertices and edges satisfying the predicates. The resulting
+ * subgraph satisifies
+ *
+ * {{{
+ * V' = {v : for all v in V where vpred(v)}
+ * E' = {(u,v): for all (u,v) in E where epred((u,v)) && vpred(u) && vpred(v)}
+ * }}}
+ *
+ * @param epred the edge predicate, which takes a triplet and
+ * evaluates to true if the edge is to remain in the subgraph. Note
+ * that only edges where both vertices satisfy the vertex
+ * predicate are considered.
+ *
+ * @param vpred the vertex predicate, which takes a vertex object and
+ * evaluates to true if the vertex is to be included in the subgraph
+ *
+ * @return the subgraph containing only the vertices and edges that
+ * satisfy the predicates
+ */
+ def subgraph(
+ epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
+ vpred: (VertexID, VD) => Boolean = ((v, d) => true))
+ : Graph[VD, ED]
+
+ /**
+ * Restricts the graph to only the vertices and edges that are also in `other`, but keeps the
+ * attributes from this graph.
+ * @param other the graph to project this graph onto
+ * @return a graph with vertices and edges that exist in both the current graph and `other`,
+ * with vertex and edge data from the current graph
+ */
+ def mask[VD2: ClassTag, ED2: ClassTag](other: Graph[VD2, ED2]): Graph[VD, ED]
+
+ /**
+ * Merges multiple edges between two vertices into a single edge. For correct results, the graph
+ * must have been partitioned using [[partitionBy]].
+ *
+ * @param merge the user-supplied commutative associative function to merge edge attributes
+ * for duplicate edges.
+ *
+ * @return The resulting graph with a single edge for each (source, dest) vertex pair.
+ */
+ def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]
+
+ /**
+ * Aggregates values from the neighboring edges and vertices of each vertex. The user supplied
+ * `mapFunc` function is invoked on each edge of the graph, generating 0 or more "messages" to be
+ * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
+ * the map phase destined to each vertex.
+ *
+ * @tparam A the type of "message" to be sent to each vertex
+ *
+ * @param mapFunc the user defined map function which returns 0 or
+ * more messages to neighboring vertices
+ *
+ * @param reduceFunc the user defined reduce function which should
+ * be commutative and associative and is used to combine the output
+ * of the map phase
+ *
+ * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to consider
+ * when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on edges with
+ * destination in the active set. If the direction is `Out`, `mapFunc` will only be run on edges
+ * originating from vertices in the active set. If the direction is `Either`, `mapFunc` will be
+ * run on edges with *either* vertex in the active set. If the direction is `Both`, `mapFunc` will
+ * be run on edges with *both* vertices in the active set. The active set must have the same index
+ * as the graph's vertices.
+ *
+ * @example We can use this function to compute the in-degree of each
+ * vertex
+ * {{{
+ * val rawGraph: Graph[(),()] = Graph.textFile("twittergraph")
+ * val inDeg: RDD[(VertexID, Int)] =
+ * mapReduceTriplets[Int](et => Iterator((et.dst.id, 1)), _ + _)
+ * }}}
+ *
+ * @note By expressing computation at the edge level we achieve
+ * maximum parallelism. This is one of the core functions in the
+ * Graph API in that enables neighborhood level computation. For
+ * example this function can be used to count neighbors satisfying a
+ * predicate or implement PageRank.
+ *
+ */
+ def mapReduceTriplets[A: ClassTag](
+ mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ reduceFunc: (A, A) => A,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
+ : VertexRDD[A]
+
+ /**
+ * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
+ * input table should contain at most one entry for each vertex. If no entry in `other` is
+ * provided for a particular vertex in the graph, the map function receives `None`.
+ *
+ * @tparam U the type of entry in the table of updates
+ * @tparam VD2 the new vertex value type
+ *
+ * @param other the table to join with the vertices in the graph.
+ * The table should contain at most one entry for each vertex.
+ * @param mapFunc the function used to compute the new vertex values.
+ * The map function is invoked for all vertices, even those
+ * that do not have a corresponding entry in the table.
+ *
+ * @example This function is used to update the vertices with new values based on external data.
+ * For example we could add the out-degree to each vertex record:
+ *
+ * {{{
+ * val rawGraph: Graph[_, _] = Graph.textFile("webgraph")
+ * val outDeg: RDD[(VertexID, Int)] = rawGraph.outDegrees()
+ * val graph = rawGraph.outerJoinVertices(outDeg) {
+ * (vid, data, optDeg) => optDeg.getOrElse(0)
+ * }
+ * }}}
+ */
+ def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexID, U)])
+ (mapFunc: (VertexID, VD, Option[U]) => VD2)
+ : Graph[VD2, ED]
+
+ /**
+ * The associated [[GraphOps]] object.
+ */
+ // Save a copy of the GraphOps object so there is always one unique GraphOps object
+ // for a given Graph object, and thus the lazy vals in GraphOps would work as intended.
+ val ops = new GraphOps(this)
+} // end of Graph
+
+
+/**
+ * The Graph object contains a collection of routines used to construct graphs from RDDs.
+ */
+object Graph {
+
+ /**
+ * Construct a graph from a collection of edges encoded as vertex id pairs.
+ *
+ * @param rawEdges a collection of edges in (src, dst) form
+ * @param uniqueEdges if multiple identical edges are found they are combined and the edge
+ * attribute is set to the sum. Otherwise duplicate edges are treated as separate. To enable
+ * `uniqueEdges`, a [[PartitionStrategy]] must be provided.
+ *
+ * @return a graph with edge attributes containing either the count of duplicate edges or 1
+ * (if `uniqueEdges` is `None`) and vertex attributes containing the total degree of each vertex.
+ */
+ def fromEdgeTuples[VD: ClassTag](
+ rawEdges: RDD[(VertexID, VertexID)],
+ defaultValue: VD,
+ uniqueEdges: Option[PartitionStrategy] = None): Graph[VD, Int] =
+ {
+ val edges = rawEdges.map(p => Edge(p._1, p._2, 1))
+ val graph = GraphImpl(edges, defaultValue)
+ uniqueEdges match {
+ case Some(p) => graph.partitionBy(p).groupEdges((a, b) => a + b)
+ case None => graph
+ }
+ }
+
+ /**
+ * Construct a graph from a collection of edges.
+ *
+ * @param edges the RDD containing the set of edges in the graph
+ * @param defaultValue the default vertex attribute to use for each vertex
+ *
+ * @return a graph with edge attributes described by `edges` and vertices
+ * given by all vertices in `edges` with value `defaultValue`
+ */
+ def fromEdges[VD: ClassTag, ED: ClassTag](
+ edges: RDD[Edge[ED]],
+ defaultValue: VD): Graph[VD, ED] = {
+ GraphImpl(edges, defaultValue)
+ }
+
+ /**
+ * Construct a graph from a collection of vertices and
+ * edges with attributes. Duplicate vertices are picked arbitrarily and
+ * vertices found in the edge collection but not in the input
+ * vertices are assigned the default attribute.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ * @param vertices the "set" of vertices and their attributes
+ * @param edges the collection of edges in the graph
+ * @param defaultVertexAttr the default vertex attribute to use for vertices that are
+ * mentioned in edges but not in vertices
+ */
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: RDD[(VertexID, VD)],
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD = null.asInstanceOf[VD]): Graph[VD, ED] = {
+ GraphImpl(vertices, edges, defaultVertexAttr)
+ }
+
+ /**
+ * Implicitly extracts the [[GraphOps]] member from a graph.
+ *
+ * To improve modularity the Graph type only contains a small set of basic operations.
+ * All the convenience operations are defined in the [[GraphOps]] class which may be
+ * shared across multiple graph implementations.
+ */
+ implicit def graphToGraphOps[VD: ClassTag, ED: ClassTag](g: Graph[VD, ED]) = g.ops
+} // end of Graph object
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
new file mode 100644
index 0000000000..d79bdf9618
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -0,0 +1,31 @@
+package org.apache.spark.graphx
+
+import com.esotericsoftware.kryo.Kryo
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.serializer.KryoRegistrator
+import org.apache.spark.util.collection.BitSet
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * Registers GraphX classes with Kryo for improved performance.
+ */
+class GraphKryoRegistrator extends KryoRegistrator {
+
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[Edge[Object]])
+ kryo.register(classOf[MessageToPartition[Object]])
+ kryo.register(classOf[VertexBroadcastMsg[Object]])
+ kryo.register(classOf[(VertexID, Object)])
+ kryo.register(classOf[EdgePartition[Object]])
+ kryo.register(classOf[BitSet])
+ kryo.register(classOf[VertexIdToIndexMap])
+ kryo.register(classOf[VertexAttributeBlock[Object]])
+ kryo.register(classOf[PartitionStrategy])
+ kryo.register(classOf[BoundedPriorityQueue[Object]])
+ kryo.register(classOf[EdgeDirection])
+
+ // This avoids a large number of hash table lookups.
+ kryo.setReferences(false)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
new file mode 100644
index 0000000000..5904aa3a28
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
@@ -0,0 +1,72 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.graphx.impl.{EdgePartitionBuilder, GraphImpl}
+
+/**
+ * Provides utilities for loading [[Graph]]s from files.
+ */
+object GraphLoader extends Logging {
+
+ /**
+ * Loads a graph from an edge list formatted file where each line contains two integers: a source
+ * id and a target id. Skips lines that begin with `#`.
+ *
+ * If desired the edges can be automatically oriented in the positive
+ * direction (source Id < target Id) by setting `canonicalOrientation` to
+ * true.
+ *
+ * @example Loads a file in the following format:
+ * {{{
+ * # Comment Line
+ * # Source Id <\t> Target Id
+ * 1 -5
+ * 1 2
+ * 2 7
+ * 1 8
+ * }}}
+ *
+ * @param sc SparkContext
+ * @param path the path to the file (e.g., /home/data/file or hdfs://file)
+ * @param canonicalOrientation whether to orient edges in the positive
+ * direction
+ * @param minEdgePartitions the number of partitions for the
+ * the edge RDD
+ */
+ def edgeListFile(
+ sc: SparkContext,
+ path: String,
+ canonicalOrientation: Boolean = false,
+ minEdgePartitions: Int = 1)
+ : Graph[Int, Int] =
+ {
+ val startTime = System.currentTimeMillis
+
+ // Parse the edge data table directly into edge partitions
+ val edges = sc.textFile(path, minEdgePartitions).mapPartitionsWithIndex { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[Int]
+ iter.foreach { line =>
+ if (!line.isEmpty && line(0) != '#') {
+ val lineArray = line.split("\\s+")
+ if (lineArray.length < 2) {
+ logWarning("Invalid line: " + line)
+ }
+ val srcId = lineArray(0).toLong
+ val dstId = lineArray(1).toLong
+ if (canonicalOrientation && srcId > dstId) {
+ builder.add(dstId, srcId, 1)
+ } else {
+ builder.add(srcId, dstId, 1)
+ }
+ }
+ }
+ Iterator((pid, builder.toEdgePartition))
+ }.cache()
+ edges.count()
+
+ logInfo("It took %d ms to load the edges".format(System.currentTimeMillis - startTime))
+
+ GraphImpl.fromEdgePartitions(edges, defaultVertexAttr = 1)
+ } // end of edgeListFile
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
new file mode 100644
index 0000000000..f10e63f059
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -0,0 +1,301 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.SparkException
+import org.apache.spark.graphx.lib._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
+ * efficient GraphX API. This class is implicitly constructed for each Graph object.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ */
+class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Serializable {
+
+ /** The number of edges in the graph. */
+ lazy val numEdges: Long = graph.edges.count()
+
+ /** The number of vertices in the graph. */
+ lazy val numVertices: Long = graph.vertices.count()
+
+ /**
+ * The in-degree of each vertex in the graph.
+ * @note Vertices with no in-edges are not returned in the resulting RDD.
+ */
+ lazy val inDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.In)
+
+ /**
+ * The out-degree of each vertex in the graph.
+ * @note Vertices with no out-edges are not returned in the resulting RDD.
+ */
+ lazy val outDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Out)
+
+ /**
+ * The degree of each vertex in the graph.
+ * @note Vertices with no edges are not returned in the resulting RDD.
+ */
+ lazy val degrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Either)
+
+ /**
+ * Computes the neighboring vertex degrees.
+ *
+ * @param edgeDirection the direction along which to collect neighboring vertex attributes
+ */
+ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
+ if (edgeDirection == EdgeDirection.In) {
+ graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
+ } else if (edgeDirection == EdgeDirection.Out) {
+ graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
+ } else { // EdgeDirection.Either
+ graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
+ }
+ }
+
+ /**
+ * Collect the neighbor vertex ids for each vertex.
+ *
+ * @param edgeDirection the direction along which to collect
+ * neighboring vertices
+ *
+ * @return the set of neighboring ids for each vertex
+ */
+ def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] = {
+ val nbrs =
+ if (edgeDirection == EdgeDirection.Either) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
+ reduceFunc = _ ++ _
+ )
+ } else if (edgeDirection == EdgeDirection.Out) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
+ reduceFunc = _ ++ _)
+ } else if (edgeDirection == EdgeDirection.In) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
+ reduceFunc = _ ++ _)
+ } else {
+ throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
+ "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
+ }
+ graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+ nbrsOpt.getOrElse(Array.empty[VertexID])
+ }
+ } // end of collectNeighborIds
+
+ /**
+ * Collect the neighbor vertex attributes for each vertex.
+ *
+ * @note This function could be highly inefficient on power-law
+ * graphs where high degree vertices may force a large ammount of
+ * information to be collected to a single location.
+ *
+ * @param edgeDirection the direction along which to collect
+ * neighboring vertices
+ *
+ * @return the vertex set of neighboring vertex attributes for each vertex
+ */
+ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] = {
+ val nbrs = graph.mapReduceTriplets[Array[(VertexID,VD)]](
+ edge => {
+ val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
+ val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
+ edgeDirection match {
+ case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
+ case EdgeDirection.In => Iterator(msgToDst)
+ case EdgeDirection.Out => Iterator(msgToSrc)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ },
+ (a, b) => a ++ b)
+
+ graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+ nbrsOpt.getOrElse(Array.empty[(VertexID, VD)])
+ }
+ } // end of collectNeighbor
+
+ /**
+ * Join the vertices with an RDD and then apply a function from the
+ * the vertex and RDD entry to a new vertex value. The input table
+ * should contain at most one entry for each vertex. If no entry is
+ * provided the map function is skipped and the old value is used.
+ *
+ * @tparam U the type of entry in the table of updates
+ * @param table the table to join with the vertices in the graph.
+ * The table should contain at most one entry for each vertex.
+ * @param mapFunc the function used to compute the new vertex
+ * values. The map function is invoked only for vertices with a
+ * corresponding entry in the table otherwise the old vertex value
+ * is used.
+ *
+ * @example This function is used to update the vertices with new
+ * values based on external data. For example we could add the out
+ * degree to each vertex record
+ *
+ * {{{
+ * val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph")
+ * .mapVertices(v => 0)
+ * val outDeg: RDD[(Int, Int)] = rawGraph.outDegrees
+ * val graph = rawGraph.leftJoinVertices[Int,Int](outDeg,
+ * (v, deg) => deg )
+ * }}}
+ *
+ */
+ def joinVertices[U: ClassTag](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD)
+ : Graph[VD, ED] = {
+ val uf = (id: VertexID, data: VD, o: Option[U]) => {
+ o match {
+ case Some(u) => mapFunc(id, data, u)
+ case None => data
+ }
+ }
+ graph.outerJoinVertices(table)(uf)
+ }
+
+ /**
+ * Filter the graph by computing some values to filter on, and applying the predicates.
+ *
+ * @param preprocess a function to compute new vertex and edge data before filtering
+ * @param epred edge pred to filter on after preprocess, see more details under
+ * [[org.apache.spark.graphx.Graph#subgraph]]
+ * @param vpred vertex pred to filter on after prerocess, see more details under
+ * [[org.apache.spark.graphx.Graph#subgraph]]
+ * @tparam VD2 vertex type the vpred operates on
+ * @tparam ED2 edge type the epred operates on
+ * @return a subgraph of the orginal graph, with its data unchanged
+ *
+ * @example This function can be used to filter the graph based on some property, without
+ * changing the vertex and edge values in your program. For example, we could remove the vertices
+ * in a graph with 0 outdegree
+ *
+ * {{{
+ * graph.filter(
+ * graph => {
+ * val degrees: VertexRDD[Int] = graph.outDegrees
+ * graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
+ * },
+ * vpred = (vid: VertexID, deg:Int) => deg > 0
+ * )
+ * }}}
+ *
+ */
+ def filter[VD2: ClassTag, ED2: ClassTag](
+ preprocess: Graph[VD, ED] => Graph[VD2, ED2],
+ epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true,
+ vpred: (VertexID, VD2) => Boolean = (v:VertexID, d:VD2) => true): Graph[VD, ED] = {
+ graph.mask(preprocess(graph).subgraph(epred, vpred))
+ }
+
+ /**
+ * Execute a Pregel-like iterative vertex-parallel abstraction. The
+ * user-defined vertex-program `vprog` is executed in parallel on
+ * each vertex receiving any inbound messages and computing a new
+ * value for the vertex. The `sendMsg` function is then invoked on
+ * all out-edges and is used to compute an optional message to the
+ * destination vertex. The `mergeMsg` function is a commutative
+ * associative function used to combine messages destined to the
+ * same vertex.
+ *
+ * On the first iteration all vertices receive the `initialMsg` and
+ * on subsequent iterations if a vertex does not receive a message
+ * then the vertex-program is not invoked.
+ *
+ * This function iterates until there are no remaining messages, or
+ * for `maxIterations` iterations.
+ *
+ * @tparam A the Pregel message type
+ *
+ * @param initialMsg the message each vertex will receive at the on
+ * the first iteration
+ *
+ * @param maxIterations the maximum number of iterations to run for
+ *
+ * @param activeDirection the direction of edges incident to a vertex that received a message in
+ * the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only
+ * out-edges of vertices that received a message in the previous round will run.
+ *
+ * @param vprog the user-defined vertex program which runs on each
+ * vertex and receives the inbound message and computes a new vertex
+ * value. On the first iteration the vertex program is invoked on
+ * all vertices and is passed the default message. On subsequent
+ * iterations the vertex program is only invoked on those vertices
+ * that receive messages.
+ *
+ * @param sendMsg a user supplied function that is applied to out
+ * edges of vertices that received messages in the current
+ * iteration
+ *
+ * @param mergeMsg a user supplied function that takes two incoming
+ * messages of type A and merges them into a single message of type
+ * A. ''This function must be commutative and associative and
+ * ideally the size of A should not increase.''
+ *
+ * @return the resulting graph at the end of the computation
+ *
+ */
+ def pregel[A: ClassTag](
+ initialMsg: A,
+ maxIterations: Int = Int.MaxValue,
+ activeDirection: EdgeDirection = EdgeDirection.Either)(
+ vprog: (VertexID, VD, A) => VD,
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)],
+ mergeMsg: (A, A) => A)
+ : Graph[VD, ED] = {
+ Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
+ }
+
+ /**
+ * Run a dynamic version of PageRank returning a graph with vertex attributes containing the
+ * PageRank and edge attributes containing the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergence]]
+ */
+ def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.runUntilConvergence(graph, tol, resetProb)
+ }
+
+ /**
+ * Run PageRank for a fixed number of iterations returning a graph with vertex attributes
+ * containing the PageRank and edge attributes the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.lib.PageRank$#run]]
+ */
+ def staticPageRank(numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.run(graph, numIter, resetProb)
+ }
+
+ /**
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]]
+ */
+ def connectedComponents(): Graph[VertexID, ED] = {
+ ConnectedComponents.run(graph)
+ }
+
+ /**
+ * Compute the number of triangles passing through each vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.TriangleCount$#run]]
+ */
+ def triangleCount(): Graph[Int, ED] = {
+ TriangleCount.run(graph)
+ }
+
+ /**
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.StronglyConnectedComponents$#run]]
+ */
+ def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] = {
+ StronglyConnectedComponents.run(graph, numIter)
+ }
+} // end of GraphOps
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
new file mode 100644
index 0000000000..6d2990a3f6
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
@@ -0,0 +1,103 @@
+package org.apache.spark.graphx
+
+/**
+ * Represents the way edges are assigned to edge partitions based on their source and destination
+ * vertex IDs.
+ */
+trait PartitionStrategy extends Serializable {
+ /** Returns the partition number for a given edge. */
+ def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID
+}
+
+/**
+ * Collection of built-in [[PartitionStrategy]] implementations.
+ */
+object PartitionStrategy {
+ /**
+ * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix,
+ * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication.
+ *
+ * Suppose we have a graph with 11 vertices that we want to partition
+ * over 9 machines. We can use the following sparse matrix representation:
+ *
+ * <pre>
+ * __________________________________
+ * v0 | P0 * | P1 | P2 * |
+ * v1 | **** | * | |
+ * v2 | ******* | ** | **** |
+ * v3 | ***** | * * | * |
+ * ----------------------------------
+ * v4 | P3 * | P4 *** | P5 ** * |
+ * v5 | * * | * | |
+ * v6 | * | ** | **** |
+ * v7 | * * * | * * | * |
+ * ----------------------------------
+ * v8 | P6 * | P7 * | P8 * *|
+ * v9 | * | * * | |
+ * v10 | * | ** | * * |
+ * v11 | * <-E | *** | ** |
+ * ----------------------------------
+ * </pre>
+ *
+ * The edge denoted by `E` connects `v11` with `v1` and is assigned to processor `P6`. To get the
+ * processor number we divide the matrix into `sqrt(numParts)` by `sqrt(numParts)` blocks. Notice
+ * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, P6)` or the last
+ * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be
+ * replicated to at most `2 * sqrt(numParts)` machines.
+ *
+ * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work
+ * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the
+ * vertex locations.
+ *
+ * One of the limitations of this approach is that the number of machines must either be a perfect
+ * square. We partially address this limitation by computing the machine assignment to the next
+ * largest perfect square and then mapping back down to the actual number of machines.
+ * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect square
+ * is used.
+ */
+ case object EdgePartition2D extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt
+ val mixingPrime: VertexID = 1125899906842597L
+ val col: PartitionID = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt
+ val row: PartitionID = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt
+ (col * ceilSqrtNumParts + row) % numParts
+ }
+ }
+
+ /**
+ * Assigns edges to partitions using only the source vertex ID, colocating edges with the same
+ * source.
+ */
+ case object EdgePartition1D extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val mixingPrime: VertexID = 1125899906842597L
+ (math.abs(src) * mixingPrime).toInt % numParts
+ }
+ }
+
+
+ /**
+ * Assigns edges to partitions by hashing the source and destination vertex IDs, resulting in a
+ * random vertex cut that colocates all same-direction edges between two vertices.
+ */
+ case object RandomVertexCut extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ math.abs((src, dst).hashCode()) % numParts
+ }
+ }
+
+
+ /**
+ * Assigns edges to partitions by hashing the source and destination vertex IDs in a canonical
+ * direction, resulting in a random vertex cut that colocates all edges between two vertices,
+ * regardless of direction.
+ */
+ case object CanonicalRandomVertexCut extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val lower = math.min(src, dst)
+ val higher = math.max(src, dst)
+ math.abs((lower, higher).hashCode()) % numParts
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
new file mode 100644
index 0000000000..fc18f7e785
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -0,0 +1,139 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+
+/**
+ * Implements a Pregel-like bulk-synchronous message-passing API.
+ *
+ * Unlike the original Pregel API, the GraphX Pregel API factors the sendMessage computation over
+ * edges, enables the message sending computation to read both vertex attributes, and constrains
+ * messages to the graph structure. These changes allow for substantially more efficient
+ * distributed execution while also exposing greater flexibility for graph-based computation.
+ *
+ * @example We can use the Pregel abstraction to implement PageRank:
+ * {{{
+ * val pagerankGraph: Graph[Double, Double] = graph
+ * // Associate the degree with each vertex
+ * .outerJoinVertices(graph.outDegrees) {
+ * (vid, vdata, deg) => deg.getOrElse(0)
+ * }
+ * // Set the weight on the edges based on the degree
+ * .mapTriplets(e => 1.0 / e.srcAttr)
+ * // Set the vertex attributes to the initial pagerank values
+ * .mapVertices((id, attr) => 1.0)
+ *
+ * def vertexProgram(id: VertexID, attr: Double, msgSum: Double): Double =
+ * resetProb + (1.0 - resetProb) * msgSum
+ * def sendMessage(id: VertexID, edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] =
+ * Iterator((edge.dstId, edge.srcAttr * edge.attr))
+ * def messageCombiner(a: Double, b: Double): Double = a + b
+ * val initialMessage = 0.0
+ * // Execute Pregel for a fixed number of iterations.
+ * Pregel(pagerankGraph, initialMessage, numIter)(
+ * vertexProgram, sendMessage, messageCombiner)
+ * }}}
+ *
+ */
+object Pregel {
+
+ /**
+ * Execute a Pregel-like iterative vertex-parallel abstraction. The
+ * user-defined vertex-program `vprog` is executed in parallel on
+ * each vertex receiving any inbound messages and computing a new
+ * value for the vertex. The `sendMsg` function is then invoked on
+ * all out-edges and is used to compute an optional message to the
+ * destination vertex. The `mergeMsg` function is a commutative
+ * associative function used to combine messages destined to the
+ * same vertex.
+ *
+ * On the first iteration all vertices receive the `initialMsg` and
+ * on subsequent iterations if a vertex does not receive a message
+ * then the vertex-program is not invoked.
+ *
+ * This function iterates until there are no remaining messages, or
+ * for `maxIterations` iterations.
+ *
+ * @tparam VD the vertex data type
+ * @tparam ED the edge data type
+ * @tparam A the Pregel message type
+ *
+ * @param graph the input graph.
+ *
+ * @param initialMsg the message each vertex will receive at the on
+ * the first iteration
+ *
+ * @param maxIterations the maximum number of iterations to run for
+ *
+ * @param activeDirection the direction of edges incident to a vertex that received a message in
+ * the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only
+ * out-edges of vertices that received a message in the previous round will run. The default is
+ * `EdgeDirection.Either`, which will run `sendMsg` on edges where either side received a message
+ * in the previous round. If this is `EdgeDirection.Both`, `sendMsg` will only run on edges where
+ * *both* vertices received a message.
+ *
+ * @param vprog the user-defined vertex program which runs on each
+ * vertex and receives the inbound message and computes a new vertex
+ * value. On the first iteration the vertex program is invoked on
+ * all vertices and is passed the default message. On subsequent
+ * iterations the vertex program is only invoked on those vertices
+ * that receive messages.
+ *
+ * @param sendMsg a user supplied function that is applied to out
+ * edges of vertices that received messages in the current
+ * iteration
+ *
+ * @param mergeMsg a user supplied function that takes two incoming
+ * messages of type A and merges them into a single message of type
+ * A. ''This function must be commutative and associative and
+ * ideally the size of A should not increase.''
+ *
+ * @return the resulting graph at the end of the computation
+ *
+ */
+ def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
+ (graph: Graph[VD, ED],
+ initialMsg: A,
+ maxIterations: Int = Int.MaxValue,
+ activeDirection: EdgeDirection = EdgeDirection.Either)
+ (vprog: (VertexID, VD, A) => VD,
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ mergeMsg: (A, A) => A)
+ : Graph[VD, ED] =
+ {
+ var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
+ // compute the messages
+ var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
+ var activeMessages = messages.count()
+ // Loop
+ var prevG: Graph[VD, ED] = null
+ var i = 0
+ while (activeMessages > 0 && i < maxIterations) {
+ // Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
+ val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
+ // Update the graph with the new vertices.
+ prevG = g
+ g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ g.cache()
+
+ val oldMessages = messages
+ // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
+ // get to send messages. We must cache messages so it can be materialized on the next line,
+ // allowing us to uncache the previous iteration.
+ messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
+ // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
+ // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
+ // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
+ activeMessages = messages.count()
+ // Unpersist the RDDs hidden by newly-materialized RDDs
+ oldMessages.unpersist(blocking=false)
+ newVerts.unpersist(blocking=false)
+ prevG.unpersistVertices(blocking=false)
+ // count the iteration
+ i += 1
+ }
+
+ g
+ } // end of apply
+
+} // end of class Pregel
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
new file mode 100644
index 0000000000..9a95364cb1
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark.storage.StorageLevel
+
+import org.apache.spark.graphx.impl.MsgRDDFunctions
+import org.apache.spark.graphx.impl.VertexPartition
+
+/**
+ * Extends `RDD[(VertexID, VD)]` by ensuring that there is only one entry for each vertex and by
+ * pre-indexing the entries for fast, efficient joins. Two VertexRDDs with the same index can be
+ * joined efficiently. All operations except [[reindex]] preserve the index. To construct a
+ * `VertexRDD`, use the [[org.apache.spark.graphx.VertexRDD$ VertexRDD object]].
+ *
+ * @example Construct a `VertexRDD` from a plain RDD:
+ * {{{
+ * // Construct an initial vertex set
+ * val someData: RDD[(VertexID, SomeType)] = loadData(someFile)
+ * val vset = VertexRDD(someData)
+ * // If there were redundant values in someData we would use a reduceFunc
+ * val vset2 = VertexRDD(someData, reduceFunc)
+ * // Finally we can use the VertexRDD to index another dataset
+ * val otherData: RDD[(VertexID, OtherType)] = loadData(otherFile)
+ * val vset3 = vset2.innerJoin(otherData) { (vid, a, b) => b }
+ * // Now we can construct very fast joins between the two sets
+ * val vset4: VertexRDD[(SomeType, OtherType)] = vset.leftJoin(vset3)
+ * }}}
+ *
+ * @tparam VD the vertex attribute associated with each vertex in the set.
+ */
+class VertexRDD[@specialized VD: ClassTag](
+ val partitionsRDD: RDD[VertexPartition[VD]])
+ extends RDD[(VertexID, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ require(partitionsRDD.partitioner.isDefined)
+
+ partitionsRDD.setName("VertexRDD")
+
+ /**
+ * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting
+ * VertexRDD will be based on a different index and can no longer be quickly joined with this RDD.
+ */
+ def reindex(): VertexRDD[VD] = new VertexRDD(partitionsRDD.map(_.reindex()))
+
+ override val partitioner = partitionsRDD.partitioner
+
+ override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
+
+ override protected def getPreferredLocations(s: Partition): Seq[String] =
+ partitionsRDD.preferredLocations(s)
+
+ override def persist(newLevel: StorageLevel): VertexRDD[VD] = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def persist(): VertexRDD[VD] = persist(StorageLevel.MEMORY_ONLY)
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def cache(): VertexRDD[VD] = persist()
+
+ override def unpersist(blocking: Boolean = true): VertexRDD[VD] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ /** The number of vertices in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_.size).reduce(_ + _)
+ }
+
+ /**
+ * Provides the `RDD[(VertexID, VD)]` equivalent output.
+ */
+ override def compute(part: Partition, context: TaskContext): Iterator[(VertexID, VD)] = {
+ firstParent[VertexPartition[VD]].iterator(part, context).next.iterator
+ }
+
+ /**
+ * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD.
+ */
+ private[graphx] def mapVertexPartitions[VD2: ClassTag](f: VertexPartition[VD] => VertexPartition[VD2])
+ : VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true)
+ new VertexRDD(newPartitionsRDD)
+ }
+
+
+ /**
+ * Restricts the vertex set to the set of vertices satisfying the given predicate. This operation
+ * preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask
+ * rather than allocating new memory.
+ *
+ * @param pred the user defined predicate, which takes a tuple to conform to the
+ * `RDD[(VertexID, VD)]` interface
+ */
+ override def filter(pred: Tuple2[VertexID, VD] => Boolean): VertexRDD[VD] =
+ this.mapVertexPartitions(_.filter(Function.untupled(pred)))
+
+ /**
+ * Maps each vertex attribute, preserving the index.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each value in the RDD
+ * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
+ * original VertexRDD
+ */
+ def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map((vid, attr) => f(attr)))
+
+ /**
+ * Maps each vertex attribute, additionally supplying the vertex ID.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each ID-value pair in the RDD
+ * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
+ * original VertexRDD. The resulting VertexRDD retains the same index.
+ */
+ def mapValues[VD2: ClassTag](f: (VertexID, VD) => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map(f))
+
+ /**
+ * Hides vertices that are the same between `this` and `other`; for vertices that are different,
+ * keeps the values from `other`.
+ */
+ def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.diff(otherPart))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Left joins this RDD with another VertexRDD with the same index. This function will fail if both
+ * VertexRDDs do not share the same index. The resulting vertex set contains an entry for each
+ * vertex in `this`. If `other` is missing any vertex in this VertexRDD, `f` is passed `None`.
+ *
+ * @tparam VD2 the attribute type of the other VertexRDD
+ * @tparam VD3 the attribute type of the resulting VertexRDD
+ *
+ * @param other the other VertexRDD with which to join.
+ * @param f the function mapping a vertex id and its attributes in this and the other vertex set
+ * to a new vertex attribute.
+ * @return a VertexRDD containing the results of `f`
+ */
+ def leftZipJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: VertexRDD[VD2])(f: (VertexID, VD, Option[VD2]) => VD3): VertexRDD[VD3] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.leftJoin(otherPart)(f))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
+ * backed by a VertexRDD with the same index then the efficient [[leftZipJoin]] implementation is
+ * used. The resulting VertexRDD contains an entry for each vertex in `this`. If `other` is
+ * missing any vertex in this VertexRDD, `f` is passed `None`. If there are duplicates, the vertex
+ * is picked arbitrarily.
+ *
+ * @tparam VD2 the attribute type of the other VertexRDD
+ * @tparam VD3 the attribute type of the resulting VertexRDD
+ *
+ * @param other the other VertexRDD with which to join
+ * @param f the function mapping a vertex id and its attributes in this and the other vertex set
+ * to a new vertex attribute.
+ * @return a VertexRDD containing all the vertices in this VertexRDD with the attributes emitted
+ * by `f`.
+ */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: RDD[(VertexID, VD2)])
+ (f: (VertexID, VD, Option[VD2]) => VD3)
+ : VertexRDD[VD3] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient leftZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ leftZipJoin(other)(f)
+ case _ =>
+ new VertexRDD[VD3](
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true)
+ { (part, msgs) =>
+ val vertexPartition: VertexPartition[VD] = part.next()
+ Iterator(vertexPartition.leftJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ /**
+ * Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See
+ * [[innerJoin]] for the behavior of the join.
+ */
+ def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U])
+ (f: (VertexID, VD, U) => VD2): VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.innerJoin(otherPart)(f))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
+ * backed by a VertexRDD with the same index then the efficient [[innerZipJoin]] implementation is
+ * used.
+ *
+ * @param other an RDD containing vertices to join. If there are multiple entries for the same
+ * vertex, one is picked arbitrarily. Use [[aggregateUsingIndex]] to merge multiple entries.
+ * @param f the join function applied to corresponding values of `this` and `other`
+ * @return a VertexRDD co-indexed with `this`, containing only vertices that appear in both `this`
+ * and `other`, with values supplied by `f`
+ */
+ def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexID, U)])
+ (f: (VertexID, VD, U) => VD2): VertexRDD[VD2] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient innerZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ innerZipJoin(other)(f)
+ case _ =>
+ new VertexRDD(
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true)
+ { (part, msgs) =>
+ val vertexPartition: VertexPartition[VD] = part.next()
+ Iterator(vertexPartition.innerJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ /**
+ * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a
+ * VertexRDD co-indexed with `this`.
+ *
+ * @param messages an RDD containing messages to aggregate, where each message is a pair of its
+ * target vertex ID and the message data
+ * @param reduceFunc the associative aggregation function for merging messages to the same vertex
+ * @return a VertexRDD co-indexed with `this`, containing only vertices that received messages.
+ * For those vertices, their values are the result of applying `reduceFunc` to all received
+ * messages.
+ */
+ def aggregateUsingIndex[VD2: ClassTag](
+ messages: RDD[(VertexID, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
+ val shuffled = MsgRDDFunctions.partitionForAggregation(messages, this.partitioner.get)
+ val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
+ val vertexPartition: VertexPartition[VD] = thisIter.next()
+ Iterator(vertexPartition.aggregateUsingIndex(msgIter, reduceFunc))
+ }
+ new VertexRDD[VD2](parts)
+ }
+
+} // end of VertexRDD
+
+
+/**
+ * The VertexRDD singleton is used to construct VertexRDDs.
+ */
+object VertexRDD {
+
+ /**
+ * Construct a `VertexRDD` from an RDD of vertex-attribute pairs.
+ * Duplicate entries are removed arbitrarily.
+ *
+ * @tparam VD the vertex attribute type
+ *
+ * @param rdd the collection of vertex-attribute pairs
+ */
+ def apply[VD: ClassTag](rdd: RDD[(VertexID, VD)]): VertexRDD[VD] = {
+ val partitioned: RDD[(VertexID, VD)] = rdd.partitioner match {
+ case Some(p) => rdd
+ case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size))
+ }
+ val vertexPartitions = partitioned.mapPartitions(
+ iter => Iterator(VertexPartition(iter)),
+ preservesPartitioning = true)
+ new VertexRDD(vertexPartitions)
+ }
+
+ /**
+ * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs, merging duplicates using
+ * `mergeFunc`.
+ *
+ * @tparam VD the vertex attribute type
+ *
+ * @param rdd the collection of vertex-attribute pairs
+ * @param mergeFunc the associative, commutative merge function.
+ */
+ def apply[VD: ClassTag](rdd: RDD[(VertexID, VD)], mergeFunc: (VD, VD) => VD): VertexRDD[VD] = {
+ val partitioned: RDD[(VertexID, VD)] = rdd.partitioner match {
+ case Some(p) => rdd
+ case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size))
+ }
+ val vertexPartitions = partitioned.mapPartitions(
+ iter => Iterator(VertexPartition(iter)),
+ preservesPartitioning = true)
+ new VertexRDD(vertexPartitions)
+ }
+
+ /**
+ * Constructs a VertexRDD from the vertex IDs in `vids`, taking attributes from `rdd` and using
+ * `defaultVal` otherwise.
+ */
+ def apply[VD: ClassTag](vids: RDD[VertexID], rdd: RDD[(VertexID, VD)], defaultVal: VD)
+ : VertexRDD[VD] = {
+ VertexRDD(vids.map(vid => (vid, defaultVal))).leftJoin(rdd) { (vid, default, value) =>
+ value.getOrElse(default)
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
new file mode 100644
index 0000000000..ee95ead3ad
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -0,0 +1,220 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+
+/**
+ * A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
+ * clustered by src.
+ *
+ * @param srcIds the source vertex id of each edge
+ * @param dstIds the destination vertex id of each edge
+ * @param data the attribute associated with each edge
+ * @param index a clustered index on source vertex id
+ * @tparam ED the edge attribute type.
+ */
+private[graphx]
+class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag](
+ val srcIds: Array[VertexID],
+ val dstIds: Array[VertexID],
+ val data: Array[ED],
+ val index: PrimitiveKeyOpenHashMap[VertexID, Int]) extends Serializable {
+
+ /**
+ * Reverse all the edges in this partition.
+ *
+ * @return a new edge partition with all edges reversed.
+ */
+ def reverse: EdgePartition[ED] = {
+ val builder = new EdgePartitionBuilder(size)
+ for (e <- iterator) {
+ builder.add(e.dstId, e.srcId, e.attr)
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * Construct a new edge partition by applying the function f to all
+ * edges in this partition.
+ *
+ * @param f a function from an edge to a new attribute
+ * @tparam ED2 the type of the new attribute
+ * @return a new edge partition with the result of the function `f`
+ * applied to each edge
+ */
+ def map[ED2: ClassTag](f: Edge[ED] => ED2): EdgePartition[ED2] = {
+ val newData = new Array[ED2](data.size)
+ val edge = new Edge[ED]()
+ val size = data.size
+ var i = 0
+ while (i < size) {
+ edge.srcId = srcIds(i)
+ edge.dstId = dstIds(i)
+ edge.attr = data(i)
+ newData(i) = f(edge)
+ i += 1
+ }
+ new EdgePartition(srcIds, dstIds, newData, index)
+ }
+
+ /**
+ * Construct a new edge partition by using the edge attributes
+ * contained in the iterator.
+ *
+ * @note The input iterator should return edge attributes in the
+ * order of the edges returned by `EdgePartition.iterator` and
+ * should return attributes equal to the number of edges.
+ *
+ * @param f a function from an edge to a new attribute
+ * @tparam ED2 the type of the new attribute
+ * @return a new edge partition with the result of the function `f`
+ * applied to each edge
+ */
+ def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = {
+ val newData = new Array[ED2](data.size)
+ var i = 0
+ while (iter.hasNext) {
+ newData(i) = iter.next()
+ i += 1
+ }
+ assert(newData.size == i)
+ new EdgePartition(srcIds, dstIds, newData, index)
+ }
+
+ /**
+ * Apply the function f to all edges in this partition.
+ *
+ * @param f an external state mutating user defined function.
+ */
+ def foreach(f: Edge[ED] => Unit) {
+ iterator.foreach(f)
+ }
+
+ /**
+ * Merge all the edges with the same src and dest id into a single
+ * edge using the `merge` function
+ *
+ * @param merge a commutative associative merge operation
+ * @return a new edge partition without duplicate edges
+ */
+ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED] = {
+ val builder = new EdgePartitionBuilder[ED]
+ var currSrcId: VertexID = null.asInstanceOf[VertexID]
+ var currDstId: VertexID = null.asInstanceOf[VertexID]
+ var currAttr: ED = null.asInstanceOf[ED]
+ var i = 0
+ while (i < size) {
+ if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
+ currAttr = merge(currAttr, data(i))
+ } else {
+ if (i > 0) {
+ builder.add(currSrcId, currDstId, currAttr)
+ }
+ currSrcId = srcIds(i)
+ currDstId = dstIds(i)
+ currAttr = data(i)
+ }
+ i += 1
+ }
+ if (size > 0) {
+ builder.add(currSrcId, currDstId, currAttr)
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * Apply `f` to all edges present in both `this` and `other` and return a new EdgePartition
+ * containing the resulting edges.
+ *
+ * If there are multiple edges with the same src and dst in `this`, `f` will be invoked once for
+ * each edge, but each time it may be invoked on any corresponding edge in `other`.
+ *
+ * If there are multiple edges with the same src and dst in `other`, `f` will only be invoked
+ * once.
+ */
+ def innerJoin[ED2: ClassTag, ED3: ClassTag]
+ (other: EdgePartition[ED2])
+ (f: (VertexID, VertexID, ED, ED2) => ED3): EdgePartition[ED3] = {
+ val builder = new EdgePartitionBuilder[ED3]
+ var i = 0
+ var j = 0
+ // For i = index of each edge in `this`...
+ while (i < size && j < other.size) {
+ val srcId = this.srcIds(i)
+ val dstId = this.dstIds(i)
+ // ... forward j to the index of the corresponding edge in `other`, and...
+ while (j < other.size && other.srcIds(j) < srcId) { j += 1 }
+ if (j < other.size && other.srcIds(j) == srcId) {
+ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
+ if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
+ // ... run `f` on the matching edge
+ builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
+ }
+ }
+ i += 1
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * The number of edges in this partition
+ *
+ * @return size of the partition
+ */
+ def size: Int = srcIds.size
+
+ /** The number of unique source vertices in the partition. */
+ def indexSize: Int = index.size
+
+ /**
+ * Get an iterator over the edges in this partition.
+ *
+ * @return an iterator over edges in the partition
+ */
+ def iterator = new Iterator[Edge[ED]] {
+ private[this] val edge = new Edge[ED]
+ private[this] var pos = 0
+
+ override def hasNext: Boolean = pos < EdgePartition.this.size
+
+ override def next(): Edge[ED] = {
+ edge.srcId = srcIds(pos)
+ edge.dstId = dstIds(pos)
+ edge.attr = data(pos)
+ pos += 1
+ edge
+ }
+ }
+
+ /**
+ * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
+ * iterator is generated using an index scan, so it is efficient at skipping edges that don't
+ * match srcIdPred.
+ */
+ def indexIterator(srcIdPred: VertexID => Boolean): Iterator[Edge[ED]] =
+ index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
+
+ /**
+ * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
+ * cluster must start at position `index`.
+ */
+ private def clusterIterator(srcId: VertexID, index: Int) = new Iterator[Edge[ED]] {
+ private[this] val edge = new Edge[ED]
+ private[this] var pos = index
+
+ override def hasNext: Boolean = {
+ pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
+ }
+
+ override def next(): Edge[ED] = {
+ assert(srcIds(pos) == srcId)
+ edge.srcId = srcIds(pos)
+ edge.dstId = dstIds(pos)
+ edge.attr = data(pos)
+ pos += 1
+ edge
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
new file mode 100644
index 0000000000..9d072f9335
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -0,0 +1,45 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+import scala.util.Sorting
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.PrimitiveVector
+
+private[graphx]
+class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) {
+ var edges = new PrimitiveVector[Edge[ED]](size)
+
+ /** Add a new edge to the partition. */
+ def add(src: VertexID, dst: VertexID, d: ED) {
+ edges += Edge(src, dst, d)
+ }
+
+ def toEdgePartition: EdgePartition[ED] = {
+ val edgeArray = edges.trim().array
+ Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
+ val srcIds = new Array[VertexID](edgeArray.size)
+ val dstIds = new Array[VertexID](edgeArray.size)
+ val data = new Array[ED](edgeArray.size)
+ val index = new PrimitiveKeyOpenHashMap[VertexID, Int]
+ // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
+ // adding them to the index
+ if (edgeArray.length > 0) {
+ index.update(srcIds(0), 0)
+ var currSrcId: VertexID = srcIds(0)
+ var i = 0
+ while (i < edgeArray.size) {
+ srcIds(i) = edgeArray(i).srcId
+ dstIds(i) = edgeArray(i).dstId
+ data(i) = edgeArray(i).attr
+ if (edgeArray(i).srcId != currSrcId) {
+ currSrcId = edgeArray(i).srcId
+ index.update(currSrcId, i)
+ }
+ i += 1
+ }
+ }
+ new EdgePartition(srcIds, dstIds, data, index)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
new file mode 100644
index 0000000000..bad840f1cd
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
@@ -0,0 +1,42 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+
+/**
+ * The Iterator type returned when constructing edge triplets. This class technically could be
+ * an anonymous class in GraphImpl.triplets, but we name it here explicitly so it is easier to
+ * debug / profile.
+ */
+private[impl]
+class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](
+ val vidToIndex: VertexIdToIndexMap,
+ val vertexArray: Array[VD],
+ val edgePartition: EdgePartition[ED])
+ extends Iterator[EdgeTriplet[VD, ED]] {
+
+ // Current position in the array.
+ private var pos = 0
+
+ // A triplet object that this iterator.next() call returns. We reuse this object to avoid
+ // allocating too many temporary Java objects.
+ private val triplet = new EdgeTriplet[VD, ED]
+
+ private val vmap = new PrimitiveKeyOpenHashMap[VertexID, VD](vidToIndex, vertexArray)
+
+ override def hasNext: Boolean = pos < edgePartition.size
+
+ override def next() = {
+ triplet.srcId = edgePartition.srcIds(pos)
+ // assert(vmap.containsKey(e.src.id))
+ triplet.srcAttr = vmap(triplet.srcId)
+ triplet.dstId = edgePartition.dstIds(pos)
+ // assert(vmap.containsKey(e.dst.id))
+ triplet.dstAttr = vmap(triplet.dstId)
+ triplet.attr = edgePartition.data(pos)
+ pos += 1
+ triplet
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
new file mode 100644
index 0000000000..56d1d9efea
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -0,0 +1,379 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.util.collection.PrimitiveVector
+import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl._
+import org.apache.spark.graphx.impl.MsgRDDFunctions._
+import org.apache.spark.graphx.util.BytecodeUtils
+import org.apache.spark.rdd.{ShuffledRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.ClosureCleaner
+
+
+/**
+ * A graph that supports computation on graphs.
+ *
+ * Graphs are represented using two classes of data: vertex-partitioned and
+ * edge-partitioned. `vertices` contains vertex attributes, which are vertex-partitioned. `edges`
+ * contains edge attributes, which are edge-partitioned. For operations on vertex neighborhoods,
+ * vertex attributes are replicated to the edge partitions where they appear as sources or
+ * destinations. `routingTable` stores the routing information for shipping vertex attributes to
+ * edge partitions. `replicatedVertexView` stores a view of the replicated vertex attributes created
+ * using the routing table.
+ */
+class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
+ @transient val vertices: VertexRDD[VD],
+ @transient val edges: EdgeRDD[ED],
+ @transient val routingTable: RoutingTable,
+ @transient val replicatedVertexView: ReplicatedVertexView[VD])
+ extends Graph[VD, ED] with Serializable {
+
+ /** Default constructor is provided to support serialization */
+ protected def this() = this(null, null, null, null)
+
+ /** Return a RDD that brings edges together with their source and destination vertices. */
+ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = {
+ val vdTag = classTag[VD]
+ val edTag = classTag[ED]
+ edges.partitionsRDD.zipPartitions(
+ replicatedVertexView.get(true, true), true) { (ePartIter, vPartIter) =>
+ val (pid, ePart) = ePartIter.next()
+ val (_, vPart) = vPartIter.next()
+ new EdgeTripletIterator(vPart.index, vPart.values, ePart)(vdTag, edTag)
+ }
+ }
+
+ override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
+ vertices.persist(newLevel)
+ edges.persist(newLevel)
+ this
+ }
+
+ override def cache(): Graph[VD, ED] = persist(StorageLevel.MEMORY_ONLY)
+
+ override def unpersistVertices(blocking: Boolean = true): Graph[VD, ED] = {
+ vertices.unpersist(blocking)
+ replicatedVertexView.unpersist(blocking)
+ this
+ }
+
+ override def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] = {
+ val numPartitions = edges.partitions.size
+ val edTag = classTag[ED]
+ val newEdges = new EdgeRDD(edges.map { e =>
+ val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions)
+
+ // Should we be using 3-tuple or an optimized class
+ new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
+ }
+ .partitionBy(new HashPartitioner(numPartitions))
+ .mapPartitionsWithIndex( { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]()(edTag)
+ iter.foreach { message =>
+ val data = message.data
+ builder.add(data._1, data._2, data._3)
+ }
+ val edgePartition = builder.toEdgePartition
+ Iterator((pid, edgePartition))
+ }, preservesPartitioning = true).cache())
+ GraphImpl(vertices, newEdges)
+ }
+
+ override def reverse: Graph[VD, ED] = {
+ val newETable = edges.mapEdgePartitions((pid, part) => part.reverse)
+ new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
+ }
+
+ override def mapVertices[VD2: ClassTag](f: (VertexID, VD) => VD2): Graph[VD2, ED] = {
+ if (classTag[VD] equals classTag[VD2]) {
+ // The map preserves type, so we can use incremental replication
+ val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
+ val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts)
+ val newReplicatedVertexView = new ReplicatedVertexView[VD2](
+ changedVerts, edges, routingTable,
+ Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]]))
+ new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView)
+ } else {
+ // The map does not preserve type, so we must re-replicate all vertices
+ GraphImpl(vertices.mapVertexPartitions(_.map(f)), edges, routingTable)
+ }
+ }
+
+ override def mapEdges[ED2: ClassTag](
+ f: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+ val newETable = edges.mapEdgePartitions((pid, part) => part.map(f(pid, part.iterator)))
+ new GraphImpl(vertices, newETable , routingTable, replicatedVertexView)
+ }
+
+ override def mapTriplets[ED2: ClassTag](
+ f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+ val newEdgePartitions =
+ edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) {
+ (ePartIter, vTableReplicatedIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
+ val (vPid, vPart) = vTableReplicatedIter.next()
+ assert(!vTableReplicatedIter.hasNext)
+ assert(ePid == vPid)
+ val et = new EdgeTriplet[VD, ED]
+ val inputIterator = edgePartition.iterator.map { e =>
+ et.set(e)
+ et.srcAttr = vPart(e.srcId)
+ et.dstAttr = vPart(e.dstId)
+ et
+ }
+ // Apply the user function to the vertex partition
+ val outputIter = f(ePid, inputIterator)
+ // Consume the iterator to update the edge attributes
+ val newEdgePartition = edgePartition.map(outputIter)
+ Iterator((ePid, newEdgePartition))
+ }
+ new GraphImpl(vertices, new EdgeRDD(newEdgePartitions), routingTable, replicatedVertexView)
+ }
+
+ override def subgraph(
+ epred: EdgeTriplet[VD, ED] => Boolean = x => true,
+ vpred: (VertexID, VD) => Boolean = (a, b) => true): Graph[VD, ED] = {
+ // Filter the vertices, reusing the partitioner and the index from this graph
+ val newVerts = vertices.mapVertexPartitions(_.filter(vpred))
+
+ // Filter the edges
+ val edTag = classTag[ED]
+ val newEdges = new EdgeRDD[ED](triplets.filter { et =>
+ vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)
+ }.mapPartitionsWithIndex( { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]()(edTag)
+ iter.foreach { et => builder.add(et.srcId, et.dstId, et.attr) }
+ val edgePartition = builder.toEdgePartition
+ Iterator((pid, edgePartition))
+ }, preservesPartitioning = true)).cache()
+
+ // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been
+ // removed will be ignored, since we only refer to replicated vertices when they are adjacent to
+ // an edge.
+ new GraphImpl(newVerts, newEdges, new RoutingTable(newEdges, newVerts), replicatedVertexView)
+ } // end of subgraph
+
+ override def mask[VD2: ClassTag, ED2: ClassTag] (
+ other: Graph[VD2, ED2]): Graph[VD, ED] = {
+ val newVerts = vertices.innerJoin(other.vertices) { (vid, v, w) => v }
+ val newEdges = edges.innerJoin(other.edges) { (src, dst, v, w) => v }
+ // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been
+ // removed will be ignored, since we only refer to replicated vertices when they are adjacent to
+ // an edge.
+ new GraphImpl(newVerts, newEdges, routingTable, replicatedVertexView)
+ }
+
+ override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = {
+ ClosureCleaner.clean(merge)
+ val newETable = edges.mapEdgePartitions((pid, part) => part.groupEdges(merge))
+ new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+ // Lower level transformation methods
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ override def mapReduceTriplets[A: ClassTag](
+ mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ reduceFunc: (A, A) => A,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) = {
+
+ ClosureCleaner.clean(mapFunc)
+ ClosureCleaner.clean(reduceFunc)
+
+ // For each vertex, replicate its attribute only to partitions where it is
+ // in the relevant position in an edge.
+ val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
+ val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
+ val vs = activeSetOpt match {
+ case Some((activeSet, _)) =>
+ replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr, activeSet)
+ case None =>
+ replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr)
+ }
+ val activeDirectionOpt = activeSetOpt.map(_._2)
+
+ // Map and combine.
+ val preAgg = edges.partitionsRDD.zipPartitions(vs, true) { (ePartIter, vPartIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
+ val (vPid, vPart) = vPartIter.next()
+ assert(!vPartIter.hasNext)
+ assert(ePid == vPid)
+ // Choose scan method
+ val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
+ val edgeIter = activeDirectionOpt match {
+ case Some(EdgeDirection.Both) =>
+ if (activeFraction < 0.8) {
+ edgePartition.indexIterator(srcVertexID => vPart.isActive(srcVertexID))
+ .filter(e => vPart.isActive(e.dstId))
+ } else {
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId))
+ }
+ case Some(EdgeDirection.Either) =>
+ // TODO: Because we only have a clustered index on the source vertex ID, we can't filter
+ // the index here. Instead we have to scan all edges and then do the filter.
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId) || vPart.isActive(e.dstId))
+ case Some(EdgeDirection.Out) =>
+ if (activeFraction < 0.8) {
+ edgePartition.indexIterator(srcVertexID => vPart.isActive(srcVertexID))
+ } else {
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId))
+ }
+ case Some(EdgeDirection.In) =>
+ edgePartition.iterator.filter(e => vPart.isActive(e.dstId))
+ case _ => // None
+ edgePartition.iterator
+ }
+
+ // Scan edges and run the map function
+ val et = new EdgeTriplet[VD, ED]
+ val mapOutputs = edgeIter.flatMap { e =>
+ et.set(e)
+ if (mapUsesSrcAttr) {
+ et.srcAttr = vPart(e.srcId)
+ }
+ if (mapUsesDstAttr) {
+ et.dstAttr = vPart(e.dstId)
+ }
+ mapFunc(et)
+ }
+ // Note: This doesn't allow users to send messages to arbitrary vertices.
+ vPart.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
+ }
+
+ // do the final reduction reusing the index map
+ vertices.aggregateUsingIndex(preAgg, reduceFunc)
+ } // end of mapReduceTriplets
+
+ override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
+ (other: RDD[(VertexID, U)])
+ (updateF: (VertexID, VD, Option[U]) => VD2): Graph[VD2, ED] =
+ {
+ if (classTag[VD] equals classTag[VD2]) {
+ // updateF preserves type, so we can use incremental replication
+ val newVerts = vertices.leftJoin(other)(updateF)
+ val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts)
+ val newReplicatedVertexView = new ReplicatedVertexView[VD2](
+ changedVerts, edges, routingTable,
+ Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]]))
+ new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView)
+ } else {
+ // updateF does not preserve type, so we must re-replicate all vertices
+ val newVerts = vertices.leftJoin(other)(updateF)
+ GraphImpl(newVerts, edges, routingTable)
+ }
+ }
+
+ /** Test whether the closure accesses the the attribute with name `attrName`. */
+ private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = {
+ try {
+ BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName)
+ } catch {
+ case _: ClassNotFoundException => true // if we don't know, be conservative
+ }
+ }
+} // end of class GraphImpl
+
+
+object GraphImpl {
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] =
+ {
+ fromEdgeRDD(createEdgeRDD(edges), defaultVertexAttr)
+ }
+
+ def fromEdgePartitions[VD: ClassTag, ED: ClassTag](
+ edgePartitions: RDD[(PartitionID, EdgePartition[ED])],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] = {
+ fromEdgeRDD(new EdgeRDD(edgePartitions), defaultVertexAttr)
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: RDD[(VertexID, VD)],
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] =
+ {
+ val edgeRDD = createEdgeRDD(edges).cache()
+
+ // Get the set of all vids
+ val partitioner = Partitioner.defaultPartitioner(vertices)
+ val vPartitioned = vertices.partitionBy(partitioner)
+ val vidsFromEdges = collectVertexIDsFromEdges(edgeRDD, partitioner)
+ val vids = vPartitioned.zipPartitions(vidsFromEdges) { (vertexIter, vidsFromEdgesIter) =>
+ vertexIter.map(_._1) ++ vidsFromEdgesIter.map(_._1)
+ }
+
+ val vertexRDD = VertexRDD(vids, vPartitioned, defaultVertexAttr)
+
+ GraphImpl(vertexRDD, edgeRDD)
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: VertexRDD[VD],
+ edges: EdgeRDD[ED]): GraphImpl[VD, ED] = {
+ // Cache RDDs that are referenced multiple times
+ edges.cache()
+
+ GraphImpl(vertices, edges, new RoutingTable(edges, vertices))
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: VertexRDD[VD],
+ edges: EdgeRDD[ED],
+ routingTable: RoutingTable): GraphImpl[VD, ED] = {
+ // Cache RDDs that are referenced multiple times. `routingTable` is cached by default, so we
+ // don't cache it explicitly.
+ vertices.cache()
+ edges.cache()
+
+ new GraphImpl(
+ vertices, edges, routingTable, new ReplicatedVertexView(vertices, edges, routingTable))
+ }
+
+ /**
+ * Create the edge RDD, which is much more efficient for Java heap storage than the normal edges
+ * data structure (RDD[(VertexID, VertexID, ED)]).
+ *
+ * The edge RDD contains multiple partitions, and each partition contains only one RDD key-value
+ * pair: the key is the partition id, and the value is an EdgePartition object containing all the
+ * edges in a partition.
+ */
+ private def createEdgeRDD[ED: ClassTag](
+ edges: RDD[Edge[ED]]): EdgeRDD[ED] = {
+ val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]
+ iter.foreach { e =>
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ Iterator((pid, builder.toEdgePartition))
+ }
+ new EdgeRDD(edgePartitions)
+ }
+
+ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag](
+ edges: EdgeRDD[ED],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] = {
+ edges.cache()
+ // Get the set of all vids
+ val vids = collectVertexIDsFromEdges(edges, new HashPartitioner(edges.partitions.size))
+ // Create the VertexRDD.
+ val vertices = VertexRDD(vids.mapValues(x => defaultVertexAttr))
+ GraphImpl(vertices, edges)
+ }
+
+ /** Collects all vids mentioned in edges and partitions them by partitioner. */
+ private def collectVertexIDsFromEdges(
+ edges: EdgeRDD[_],
+ partitioner: Partitioner): RDD[(VertexID, Int)] = {
+ // TODO: Consider doing map side distinct before shuffle.
+ new ShuffledRDD[VertexID, Int, (VertexID, Int)](
+ edges.collectVertexIDs.map(vid => (vid, 0)), partitioner)
+ .setSerializer(classOf[VertexIDMsgSerializer].getName)
+ }
+} // end of object GraphImpl
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
new file mode 100644
index 0000000000..05508ff716
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -0,0 +1,98 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.Partitioner
+import org.apache.spark.graphx.{PartitionID, VertexID}
+import org.apache.spark.rdd.{ShuffledRDD, RDD}
+
+
+private[graphx]
+class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
+ @transient var partition: PartitionID,
+ var vid: VertexID,
+ var data: T)
+ extends Product2[PartitionID, (VertexID, T)] with Serializable {
+
+ override def _1 = partition
+
+ override def _2 = (vid, data)
+
+ override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
+}
+
+
+/**
+ * A message used to send a specific value to a partition.
+ * @param partition index of the target partition.
+ * @param data value to send
+ */
+private[graphx]
+class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T](
+ @transient var partition: PartitionID,
+ var data: T)
+ extends Product2[PartitionID, T] with Serializable {
+
+ override def _1 = partition
+
+ override def _2 = data
+
+ override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
+}
+
+
+private[graphx]
+class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
+ def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
+ val rdd = new ShuffledRDD[PartitionID, (VertexID, T), VertexBroadcastMsg[T]](self, partitioner)
+
+ // Set a custom serializer if the data is of int or double type.
+ if (classTag[T] == ClassTag.Int) {
+ rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Long) {
+ rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Double) {
+ rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
+ }
+ rdd
+ }
+}
+
+
+private[graphx]
+class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {
+
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner.
+ */
+ def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
+ new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner)
+ }
+
+}
+
+
+private[graphx]
+object MsgRDDFunctions {
+ implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = {
+ new MsgRDDFunctions(rdd)
+ }
+
+ implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = {
+ new VertexBroadcastMsgRDDFunctions(rdd)
+ }
+
+ def partitionForAggregation[T: ClassTag](msgs: RDD[(VertexID, T)], partitioner: Partitioner) = {
+ val rdd = new ShuffledRDD[VertexID, T, (VertexID, T)](msgs, partitioner)
+
+ // Set a custom serializer if the data is of int or double type.
+ if (classTag[T] == ClassTag.Int) {
+ rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Long) {
+ rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Double) {
+ rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
+ }
+ rdd
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
new file mode 100644
index 0000000000..4ebe0b0267
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -0,0 +1,195 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet}
+
+import org.apache.spark.graphx._
+
+/**
+ * A view of the vertices after they are shipped to the join sites specified in
+ * `vertexPlacement`. The resulting view is co-partitioned with `edges`. If `prevViewOpt` is
+ * specified, `updatedVerts` are treated as incremental updates to the previous view. Otherwise, a
+ * fresh view is created.
+ *
+ * The view is always cached (i.e., once it is evaluated, it remains materialized). This avoids
+ * constructing it twice if the user calls graph.triplets followed by graph.mapReduceTriplets, for
+ * example. However, it means iterative algorithms must manually call `Graph.unpersist` on previous
+ * iterations' graphs for best GC performance. See the implementation of
+ * [[org.apache.spark.graphx.Pregel]] for an example.
+ */
+private[impl]
+class ReplicatedVertexView[VD: ClassTag](
+ updatedVerts: VertexRDD[VD],
+ edges: EdgeRDD[_],
+ routingTable: RoutingTable,
+ prevViewOpt: Option[ReplicatedVertexView[VD]] = None) {
+
+ /**
+ * Within each edge partition, create a local map from vid to an index into the attribute
+ * array. Each map contains a superset of the vertices that it will receive, because it stores
+ * vids from both the source and destination of edges. It must always include both source and
+ * destination vids because some operations, such as GraphImpl.mapReduceTriplets, rely on this.
+ */
+ private val localVertexIDMap: RDD[(Int, VertexIdToIndexMap)] = prevViewOpt match {
+ case Some(prevView) =>
+ prevView.localVertexIDMap
+ case None =>
+ edges.partitionsRDD.mapPartitions(_.map {
+ case (pid, epart) =>
+ val vidToIndex = new VertexIdToIndexMap
+ epart.foreach { e =>
+ vidToIndex.add(e.srcId)
+ vidToIndex.add(e.dstId)
+ }
+ (pid, vidToIndex)
+ }, preservesPartitioning = true).cache().setName("ReplicatedVertexView localVertexIDMap")
+ }
+
+ private lazy val bothAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(true, true)
+ private lazy val srcAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(true, false)
+ private lazy val dstAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(false, true)
+ private lazy val noAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(false, false)
+
+ def unpersist(blocking: Boolean = true): ReplicatedVertexView[VD] = {
+ bothAttrs.unpersist(blocking)
+ srcAttrOnly.unpersist(blocking)
+ dstAttrOnly.unpersist(blocking)
+ noAttrs.unpersist(blocking)
+ // Don't unpersist localVertexIDMap because a future ReplicatedVertexView may be using it
+ // without modification
+ this
+ }
+
+ def get(includeSrc: Boolean, includeDst: Boolean): RDD[(PartitionID, VertexPartition[VD])] = {
+ (includeSrc, includeDst) match {
+ case (true, true) => bothAttrs
+ case (true, false) => srcAttrOnly
+ case (false, true) => dstAttrOnly
+ case (false, false) => noAttrs
+ }
+ }
+
+ def get(
+ includeSrc: Boolean,
+ includeDst: Boolean,
+ actives: VertexRDD[_]): RDD[(PartitionID, VertexPartition[VD])] = {
+ // Ship active sets to edge partitions using vertexPlacement, but ignoring includeSrc and
+ // includeDst. These flags govern attribute shipping, but the activeness of a vertex must be
+ // shipped to all edges mentioning that vertex, regardless of whether the vertex attribute is
+ // also shipped there.
+ val shippedActives = routingTable.get(true, true)
+ .zipPartitions(actives.partitionsRDD)(ReplicatedVertexView.buildActiveBuffer(_, _))
+ .partitionBy(edges.partitioner.get)
+ // Update the view with shippedActives, setting activeness flags in the resulting
+ // VertexPartitions
+ get(includeSrc, includeDst).zipPartitions(shippedActives) { (viewIter, shippedActivesIter) =>
+ val (pid, vPart) = viewIter.next()
+ val newPart = vPart.replaceActives(shippedActivesIter.flatMap(_._2.iterator))
+ Iterator((pid, newPart))
+ }
+ }
+
+ private def create(includeSrc: Boolean, includeDst: Boolean)
+ : RDD[(PartitionID, VertexPartition[VD])] = {
+ val vdTag = classTag[VD]
+
+ // Ship vertex attributes to edge partitions according to vertexPlacement
+ val verts = updatedVerts.partitionsRDD
+ val shippedVerts = routingTable.get(includeSrc, includeDst)
+ .zipPartitions(verts)(ReplicatedVertexView.buildBuffer(_, _)(vdTag))
+ .partitionBy(edges.partitioner.get)
+ // TODO: Consider using a specialized shuffler.
+
+ prevViewOpt match {
+ case Some(prevView) =>
+ // Update prevView with shippedVerts, setting staleness flags in the resulting
+ // VertexPartitions
+ prevView.get(includeSrc, includeDst).zipPartitions(shippedVerts) {
+ (prevViewIter, shippedVertsIter) =>
+ val (pid, prevVPart) = prevViewIter.next()
+ val newVPart = prevVPart.innerJoinKeepLeft(shippedVertsIter.flatMap(_._2.iterator))
+ Iterator((pid, newVPart))
+ }.cache().setName("ReplicatedVertexView delta %s %s".format(includeSrc, includeDst))
+
+ case None =>
+ // Within each edge partition, place the shipped vertex attributes into the correct
+ // locations specified in localVertexIDMap
+ localVertexIDMap.zipPartitions(shippedVerts) { (mapIter, shippedVertsIter) =>
+ val (pid, vidToIndex) = mapIter.next()
+ assert(!mapIter.hasNext)
+ // Populate the vertex array using the vidToIndex map
+ val vertexArray = vdTag.newArray(vidToIndex.capacity)
+ for ((_, block) <- shippedVertsIter) {
+ for (i <- 0 until block.vids.size) {
+ val vid = block.vids(i)
+ val attr = block.attrs(i)
+ val ind = vidToIndex.getPos(vid)
+ vertexArray(ind) = attr
+ }
+ }
+ val newVPart = new VertexPartition(
+ vidToIndex, vertexArray, vidToIndex.getBitSet)(vdTag)
+ Iterator((pid, newVPart))
+ }.cache().setName("ReplicatedVertexView %s %s".format(includeSrc, includeDst))
+ }
+ }
+}
+
+private object ReplicatedVertexView {
+ protected def buildBuffer[VD: ClassTag](
+ pid2vidIter: Iterator[Array[Array[VertexID]]],
+ vertexPartIter: Iterator[VertexPartition[VD]]) = {
+ val pid2vid: Array[Array[VertexID]] = pid2vidIter.next()
+ val vertexPart: VertexPartition[VD] = vertexPartIter.next()
+
+ Iterator.tabulate(pid2vid.size) { pid =>
+ val vidsCandidate = pid2vid(pid)
+ val size = vidsCandidate.length
+ val vids = new PrimitiveVector[VertexID](pid2vid(pid).size)
+ val attrs = new PrimitiveVector[VD](pid2vid(pid).size)
+ var i = 0
+ while (i < size) {
+ val vid = vidsCandidate(i)
+ if (vertexPart.isDefined(vid)) {
+ vids += vid
+ attrs += vertexPart(vid)
+ }
+ i += 1
+ }
+ (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
+ }
+ }
+
+ protected def buildActiveBuffer(
+ pid2vidIter: Iterator[Array[Array[VertexID]]],
+ activePartIter: Iterator[VertexPartition[_]])
+ : Iterator[(Int, Array[VertexID])] = {
+ val pid2vid: Array[Array[VertexID]] = pid2vidIter.next()
+ val activePart: VertexPartition[_] = activePartIter.next()
+
+ Iterator.tabulate(pid2vid.size) { pid =>
+ val vidsCandidate = pid2vid(pid)
+ val size = vidsCandidate.length
+ val actives = new PrimitiveVector[VertexID](vidsCandidate.size)
+ var i = 0
+ while (i < size) {
+ val vid = vidsCandidate(i)
+ if (activePart.isDefined(vid)) {
+ actives += vid
+ }
+ i += 1
+ }
+ (pid, actives.trim().array)
+ }
+ }
+}
+
+private[graphx]
+class VertexAttributeBlock[VD: ClassTag](val vids: Array[VertexID], val attrs: Array[VD])
+ extends Serializable {
+ def iterator: Iterator[(VertexID, VD)] =
+ (0 until vids.size).iterator.map { i => (vids(i), attrs(i)) }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala
new file mode 100644
index 0000000000..f342fd7437
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala
@@ -0,0 +1,65 @@
+package org.apache.spark.graphx.impl
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.PrimitiveVector
+
+/**
+ * Stores the locations of edge-partition join sites for each vertex attribute; that is, the routing
+ * information for shipping vertex attributes to edge partitions. This is always cached because it
+ * may be used multiple times in ReplicatedVertexView -- once to ship the vertex attributes and
+ * (possibly) once to ship the active-set information.
+ */
+private[impl]
+class RoutingTable(edges: EdgeRDD[_], vertices: VertexRDD[_]) {
+
+ val bothAttrs: RDD[Array[Array[VertexID]]] = createPid2Vid(true, true)
+ val srcAttrOnly: RDD[Array[Array[VertexID]]] = createPid2Vid(true, false)
+ val dstAttrOnly: RDD[Array[Array[VertexID]]] = createPid2Vid(false, true)
+ val noAttrs: RDD[Array[Array[VertexID]]] = createPid2Vid(false, false)
+
+ def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexID]]] =
+ (includeSrcAttr, includeDstAttr) match {
+ case (true, true) => bothAttrs
+ case (true, false) => srcAttrOnly
+ case (false, true) => dstAttrOnly
+ case (false, false) => noAttrs
+ }
+
+ private def createPid2Vid(
+ includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexID]]] = {
+ // Determine which vertices each edge partition needs by creating a mapping from vid to pid.
+ val vid2pid: RDD[(VertexID, PartitionID)] = edges.partitionsRDD.mapPartitions { iter =>
+ val (pid: PartitionID, edgePartition: EdgePartition[_]) = iter.next()
+ val numEdges = edgePartition.size
+ val vSet = new VertexSet
+ if (includeSrcAttr) { // Add src vertices to the set.
+ var i = 0
+ while (i < numEdges) {
+ vSet.add(edgePartition.srcIds(i))
+ i += 1
+ }
+ }
+ if (includeDstAttr) { // Add dst vertices to the set.
+ var i = 0
+ while (i < numEdges) {
+ vSet.add(edgePartition.dstIds(i))
+ i += 1
+ }
+ }
+ vSet.iterator.map { vid => (vid, pid) }
+ }
+
+ val numPartitions = vertices.partitions.size
+ vid2pid.partitionBy(vertices.partitioner.get).mapPartitions { iter =>
+ val pid2vid = Array.fill(numPartitions)(new PrimitiveVector[VertexID])
+ for ((vid, pid) <- iter) {
+ pid2vid(pid) += vid
+ }
+
+ Iterator(pid2vid.map(_.trim().array))
+ }.cache().setName("RoutingTable %s %s".format(includeSrcAttr, includeDstAttr))
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
new file mode 100644
index 0000000000..cbd6318f33
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -0,0 +1,395 @@
+package org.apache.spark.graphx.impl
+
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.graphx._
+import org.apache.spark.serializer._
+
+private[graphx]
+class VertexIDMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, _)]
+ writeVarLong(msg._1, optimizePositive = false)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ (readVarLong(optimizePositive = false), null).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
+private[graphx]
+class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeInt(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readInt()
+ new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
+private[graphx]
+class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeLong(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readLong()
+ new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
+private[graphx]
+class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeDouble(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readDouble()
+ new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Int]. */
+private[graphx]
+class IntAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Int)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeUnsignedVarInt(msg._2)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readUnsignedVarInt()
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Long]. */
+private[graphx]
+class LongAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Long)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeVarLong(msg._2, optimizePositive = true)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readVarLong(optimizePositive = true)
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Double]. */
+private[graphx]
+class DoubleAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Double)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeDouble(msg._2)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readDouble()
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Helper classes to shorten the implementation of those special serializers.
+////////////////////////////////////////////////////////////////////////////////
+
+private[graphx]
+abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
+ // The implementation should override this one.
+ def writeObject[T](t: T): SerializationStream
+
+ def writeInt(v: Int) {
+ s.write(v >> 24)
+ s.write(v >> 16)
+ s.write(v >> 8)
+ s.write(v)
+ }
+
+ def writeUnsignedVarInt(value: Int) {
+ if ((value >>> 7) == 0) {
+ s.write(value.toInt)
+ } else if ((value >>> 14) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7)
+ } else if ((value >>> 21) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14)
+ } else if ((value >>> 28) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14 | 0x80)
+ s.write(value >>> 21)
+ } else {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14 | 0x80)
+ s.write(value >>> 21 | 0x80)
+ s.write(value >>> 28)
+ }
+ }
+
+ def writeVarLong(value: Long, optimizePositive: Boolean) {
+ val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value
+ if ((v >>> 7) == 0) {
+ s.write(v.toInt)
+ } else if ((v >>> 14) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7).toInt)
+ } else if ((v >>> 21) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14).toInt)
+ } else if ((v >>> 28) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21).toInt)
+ } else if ((v >>> 35) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28).toInt)
+ } else if ((v >>> 42) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35).toInt)
+ } else if ((v >>> 49) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42).toInt)
+ } else if ((v >>> 56) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42 | 0x80).toInt)
+ s.write((v >>> 49).toInt)
+ } else {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42 | 0x80).toInt)
+ s.write((v >>> 49 | 0x80).toInt)
+ s.write((v >>> 56).toInt)
+ }
+ }
+
+ def writeLong(v: Long) {
+ s.write((v >>> 56).toInt)
+ s.write((v >>> 48).toInt)
+ s.write((v >>> 40).toInt)
+ s.write((v >>> 32).toInt)
+ s.write((v >>> 24).toInt)
+ s.write((v >>> 16).toInt)
+ s.write((v >>> 8).toInt)
+ s.write(v.toInt)
+ }
+
+ //def writeDouble(v: Double): Unit = writeUnsignedVarLong(java.lang.Double.doubleToLongBits(v))
+ def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
+
+ override def flush(): Unit = s.flush()
+
+ override def close(): Unit = s.close()
+}
+
+private[graphx]
+abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
+ // The implementation should override this one.
+ def readObject[T](): T
+
+ def readInt(): Int = {
+ val first = s.read()
+ if (first < 0) throw new EOFException
+ (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
+ }
+
+ def readUnsignedVarInt(): Int = {
+ var value: Int = 0
+ var i: Int = 0
+ def readOrThrow(): Int = {
+ val in = s.read()
+ if (in < 0) throw new EOFException
+ in & 0xFF
+ }
+ var b: Int = readOrThrow()
+ while ((b & 0x80) != 0) {
+ value |= (b & 0x7F) << i
+ i += 7
+ if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long")
+ b = readOrThrow()
+ }
+ value | (b << i)
+ }
+
+ def readVarLong(optimizePositive: Boolean): Long = {
+ def readOrThrow(): Int = {
+ val in = s.read()
+ if (in < 0) throw new EOFException
+ in & 0xFF
+ }
+ var b = readOrThrow()
+ var ret: Long = b & 0x7F
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 7
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 14
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 21
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 28
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 35
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 42
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 49
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= b.toLong << 56
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret
+ }
+
+ def readLong(): Long = {
+ val first = s.read()
+ if (first < 0) throw new EOFException()
+ (first.toLong << 56) |
+ (s.read() & 0xFF).toLong << 48 |
+ (s.read() & 0xFF).toLong << 40 |
+ (s.read() & 0xFF).toLong << 32 |
+ (s.read() & 0xFF).toLong << 24 |
+ (s.read() & 0xFF) << 16 |
+ (s.read() & 0xFF) << 8 |
+ (s.read() & 0xFF)
+ }
+
+ //def readDouble(): Double = java.lang.Double.longBitsToDouble(readUnsignedVarLong())
+ def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
+
+ override def close(): Unit = s.close()
+}
+
+private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
+
+ override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
+
+ override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
+
+ override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
+ throw new UnsupportedOperationException
+
+ // The implementation should override the following two.
+ override def serializeStream(s: OutputStream): SerializationStream
+ override def deserializeStream(s: InputStream): DeserializationStream
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
new file mode 100644
index 0000000000..f97ff75fb2
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
@@ -0,0 +1,261 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.BitSet
+
+private[graphx] object VertexPartition {
+
+ def apply[VD: ClassTag](iter: Iterator[(VertexID, VD)]): VertexPartition[VD] = {
+ val map = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ iter.foreach { case (k, v) =>
+ map(k) = v
+ }
+ new VertexPartition(map.keySet, map._values, map.keySet.getBitSet)
+ }
+
+ def apply[VD: ClassTag](iter: Iterator[(VertexID, VD)], mergeFunc: (VD, VD) => VD)
+ : VertexPartition[VD] =
+ {
+ val map = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ iter.foreach { case (k, v) =>
+ map.setMerge(k, v, mergeFunc)
+ }
+ new VertexPartition(map.keySet, map._values, map.keySet.getBitSet)
+ }
+}
+
+
+private[graphx]
+class VertexPartition[@specialized(Long, Int, Double) VD: ClassTag](
+ val index: VertexIdToIndexMap,
+ val values: Array[VD],
+ val mask: BitSet,
+ /** A set of vids of active vertices. May contain vids not in index due to join rewrite. */
+ private val activeSet: Option[VertexSet] = None)
+ extends Logging {
+
+ val capacity: Int = index.capacity
+
+ def size: Int = mask.cardinality()
+
+ /** Return the vertex attribute for the given vertex ID. */
+ def apply(vid: VertexID): VD = values(index.getPos(vid))
+
+ def isDefined(vid: VertexID): Boolean = {
+ val pos = index.getPos(vid)
+ pos >= 0 && mask.get(pos)
+ }
+
+ /** Look up vid in activeSet, throwing an exception if it is None. */
+ def isActive(vid: VertexID): Boolean = {
+ activeSet.get.contains(vid)
+ }
+
+ /** The number of active vertices, if any exist. */
+ def numActives: Option[Int] = activeSet.map(_.size)
+
+ /**
+ * Pass each vertex attribute along with the vertex id through a map
+ * function and retain the original RDD's partitioning and index.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each vertex id and vertex
+ * attribute in the RDD
+ *
+ * @return a new VertexPartition with values obtained by applying `f` to
+ * each of the entries in the original VertexRDD. The resulting
+ * VertexPartition retains the same index.
+ */
+ def map[VD2: ClassTag](f: (VertexID, VD) => VD2): VertexPartition[VD2] = {
+ // Construct a view of the map transformation
+ val newValues = new Array[VD2](capacity)
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ newValues(i) = f(index.getValue(i), values(i))
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition[VD2](index, newValues, mask)
+ }
+
+ /**
+ * Restrict the vertex set to the set of vertices satisfying the given predicate.
+ *
+ * @param pred the user defined predicate
+ *
+ * @note The vertex set preserves the original index structure which means that the returned
+ * RDD can be easily joined with the original vertex-set. Furthermore, the filter only
+ * modifies the bitmap index and so no new values are allocated.
+ */
+ def filter(pred: (VertexID, VD) => Boolean): VertexPartition[VD] = {
+ // Allocate the array to store the results into
+ val newMask = new BitSet(capacity)
+ // Iterate over the active bits in the old mask and evaluate the predicate
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ if (pred(index.getValue(i), values(i))) {
+ newMask.set(i)
+ }
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, values, newMask)
+ }
+
+ /**
+ * Hides vertices that are the same between this and other. For vertices that are different, keeps
+ * the values from `other`. The indices of `this` and `other` must be the same.
+ */
+ def diff(other: VertexPartition[VD]): VertexPartition[VD] = {
+ if (index != other.index) {
+ logWarning("Diffing two VertexPartitions with different indexes is slow.")
+ diff(createUsingIndex(other.iterator))
+ } else {
+ val newMask = mask & other.mask
+ var i = newMask.nextSetBit(0)
+ while (i >= 0) {
+ if (values(i) == other.values(i)) {
+ newMask.unset(i)
+ }
+ i = newMask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, other.values, newMask)
+ }
+ }
+
+ /** Left outer join another VertexPartition. */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: VertexPartition[VD2])
+ (f: (VertexID, VD, Option[VD2]) => VD3): VertexPartition[VD3] = {
+ if (index != other.index) {
+ logWarning("Joining two VertexPartitions with different indexes is slow.")
+ leftJoin(createUsingIndex(other.iterator))(f)
+ } else {
+ val newValues = new Array[VD3](capacity)
+
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ val otherV: Option[VD2] = if (other.mask.get(i)) Some(other.values(i)) else None
+ newValues(i) = f(index.getValue(i), values(i), otherV)
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, newValues, mask)
+ }
+ }
+
+ /** Left outer join another iterator of messages. */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: Iterator[(VertexID, VD2)])
+ (f: (VertexID, VD, Option[VD2]) => VD3): VertexPartition[VD3] = {
+ leftJoin(createUsingIndex(other))(f)
+ }
+
+ /** Inner join another VertexPartition. */
+ def innerJoin[U: ClassTag, VD2: ClassTag](other: VertexPartition[U])
+ (f: (VertexID, VD, U) => VD2): VertexPartition[VD2] = {
+ if (index != other.index) {
+ logWarning("Joining two VertexPartitions with different indexes is slow.")
+ innerJoin(createUsingIndex(other.iterator))(f)
+ } else {
+ val newMask = mask & other.mask
+ val newValues = new Array[VD2](capacity)
+ var i = newMask.nextSetBit(0)
+ while (i >= 0) {
+ newValues(i) = f(index.getValue(i), values(i), other.values(i))
+ i = newMask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, newValues, newMask)
+ }
+ }
+
+ /**
+ * Inner join an iterator of messages.
+ */
+ def innerJoin[U: ClassTag, VD2: ClassTag]
+ (iter: Iterator[Product2[VertexID, U]])
+ (f: (VertexID, VD, U) => VD2): VertexPartition[VD2] = {
+ innerJoin(createUsingIndex(iter))(f)
+ }
+
+ /**
+ * Similar effect as aggregateUsingIndex((a, b) => a)
+ */
+ def createUsingIndex[VD2: ClassTag](iter: Iterator[Product2[VertexID, VD2]])
+ : VertexPartition[VD2] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD2](capacity)
+ iter.foreach { case (vid, vdata) =>
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ new VertexPartition[VD2](index, newValues, newMask)
+ }
+
+ /**
+ * Similar to innerJoin, but vertices from the left side that don't appear in iter will remain in
+ * the partition, hidden by the bitmask.
+ */
+ def innerJoinKeepLeft(iter: Iterator[Product2[VertexID, VD]]): VertexPartition[VD] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD](capacity)
+ System.arraycopy(values, 0, newValues, 0, newValues.length)
+ iter.foreach { case (vid, vdata) =>
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ new VertexPartition(index, newValues, newMask)
+ }
+
+ def aggregateUsingIndex[VD2: ClassTag](
+ iter: Iterator[Product2[VertexID, VD2]],
+ reduceFunc: (VD2, VD2) => VD2): VertexPartition[VD2] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD2](capacity)
+ iter.foreach { product =>
+ val vid = product._1
+ val vdata = product._2
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ if (newMask.get(pos)) {
+ newValues(pos) = reduceFunc(newValues(pos), vdata)
+ } else { // otherwise just store the new value
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ }
+ new VertexPartition[VD2](index, newValues, newMask)
+ }
+
+ def replaceActives(iter: Iterator[VertexID]): VertexPartition[VD] = {
+ val newActiveSet = new VertexSet
+ iter.foreach(newActiveSet.add(_))
+ new VertexPartition(index, values, mask, Some(newActiveSet))
+ }
+
+ /**
+ * Construct a new VertexPartition whose index contains only the vertices in the mask.
+ */
+ def reindex(): VertexPartition[VD] = {
+ val hashMap = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ val arbitraryMerge = (a: VD, b: VD) => a
+ for ((k, v) <- this.iterator) {
+ hashMap.setMerge(k, v, arbitraryMerge)
+ }
+ new VertexPartition(hashMap.keySet, hashMap._values, hashMap.keySet.getBitSet)
+ }
+
+ def iterator: Iterator[(VertexID, VD)] =
+ mask.iterator.map(ind => (index.getValue(ind), values(ind)))
+
+ def vidIterator: Iterator[VertexID] = mask.iterator.map(ind => index.getValue(ind))
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala
new file mode 100644
index 0000000000..cfc3281b64
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala
@@ -0,0 +1,7 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.util.collection.OpenHashSet
+
+package object impl {
+ private[graphx] type VertexIdToIndexMap = OpenHashSet[VertexID]
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
new file mode 100644
index 0000000000..e0aff5644e
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
@@ -0,0 +1,136 @@
+package org.apache.spark.graphx.lib
+
+import org.apache.spark._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.PartitionStrategy._
+
+/**
+ * Driver program for running graph algorithms.
+ */
+object Analytics extends Logging {
+
+ def main(args: Array[String]) = {
+ val host = args(0)
+ val taskType = args(1)
+ val fname = args(2)
+ val options = args.drop(3).map { arg =>
+ arg.dropWhile(_ == '-').split('=') match {
+ case Array(opt, v) => (opt -> v)
+ case _ => throw new IllegalArgumentException("Invalid argument: " + arg)
+ }
+ }
+
+ def pickPartitioner(v: String): PartitionStrategy = {
+ // TODO: Use reflection rather than listing all the partitioning strategies here.
+ v match {
+ case "RandomVertexCut" => RandomVertexCut
+ case "EdgePartition1D" => EdgePartition1D
+ case "EdgePartition2D" => EdgePartition2D
+ case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
+ case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
+ }
+ }
+
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+
+ taskType match {
+ case "pagerank" =>
+ var tol: Float = 0.001F
+ var outFname = ""
+ var numEPart = 4
+ var partitionStrategy: Option[PartitionStrategy] = None
+
+ options.foreach{
+ case ("tol", v) => tol = v.toFloat
+ case ("output", v) => outFname = v
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v))
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+
+ println("======================================")
+ println("| PageRank |")
+ println("======================================")
+
+ val sc = new SparkContext(host, "PageRank(" + fname + ")", conf)
+
+ val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
+ minEdgePartitions = numEPart).cache()
+ val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
+
+ println("GRAPHX: Number of vertices " + graph.vertices.count)
+ println("GRAPHX: Number of edges " + graph.edges.count)
+
+ val pr = graph.pageRank(tol).vertices.cache()
+
+ println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_))
+
+ if (!outFname.isEmpty) {
+ logWarning("Saving pageranks of pages to " + outFname)
+ pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+ }
+
+ sc.stop()
+
+ case "cc" =>
+ var numIter = Int.MaxValue
+ var numVPart = 4
+ var numEPart = 4
+ var isDynamic = false
+ var partitionStrategy: Option[PartitionStrategy] = None
+
+ options.foreach{
+ case ("numIter", v) => numIter = v.toInt
+ case ("dynamic", v) => isDynamic = v.toBoolean
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("numVPart", v) => numVPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v))
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+
+ if (!isDynamic && numIter == Int.MaxValue) {
+ println("Set number of iterations!")
+ sys.exit(1)
+ }
+ println("======================================")
+ println("| Connected Components |")
+ println("======================================")
+
+ val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")", conf)
+ val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
+ minEdgePartitions = numEPart).cache()
+ val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
+
+ val cc = ConnectedComponents.run(graph)
+ println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+ sc.stop()
+
+ case "triangles" =>
+ var numEPart = 4
+ // TriangleCount requires the graph to be partitioned
+ var partitionStrategy: PartitionStrategy = RandomVertexCut
+
+ options.foreach{
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = pickPartitioner(v)
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+ println("======================================")
+ println("| Triangle Count |")
+ println("======================================")
+ val sc = new SparkContext(host, "TriangleCount(" + fname + ")", conf)
+ val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true,
+ minEdgePartitions = numEPart).partitionBy(partitionStrategy).cache()
+ val triangles = TriangleCount.run(graph)
+ println("Triangles: " + triangles.vertices.map {
+ case (vid,data) => data.toLong
+ }.reduce(_ + _) / 3)
+ sc.stop()
+
+ case _ =>
+ println("Invalid task type.")
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
new file mode 100644
index 0000000000..4d1f5e74df
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -0,0 +1,38 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/** Connected components algorithm. */
+object ConnectedComponents {
+ /**
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @tparam VD the vertex attribute type (discarded in the computation)
+ * @tparam ED the edge attribute type (preserved in the computation)
+ *
+ * @param graph the graph for which to compute the connected components
+ *
+ * @return a graph with vertex attributes containing the smallest vertex in each
+ * connected component
+ */
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexID, ED] = {
+ val ccGraph = graph.mapVertices { case (vid, _) => vid }
+ def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
+ if (edge.srcAttr < edge.dstAttr) {
+ Iterator((edge.dstId, edge.srcAttr))
+ } else if (edge.srcAttr > edge.dstAttr) {
+ Iterator((edge.srcId, edge.dstAttr))
+ } else {
+ Iterator.empty
+ }
+ }
+ val initialMessage = Long.MaxValue
+ Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
+ vprog = (id, attr, msg) => math.min(attr, msg),
+ sendMsg = sendMessage,
+ mergeMsg = (a, b) => math.min(a, b))
+ } // end of connectedComponents
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
new file mode 100644
index 0000000000..2f4d6d6864
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -0,0 +1,147 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx._
+
+/**
+ * PageRank algorithm implementation. There are two implementations of PageRank implemented.
+ *
+ * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number
+ * of iterations:
+ * {{{
+ * var PR = Array.fill(n)( 1.0 )
+ * val oldPR = Array.fill(n)( 1.0 )
+ * for( iter <- 0 until numIter ) {
+ * swap(oldPR, PR)
+ * for( i <- 0 until n ) {
+ * PR[i] = alpha + (1 - alpha) * inNbrs[i].map(j => oldPR[j] / outDeg[j]).sum
+ * }
+ * }
+ * }}}
+ *
+ * The second implementation uses the standalone [[Graph]] interface and runs PageRank until
+ * convergence:
+ *
+ * {{{
+ * var PR = Array.fill(n)( 1.0 )
+ * val oldPR = Array.fill(n)( 0.0 )
+ * while( max(abs(PR - oldPr)) > tol ) {
+ * swap(oldPR, PR)
+ * for( i <- 0 until n if abs(PR[i] - oldPR[i]) > tol ) {
+ * PR[i] = alpha + (1 - \alpha) * inNbrs[i].map(j => oldPR[j] / outDeg[j]).sum
+ * }
+ * }
+ * }}}
+ *
+ * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of
+ * neighbors whick link to `i` and `outDeg[j]` is the out degree of vertex `j`.
+ *
+ * Note that this is not the "normalized" PageRank and as a consequence pages that have no
+ * inlinks will have a PageRank of alpha.
+ */
+object PageRank extends Logging {
+
+ /**
+ * Run PageRank for a fixed number of iterations returning a graph
+ * with vertex attributes containing the PageRank and edge
+ * attributes the normalized edge weight.
+ *
+ * @tparam VD the original vertex attribute (not used)
+ * @tparam ED the original edge attribute (not used)
+ *
+ * @param graph the graph on which to compute PageRank
+ * @param numIter the number of iterations of PageRank to run
+ * @param resetProb the random reset probability (alpha)
+ *
+ * @return the graph containing with each vertex containing the PageRank and each edge
+ * containing the normalized weight.
+ *
+ */
+ def run[VD: ClassTag, ED: ClassTag](
+ graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] =
+ {
+ // Initialize the pagerankGraph with each edge attribute having
+ // weight 1/outDegree and each vertex with attribute 1.0.
+ val pagerankGraph: Graph[Double, Double] = graph
+ // Associate the degree with each vertex
+ .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
+ // Set the weight on the edges based on the degree
+ .mapTriplets( e => 1.0 / e.srcAttr )
+ // Set the vertex attributes to the initial pagerank values
+ .mapVertices( (id, attr) => 1.0 )
+ .cache()
+
+ // Define the three functions needed to implement PageRank in the GraphX
+ // version of Pregel
+ def vertexProgram(id: VertexID, attr: Double, msgSum: Double): Double =
+ resetProb + (1.0 - resetProb) * msgSum
+ def sendMessage(edge: EdgeTriplet[Double, Double]) =
+ Iterator((edge.dstId, edge.srcAttr * edge.attr))
+ def messageCombiner(a: Double, b: Double): Double = a + b
+ // The initial message received by all vertices in PageRank
+ val initialMessage = 0.0
+
+ // Execute pregel for a fixed number of iterations.
+ Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)(
+ vertexProgram, sendMessage, messageCombiner)
+ }
+
+ /**
+ * Run a dynamic version of PageRank returning a graph with vertex attributes containing the
+ * PageRank and edge attributes containing the normalized edge weight.
+ *
+ * @tparam VD the original vertex attribute (not used)
+ * @tparam ED the original edge attribute (not used)
+ *
+ * @param graph the graph on which to compute PageRank
+ * @param tol the tolerance allowed at convergence (smaller => more accurate).
+ * @param resetProb the random reset probability (alpha)
+ *
+ * @return the graph containing with each vertex containing the PageRank and each edge
+ * containing the normalized weight.
+ */
+ def runUntilConvergence[VD: ClassTag, ED: ClassTag](
+ graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] =
+ {
+ // Initialize the pagerankGraph with each edge attribute
+ // having weight 1/outDegree and each vertex with attribute 1.0.
+ val pagerankGraph: Graph[(Double, Double), Double] = graph
+ // Associate the degree with each vertex
+ .outerJoinVertices(graph.outDegrees) {
+ (vid, vdata, deg) => deg.getOrElse(0)
+ }
+ // Set the weight on the edges based on the degree
+ .mapTriplets( e => 1.0 / e.srcAttr )
+ // Set the vertex attributes to (initalPR, delta = 0)
+ .mapVertices( (id, attr) => (0.0, 0.0) )
+ .cache()
+
+ // Define the three functions needed to implement PageRank in the GraphX
+ // version of Pregel
+ def vertexProgram(id: VertexID, attr: (Double, Double), msgSum: Double): (Double, Double) = {
+ val (oldPR, lastDelta) = attr
+ val newPR = oldPR + (1.0 - resetProb) * msgSum
+ (newPR, newPR - oldPR)
+ }
+
+ def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = {
+ if (edge.srcAttr._2 > tol) {
+ Iterator((edge.dstId, edge.srcAttr._2 * edge.attr))
+ } else {
+ Iterator.empty
+ }
+ }
+
+ def messageCombiner(a: Double, b: Double): Double = a + b
+
+ // The initial message received by all vertices in PageRank
+ val initialMessage = resetProb / (1.0 - resetProb)
+
+ // Execute a dynamic version of Pregel.
+ Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
+ vertexProgram, sendMessage, messageCombiner)
+ .mapVertices((vid, attr) => attr._1)
+ } // end of deltaPageRank
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
new file mode 100644
index 0000000000..ba6517e012
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -0,0 +1,138 @@
+package org.apache.spark.graphx.lib
+
+import scala.util.Random
+import org.apache.commons.math.linear._
+import org.apache.spark.rdd._
+import org.apache.spark.graphx._
+
+/** Implementation of SVD++ algorithm. */
+object SVDPlusPlus {
+
+ /** Configuration parameters for SVDPlusPlus. */
+ class Conf(
+ var rank: Int,
+ var maxIters: Int,
+ var minVal: Double,
+ var maxVal: Double,
+ var gamma1: Double,
+ var gamma2: Double,
+ var gamma6: Double,
+ var gamma7: Double)
+ extends Serializable
+
+ /**
+ * Implement SVD++ based on "Factorization Meets the Neighborhood:
+ * a Multifaceted Collaborative Filtering Model",
+ * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
+ *
+ * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)),
+ * see the details on page 6.
+ *
+ * @param edges edges for constructing the graph
+ *
+ * @param conf SVDPlusPlus parameters
+ *
+ * @return a graph with vertex attributes containing the trained model
+ */
+ def run(edges: RDD[Edge[Double]], conf: Conf)
+ : (Graph[(RealVector, RealVector, Double, Double), Double], Double) =
+ {
+ // Generate default vertex attribute
+ def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = {
+ val v1 = new ArrayRealVector(rank)
+ val v2 = new ArrayRealVector(rank)
+ for (i <- 0 until rank) {
+ v1.setEntry(i, Random.nextDouble())
+ v2.setEntry(i, Random.nextDouble())
+ }
+ (v1, v2, 0.0, 0.0)
+ }
+
+ // calculate global rating mean
+ edges.cache()
+ val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
+ val u = rs / rc
+
+ // construct graph
+ var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
+
+ // Calculate initial bias and norm
+ val t0 = g.mapReduceTriplets(
+ et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
+ (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
+
+ g = g.outerJoinVertices(t0) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[(Long, Double)]) =>
+ (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
+ }
+
+ def mapTrainF(conf: Conf, u: Double)
+ (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
+ : Iterator[(VertexID, (RealVector, RealVector, Double))] = {
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr._1, itm._1)
+ var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
+ val err = et.attr - pred
+ val updateP = q.mapMultiply(err)
+ .subtract(p.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ val updateQ = usr._2.mapMultiply(err)
+ .subtract(q.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ val updateY = q.mapMultiply(err * usr._4)
+ .subtract(itm._2.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
+ (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
+ }
+
+ for (i <- 0 until conf.maxIters) {
+ // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
+ g.cache()
+ val t1 = g.mapReduceTriplets(
+ et => Iterator((et.srcId, et.dstAttr._2)),
+ (g1: RealVector, g2: RealVector) => g1.add(g2))
+ g = g.outerJoinVertices(t1) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) =>
+ if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd
+ }
+
+ // Phase 2, update p for user nodes and q, y for item nodes
+ g.cache()
+ val t2 = g.mapReduceTriplets(
+ mapTrainF(conf, u),
+ (g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) =>
+ (g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3))
+ g = g.outerJoinVertices(t2) {
+ (vid: VertexID,
+ vd: (RealVector, RealVector, Double, Double),
+ msg: Option[(RealVector, RealVector, Double)]) =>
+ (vd._1.add(msg.get._1), vd._2.add(msg.get._2), vd._3 + msg.get._3, vd._4)
+ }
+ }
+
+ // calculate error on training set
+ def mapTestF(conf: Conf, u: Double)
+ (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
+ : Iterator[(VertexID, Double)] =
+ {
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr._1, itm._1)
+ var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
+ val err = (et.attr - pred) * (et.attr - pred)
+ Iterator((et.dstId, err))
+ }
+ g.cache()
+ val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
+ g = g.outerJoinVertices(t3) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) =>
+ if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
+ }
+
+ (g, u)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
new file mode 100644
index 0000000000..d3d496e335
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
@@ -0,0 +1,94 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/** Strongly connected components algorithm implementation. */
+object StronglyConnectedComponents {
+
+ /**
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
+ *
+ * @tparam VD the vertex attribute type (discarded in the computation)
+ * @tparam ED the edge attribute type (preserved in the computation)
+ *
+ * @param graph the graph for which to compute the SCC
+ *
+ * @return a graph with vertex attributes containing the smallest vertex id in each SCC
+ */
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = {
+
+ // the graph we update with final SCC ids, and the graph we return at the end
+ var sccGraph = graph.mapVertices { case (vid, _) => vid }
+ // graph we are going to work with in our iterations
+ var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }.cache()
+
+ var numVertices = sccWorkGraph.numVertices
+ var iter = 0
+ while (sccWorkGraph.numVertices > 0 && iter < numIter) {
+ iter += 1
+ do {
+ numVertices = sccWorkGraph.numVertices
+ sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) {
+ (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
+ }.outerJoinVertices(sccWorkGraph.inDegrees) {
+ (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
+ }.cache()
+
+ // get all vertices to be removed
+ val finalVertices = sccWorkGraph.vertices
+ .filter { case (vid, (scc, isFinal)) => isFinal}
+ .mapValues { (vid, data) => data._1}
+
+ // write values to sccGraph
+ sccGraph = sccGraph.outerJoinVertices(finalVertices) {
+ (vid, scc, opt) => opt.getOrElse(scc)
+ }
+ // only keep vertices that are not final
+ sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache()
+ } while (sccWorkGraph.numVertices < numVertices)
+
+ sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) }
+
+ // collect min of all my neighbor's scc values, update if it's smaller than mine
+ // then notify any neighbors with scc values larger than mine
+ sccWorkGraph = Pregel[(VertexID, Boolean), ED, VertexID](
+ sccWorkGraph, Long.MaxValue, activeDirection = EdgeDirection.Out)(
+ (vid, myScc, neighborScc) => (math.min(myScc._1, neighborScc), myScc._2),
+ e => {
+ if (e.srcId < e.dstId) {
+ Iterator((e.dstId, e.srcAttr._1))
+ } else {
+ Iterator()
+ }
+ },
+ (vid1, vid2) => math.min(vid1, vid2))
+
+ // start at root of SCCs. Traverse values in reverse, notify all my neighbors
+ // do not propagate if colors do not match!
+ sccWorkGraph = Pregel[(VertexID, Boolean), ED, Boolean](
+ sccWorkGraph, false, activeDirection = EdgeDirection.In)(
+ // vertex is final if it is the root of a color
+ // or it has the same color as a neighbor that is final
+ (vid, myScc, existsSameColorFinalNeighbor) => {
+ val isColorRoot = vid == myScc._1
+ (myScc._1, myScc._2 || isColorRoot || existsSameColorFinalNeighbor)
+ },
+ // activate neighbor if they are not final, you are, and you have the same color
+ e => {
+ val sameColor = e.dstAttr._1 == e.srcAttr._1
+ val onlyDstIsFinal = e.dstAttr._2 && !e.srcAttr._2
+ if (sameColor && onlyDstIsFinal) {
+ Iterator((e.srcId, e.dstAttr._2))
+ } else {
+ Iterator()
+ }
+ },
+ (final1, final2) => final1 || final2)
+ }
+ sccGraph
+ }
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
new file mode 100644
index 0000000000..23c9c40594
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -0,0 +1,76 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/**
+ * Compute the number of triangles passing through each vertex.
+ *
+ * The algorithm is relatively straightforward and can be computed in three steps:
+ *
+ * <ul>
+ * <li>Compute the set of neighbors for each vertex
+ * <li>For each edge compute the intersection of the sets and send the count to both vertices.
+ * <li> Compute the sum at each vertex and divide by two since each triangle is counted twice.
+ * </ul>
+ *
+ * Note that the input graph should have its edges in canonical direction
+ * (i.e. the `sourceId` less than `destId`). Also the graph must have been partitioned
+ * using [[org.apache.spark.graphx.Graph#partitionBy]].
+ */
+object TriangleCount {
+
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = {
+ // Remove redundant edges
+ val g = graph.groupEdges((a, b) => a).cache()
+
+ // Construct set representations of the neighborhoods
+ val nbrSets: VertexRDD[VertexSet] =
+ g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
+ val set = new VertexSet(4)
+ var i = 0
+ while (i < nbrs.size) {
+ // prevent self cycle
+ if(nbrs(i) != vid) {
+ set.add(nbrs(i))
+ }
+ i += 1
+ }
+ set
+ }
+ // join the sets with the graph
+ val setGraph: Graph[VertexSet, ED] = g.outerJoinVertices(nbrSets) {
+ (vid, _, optSet) => optSet.getOrElse(null)
+ }
+ // Edge function computes intersection of smaller vertex with larger vertex
+ def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexID, Int)] = {
+ assert(et.srcAttr != null)
+ assert(et.dstAttr != null)
+ val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) {
+ (et.srcAttr, et.dstAttr)
+ } else {
+ (et.dstAttr, et.srcAttr)
+ }
+ val iter = smallSet.iterator
+ var counter: Int = 0
+ while (iter.hasNext) {
+ val vid = iter.next()
+ if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) {
+ counter += 1
+ }
+ }
+ Iterator((et.srcId, counter), (et.dstId, counter))
+ }
+ // compute the intersection along edges
+ val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _)
+ // Merge counters with the graph and divide by two since each triangle is counted twice
+ g.outerJoinVertices(counters) {
+ (vid, _, optCounter: Option[Int]) =>
+ val dblCount = optCounter.getOrElse(0)
+ // double count should be even (divisible by two)
+ assert((dblCount & 1) == 0)
+ dblCount / 2
+ }
+ } // end of TriangleCount
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
new file mode 100644
index 0000000000..60dfc1dc37
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
@@ -0,0 +1,18 @@
+package org.apache.spark
+
+import org.apache.spark.util.collection.OpenHashSet
+
+/** GraphX is a graph processing framework built on top of Spark. */
+package object graphx {
+ /**
+ * A 64-bit vertex identifier that uniquely identifies a vertex within a graph. It does not need
+ * to follow any ordering or any constraints other than uniqueness.
+ */
+ type VertexID = Long
+
+ /** Integer identifer of a graph partition. */
+ // TODO: Consider using Char.
+ type PartitionID = Int
+
+ private[graphx] type VertexSet = OpenHashSet[VertexID]
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
new file mode 100644
index 0000000000..1c5b234d74
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -0,0 +1,117 @@
+package org.apache.spark.graphx.util
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.util.Utils
+
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor}
+import org.objectweb.asm.Opcodes._
+
+
+/**
+ * Includes an utility function to test whether a function accesses a specific attribute
+ * of an object.
+ */
+private[graphx] object BytecodeUtils {
+
+ /**
+ * Test whether the given closure invokes the specified method in the specified class.
+ */
+ def invokedMethod(closure: AnyRef, targetClass: Class[_], targetMethod: String): Boolean = {
+ if (_invokedMethod(closure.getClass, "apply", targetClass, targetMethod)) {
+ true
+ } else {
+ // look at closures enclosed in this closure
+ for (f <- closure.getClass.getDeclaredFields
+ if f.getType.getName.startsWith("scala.Function")) {
+ f.setAccessible(true)
+ if (invokedMethod(f.get(closure), targetClass, targetMethod)) {
+ return true
+ }
+ }
+ return false
+ }
+ }
+
+ private def _invokedMethod(cls: Class[_], method: String,
+ targetClass: Class[_], targetMethod: String): Boolean = {
+
+ val seen = new HashSet[(Class[_], String)]
+ var stack = List[(Class[_], String)]((cls, method))
+
+ while (stack.nonEmpty) {
+ val (c, m) = stack.head
+ stack = stack.tail
+ seen.add((c, m))
+ val finder = new MethodInvocationFinder(c.getName, m)
+ getClassReader(c).accept(finder, 0)
+ for (classMethod <- finder.methodsInvoked) {
+ //println(classMethod)
+ if (classMethod._1 == targetClass && classMethod._2 == targetMethod) {
+ return true
+ } else if (!seen.contains(classMethod)) {
+ stack = classMethod :: stack
+ }
+ }
+ }
+ return false
+ }
+
+ /**
+ * Get an ASM class reader for a given class from the JAR that loaded it.
+ */
+ private def getClassReader(cls: Class[_]): ClassReader = {
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
+ }
+
+ /**
+ * Given the class name, return whether we should look into the class or not. This is used to
+ * skip examing a large quantity of Java or Scala classes that we know for sure wouldn't access
+ * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of ".").
+ */
+ private def skipClass(className: String): Boolean = {
+ val c = className
+ c.startsWith("java/") || c.startsWith("scala/") || c.startsWith("javax/")
+ }
+
+ /**
+ * Find the set of methods invoked by the specified method in the specified class.
+ * For example, after running the visitor,
+ * MethodInvocationFinder("spark/graph/Foo", "test")
+ * its methodsInvoked variable will contain the set of methods invoked directly by
+ * Foo.test(). Interface invocations are not returned as part of the result set because we cannot
+ * determine the actual metod invoked by inspecting the bytecode.
+ */
+ private class MethodInvocationFinder(className: String, methodName: String)
+ extends ClassVisitor(ASM4) {
+
+ val methodsInvoked = new HashSet[(Class[_], String)]
+
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ if (name == methodName) {
+ new MethodVisitor(ASM4) {
+ override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+ if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
+ if (!skipClass(owner)) {
+ methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
+ }
+ }
+ }
+ }
+ } else {
+ null
+ }
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
new file mode 100644
index 0000000000..57422ce3f1
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -0,0 +1,218 @@
+package org.apache.spark.graphx.util
+
+import scala.annotation.tailrec
+import scala.math._
+import scala.reflect.ClassTag
+import scala.util._
+
+import org.apache.spark._
+import org.apache.spark.serializer._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.Graph
+import org.apache.spark.graphx.Edge
+import org.apache.spark.graphx.impl.GraphImpl
+
+/** A collection of graph generating functions. */
+object GraphGenerators {
+
+ val RMATa = 0.45
+ val RMATb = 0.15
+ val RMATc = 0.15
+ val RMATd = 0.25
+
+ // Right now it just generates a bunch of edges where
+ // the edge data is the weight (default 1)
+ /**
+ * Generate a graph whose vertex out degree is log normal.
+ */
+ def logNormalGraph(sc: SparkContext, numVertices: Int): Graph[Int, Int] = {
+ // based on Pregel settings
+ val mu = 4
+ val sigma = 1.3
+
+ val vertices: RDD[(VertexID, Int)] = sc.parallelize(0 until numVertices).map{
+ src => (src, sampleLogNormal(mu, sigma, numVertices))
+ }
+ val edges = vertices.flatMap { v =>
+ generateRandomEdges(v._1.toInt, v._2, numVertices)
+ }
+ Graph(vertices, edges, 0)
+ }
+
+ def generateRandomEdges(src: Int, numEdges: Int, maxVertexID: Int): Array[Edge[Int]] = {
+ val rand = new Random()
+ Array.fill(maxVertexID) { Edge[Int](src, rand.nextInt(maxVertexID), 1) }
+ }
+
+ /**
+ * Randomly samples from a log normal distribution whose corresponding normal distribution has the
+ * the given mean and standard deviation. It uses the formula `X = exp(m+s*Z)` where `m`, `s` are
+ * the mean, standard deviation of the lognormal distribution and `Z ~ N(0, 1)`. In this function,
+ * `m = e^(mu+sigma^2/2)` and `s = sqrt[(e^(sigma^2) - 1)(e^(2*mu+sigma^2))]`.
+ *
+ * @param mu the mean of the normal distribution
+ * @param sigma the standard deviation of the normal distribution
+ * @param maxVal exclusive upper bound on the value of the sample
+ */
+ private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = {
+ val rand = new Random()
+ val m = math.exp(mu+(sigma*sigma)/2.0)
+ val s = math.sqrt((math.exp(sigma*sigma) - 1) * math.exp(2*mu + sigma*sigma))
+ // Z ~ N(0, 1)
+ var X: Double = maxVal
+
+ while (X >= maxVal) {
+ val Z = rand.nextGaussian()
+ X = math.exp(mu + sigma*Z)
+ }
+ math.round(X.toFloat)
+ }
+
+ /**
+ * A random graph generator using the R-MAT model, proposed in
+ * "R-MAT: A Recursive Model for Graph Mining" by Chakrabarti et al.
+ *
+ * See [[http://www.cs.cmu.edu/~christos/PUBLICATIONS/siam04.pdf]].
+ */
+ def rmatGraph(sc: SparkContext, requestedNumVertices: Int, numEdges: Int): Graph[Int, Int] = {
+ // let N = requestedNumVertices
+ // the number of vertices is 2^n where n=ceil(log2[N])
+ // This ensures that the 4 quadrants are the same size at all recursion levels
+ val numVertices = math.round(
+ math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt
+ var edges: Set[Edge[Int]] = Set()
+ while (edges.size < numEdges) {
+ if (edges.size % 100 == 0) {
+ println(edges.size + " edges")
+ }
+ edges += addEdge(numVertices)
+ }
+ outDegreeFromEdges(sc.parallelize(edges.toList))
+ }
+
+ private def outDegreeFromEdges[ED: ClassTag](edges: RDD[Edge[ED]]): Graph[Int, ED] = {
+ val vertices = edges.flatMap { edge => List((edge.srcId, 1)) }
+ .reduceByKey(_ + _)
+ .map{ case (vid, degree) => (vid, degree) }
+ Graph(vertices, edges, 0)
+ }
+
+ /**
+ * @param numVertices Specifies the total number of vertices in the graph (used to get
+ * the dimensions of the adjacency matrix
+ */
+ private def addEdge(numVertices: Int): Edge[Int] = {
+ //val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0)
+ val v = math.round(numVertices.toFloat/2.0).toInt
+
+ val (src, dst) = chooseCell(v, v, v)
+ Edge[Int](src, dst, 1)
+ }
+
+ /**
+ * This method recursively subdivides the the adjacency matrix into quadrants
+ * until it picks a single cell. The naming conventions in this paper match
+ * those of the R-MAT paper. There are a power of 2 number of nodes in the graph.
+ * The adjacency matrix looks like:
+ * <pre>
+ *
+ * dst ->
+ * (x,y) *************** _
+ * | | | |
+ * | a | b | |
+ * src | | | |
+ * | *************** | T
+ * \|/ | | | |
+ * | c | d | |
+ * | | | |
+ * *************** -
+ * </pre>
+ *
+ * where this represents the subquadrant of the adj matrix currently being
+ * subdivided. (x,y) represent the upper left hand corner of the subquadrant,
+ * and T represents the side length (guaranteed to be a power of 2).
+ *
+ * After choosing the next level subquadrant, we get the resulting sets
+ * of parameters:
+ * {{{
+ * quad = a, x'=x, y'=y, T'=T/2
+ * quad = b, x'=x+T/2, y'=y, T'=T/2
+ * quad = c, x'=x, y'=y+T/2, T'=T/2
+ * quad = d, x'=x+T/2, y'=y+T/2, T'=T/2
+ * }}}
+ */
+ @tailrec
+ private def chooseCell(x: Int, y: Int, t: Int): (Int, Int) = {
+ if (t <= 1) {
+ (x, y)
+ } else {
+ val newT = math.round(t.toFloat/2.0).toInt
+ pickQuadrant(RMATa, RMATb, RMATc, RMATd) match {
+ case 0 => chooseCell(x, y, newT)
+ case 1 => chooseCell(x+newT, y, newT)
+ case 2 => chooseCell(x, y+newT, newT)
+ case 3 => chooseCell(x+newT, y+newT, newT)
+ }
+ }
+ }
+
+ // TODO(crankshaw) turn result into an enum (or case class for pattern matching}
+ private def pickQuadrant(a: Double, b: Double, c: Double, d: Double): Int = {
+ if (a + b + c + d != 1.0) {
+ throw new IllegalArgumentException(
+ "R-MAT probability parameters sum to " + (a+b+c+d) + ", should sum to 1.0")
+ }
+ val rand = new Random()
+ val result = rand.nextDouble()
+ result match {
+ case x if x < a => 0 // 0 corresponds to quadrant a
+ case x if (x >= a && x < a + b) => 1 // 1 corresponds to b
+ case x if (x >= a + b && x < a + b + c) => 2 // 2 corresponds to c
+ case _ => 3 // 3 corresponds to d
+ }
+ }
+
+ /**
+ * Create `rows` by `cols` grid graph with each vertex connected to its
+ * row+1 and col+1 neighbors. Vertex ids are assigned in row major
+ * order.
+ *
+ * @param sc the spark context in which to construct the graph
+ * @param rows the number of rows
+ * @param cols the number of columns
+ *
+ * @return A graph containing vertices with the row and column ids
+ * as their attributes and edge values as 1.0.
+ */
+ def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = {
+ // Convert row column address into vertex ids (row major order)
+ def sub2ind(r: Int, c: Int): VertexID = r * cols + c
+
+ val vertices: RDD[(VertexID, (Int,Int))] =
+ sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) )
+ val edges: RDD[Edge[Double]] =
+ vertices.flatMap{ case (vid, (r,c)) =>
+ (if (r+1 < rows) { Seq( (sub2ind(r, c), sub2ind(r+1, c))) } else { Seq.empty }) ++
+ (if (c+1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c+1))) } else { Seq.empty })
+ }.map{ case (src, dst) => Edge(src, dst, 1.0) }
+ Graph(vertices, edges)
+ } // end of gridGraph
+
+ /**
+ * Create a star graph with vertex 0 being the center.
+ *
+ * @param sc the spark context in which to construct the graph
+ * @param nverts the number of vertices in the star
+ *
+ * @return A star graph containing `nverts` vertices with vertex 0
+ * being the center vertex.
+ */
+ def starGraph(sc: SparkContext, nverts: Int): Graph[Int, Int] = {
+ val edges: RDD[(VertexID, VertexID)] = sc.parallelize(1 until nverts).map(vid => (vid, 0))
+ Graph.fromEdgeTuples(edges, 1)
+ } // end of starGraph
+
+} // end of Graph Generators
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..7b02e2ed1a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.util.collection
+
+import org.apache.spark.util.collection.OpenHashSet
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[graphx]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
+ val keySet: OpenHashSet[K], var _values: Array[V])
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ /**
+ * Allocate an OpenHashMap with a fixed initial capacity
+ */
+ def this(initialCapacity: Int) =
+ this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
+
+ /**
+ * Allocate an OpenHashMap with a default initial capacity, providing a true
+ * no-argument constructor.
+ */
+ def this() = this(64)
+
+ /**
+ * Allocate an OpenHashMap with a fixed initial capacity
+ */
+ def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
+
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+ private var _oldValues: Array[V] = null
+
+ override def size = keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+
+ /** Set the value for a key */
+ def setMerge(k: K, v: V, mergeF: (V, V) => V) {
+ val pos = keySet.addWithoutResize(k)
+ val ind = pos & OpenHashSet.POSITION_MASK
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
+ _values(ind) = v
+ } else {
+ _values(ind) = mergeF(_values(ind), v)
+ }
+ keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..85e57f0c4b
--- /dev/null
+++ b/graphx/src/test/resources/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=graphx/target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+org.eclipse.jetty.LEVEL=WARN
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
new file mode 100644
index 0000000000..280f50e39a
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -0,0 +1,66 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd._
+import org.scalatest.FunSuite
+
+class GraphOpsSuite extends FunSuite with LocalSparkContext {
+
+ test("joinVertices") {
+ withSpark { sc =>
+ val vertices =
+ sc.parallelize(Seq[(VertexID, String)]((1, "one"), (2, "two"), (3, "three")), 2)
+ val edges = sc.parallelize((Seq(Edge(1, 2, "onetwo"))))
+ val g: Graph[String, String] = Graph(vertices, edges)
+
+ val tbl = sc.parallelize(Seq[(VertexID, Int)]((1, 10), (2, 20)))
+ val g1 = g.joinVertices(tbl) { (vid: VertexID, attr: String, u: Int) => attr + u }
+
+ val v = g1.vertices.collect().toSet
+ assert(v === Set((1, "one10"), (2, "two20"), (3, "three")))
+ }
+ }
+
+ test("collectNeighborIds") {
+ withSpark { sc =>
+ val chain = (0 until 100).map(x => (x, (x+1)%100) )
+ val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
+ assert(nbrs.count === chain.size)
+ assert(graph.numVertices === nbrs.count)
+ nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
+ nbrs.collect.foreach { case (vid, nbrs) =>
+ val s = nbrs.toSet
+ assert(s.contains((vid + 1) % 100))
+ assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+ }
+ }
+ }
+
+ test ("filter") {
+ withSpark { sc =>
+ val n = 5
+ val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x)))
+ val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
+ val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
+ val filteredGraph = graph.filter(
+ graph => {
+ val degrees: VertexRDD[Int] = graph.outDegrees
+ graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
+ },
+ vpred = (vid: VertexID, deg:Int) => deg > 0
+ ).cache()
+
+ val v = filteredGraph.vertices.collect().toSet
+ assert(v === Set((0,0)))
+
+ // the map is necessary because of object-reuse in the edge iterator
+ val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
+ assert(e.isEmpty)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
new file mode 100644
index 0000000000..9587f04c3e
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -0,0 +1,273 @@
+package org.apache.spark.graphx
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.PartitionStrategy._
+import org.apache.spark.rdd._
+
+class GraphSuite extends FunSuite with LocalSparkContext {
+
+ def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = {
+ Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexID, x: VertexID)), 3), "v")
+ }
+
+ test("Graph.fromEdgeTuples") {
+ withSpark { sc =>
+ val ring = (0L to 100L).zip((1L to 99L) :+ 0L)
+ val doubleRing = ring ++ ring
+ val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1)
+ assert(graph.edges.count() === doubleRing.size)
+ assert(graph.edges.collect.forall(e => e.attr == 1))
+
+ // uniqueEdges option should uniquify edges and store duplicate count in edge attributes
+ val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut))
+ assert(uniqueGraph.edges.count() === ring.size)
+ assert(uniqueGraph.edges.collect.forall(e => e.attr == 2))
+ }
+ }
+
+ test("Graph.fromEdges") {
+ withSpark { sc =>
+ val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1) }
+ val graph = Graph.fromEdges(sc.parallelize(ring), 1.0F)
+ assert(graph.edges.count() === ring.size)
+ }
+ }
+
+ test("Graph.apply") {
+ withSpark { sc =>
+ val rawEdges = (0L to 98L).zip((1L to 99L) :+ 0L)
+ val edges: RDD[Edge[Int]] = sc.parallelize(rawEdges).map { case (s, t) => Edge(s, t, 1) }
+ val vertices: RDD[(VertexID, Boolean)] = sc.parallelize((0L until 10L).map(id => (id, true)))
+ val graph = Graph(vertices, edges, false)
+ assert( graph.edges.count() === rawEdges.size )
+ // Vertices not explicitly provided but referenced by edges should be created automatically
+ assert( graph.vertices.count() === 100)
+ graph.triplets.map { et =>
+ assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
+ assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
+ }
+ }
+ }
+
+ test("triplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet ===
+ (1 to n).map(x => (0: VertexID, x: VertexID, "v", "v")).toSet)
+ }
+ }
+
+ test("partitionBy") {
+ withSpark { sc =>
+ def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0)
+ def nonemptyParts(graph: Graph[Int, Int]) = {
+ graph.edges.partitionsRDD.mapPartitions { iter =>
+ Iterator(iter.next()._2.iterator.toList)
+ }.filter(_.nonEmpty)
+ }
+ val identicalEdges = List((0L, 1L), (0L, 1L))
+ val canonicalEdges = List((0L, 1L), (1L, 0L))
+ val sameSrcEdges = List((0L, 1L), (0L, 2L))
+
+ // The two edges start out in different partitions
+ for (edges <- List(identicalEdges, canonicalEdges, sameSrcEdges)) {
+ assert(nonemptyParts(mkGraph(edges)).count === 2)
+ }
+ // partitionBy(RandomVertexCut) puts identical edges in the same partition
+ assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(RandomVertexCut)).count === 1)
+ // partitionBy(EdgePartition1D) puts same-source edges in the same partition
+ assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1)
+ // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into
+ // the same partition
+ assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1)
+ // partitionBy(EdgePartition2D) puts identical edges in the same partition
+ assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1)
+
+ // partitionBy(EdgePartition2D) ensures that vertices need only be replicated to 2 * sqrt(p)
+ // partitions
+ val n = 100
+ val p = 100
+ val verts = 1 to n
+ val graph = Graph.fromEdgeTuples(sc.parallelize(verts.flatMap(x =>
+ verts.filter(y => y % x == 0).map(y => (x: VertexID, y: VertexID))), p), 0)
+ assert(graph.edges.partitions.length === p)
+ val partitionedGraph = graph.partitionBy(EdgePartition2D)
+ assert(graph.edges.partitions.length === p)
+ val bound = 2 * math.sqrt(p)
+ // Each vertex should be replicated to at most 2 * sqrt(p) partitions
+ val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter =>
+ val part = iter.next()._2
+ Iterator((part.srcIds ++ part.dstIds).toSet)
+ }.collect
+ assert(verts.forall(id => partitionSets.count(_.contains(id)) <= bound))
+ // This should not be true for the default hash partitioning
+ val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter =>
+ val part = iter.next()._2
+ Iterator((part.srcIds ++ part.dstIds).toSet)
+ }.collect
+ assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))
+ }
+ }
+
+ test("mapVertices") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ // mapVertices preserving type
+ val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2")
+ assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexID, "v2")).toSet)
+ // mapVertices changing type
+ val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length)
+ assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexID, 1)).toSet)
+ }
+ }
+
+ test("mapEdges") {
+ withSpark { sc =>
+ val n = 3
+ val star = starGraph(sc, n)
+ val starWithEdgeAttrs = star.mapEdges(e => e.dstId)
+
+ val edges = starWithEdgeAttrs.edges.collect()
+ assert(edges.size === n)
+ assert(edges.toSet === (1 to n).map(x => Edge(0, x, x)).toSet)
+ }
+ }
+
+ test("mapTriplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet ===
+ (1L to n).map(x => Edge(0, x, "vv")).toSet)
+ }
+ }
+
+ test("reverse") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexID, 1)).toSet)
+ }
+ }
+
+ test("subgraph") {
+ withSpark { sc =>
+ // Create a star graph of 10 veritces.
+ val n = 10
+ val star = starGraph(sc, n)
+ // Take only vertices whose vids are even
+ val subgraph = star.subgraph(vpred = (vid, attr) => vid % 2 == 0)
+
+ // We should have 5 vertices.
+ assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet)
+
+ // And 4 edges.
+ assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
+ }
+ }
+
+ test("mask") {
+ withSpark { sc =>
+ val n = 5
+ val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x)))
+ val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
+ val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
+
+ val subgraph = graph.subgraph(
+ e => e.dstId != 4L,
+ (vid, vdata) => vid != 3L
+ ).mapVertices((vid, vdata) => -1).mapEdges(e => -1)
+
+ val projectedGraph = graph.mask(subgraph)
+
+ val v = projectedGraph.vertices.collect().toSet
+ assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5)))
+
+ // the map is necessary because of object-reuse in the edge iterator
+ val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
+ assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5)))
+
+ }
+ }
+
+ test("groupEdges") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ val doubleStar = Graph.fromEdgeTuples(
+ sc.parallelize((1 to n).flatMap(x =>
+ List((0: VertexID, x: VertexID), (0: VertexID, x: VertexID))), 1), "v")
+ val star2 = doubleStar.groupEdges { (a, b) => a}
+ assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) ===
+ star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]))
+ assert(star2.vertices.collect.toSet === star.vertices.collect.toSet)
+ }
+ }
+
+ test("mapReduceTriplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n).mapVertices { (_, _) => 0 }.cache()
+ val starDeg = star.joinVertices(star.degrees){ (vid, oldV, deg) => deg }
+ val neighborDegreeSums = starDeg.mapReduceTriplets(
+ edge => Iterator((edge.srcId, edge.dstAttr), (edge.dstId, edge.srcAttr)),
+ (a: Int, b: Int) => a + b)
+ assert(neighborDegreeSums.collect().toSet === (0 to n).map(x => (x, n)).toSet)
+
+ // activeSetOpt
+ val allPairs = for (x <- 1 to n; y <- 1 to n) yield (x: VertexID, y: VertexID)
+ val complete = Graph.fromEdgeTuples(sc.parallelize(allPairs, 3), 0)
+ val vids = complete.mapVertices((vid, attr) => vid).cache()
+ val active = vids.vertices.filter { case (vid, attr) => attr % 2 == 0 }
+ val numEvenNeighbors = vids.mapReduceTriplets(et => {
+ // Map function should only run on edges with destination in the active set
+ if (et.dstId % 2 != 0) {
+ throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId))
+ }
+ Iterator((et.srcId, 1))
+ }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet
+ assert(numEvenNeighbors === (1 to n).map(x => (x: VertexID, n / 2)).toSet)
+
+ // outerJoinVertices followed by mapReduceTriplets(activeSetOpt)
+ val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexID, (x+1) % n: VertexID)), 3)
+ val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache()
+ val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache()
+ val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ val numOddNeighbors = changedGraph.mapReduceTriplets(et => {
+ // Map function should only run on edges with source in the active set
+ if (et.srcId % 2 != 1) {
+ throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId))
+ }
+ Iterator((et.dstId, 1))
+ }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet
+ assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexID, 1)).toSet)
+
+ }
+ }
+
+ test("outerJoinVertices") {
+ withSpark { sc =>
+ val n = 5
+ val reverseStar = starGraph(sc, n).reverse.cache()
+ // outerJoinVertices changing type
+ val reverseStarDegrees =
+ reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) }
+ val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets(
+ et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)),
+ (a: Int, b: Int) => a + b).collect.toSet
+ assert(neighborDegreeSums === Set((0: VertexID, n)) ++ (1 to n).map(x => (x: VertexID, 0)))
+ // outerJoinVertices preserving type
+ val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString }
+ val newReverseStar =
+ reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") }
+ assert(newReverseStar.vertices.map(_._2).collect.toSet ===
+ (0 to n).map(x => "v%d".format(x)).toSet)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
new file mode 100644
index 0000000000..aa9ba84084
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.graphx
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+
+/**
+ * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped
+ * after each test.
+*/
+trait LocalSparkContext {
+ /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
+ def withSpark[T](f: SparkContext => T) = {
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ f(sc)
+ } finally {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
new file mode 100644
index 0000000000..bceff11b8e
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -0,0 +1,41 @@
+package org.apache.spark.graphx
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd._
+
+class PregelSuite extends FunSuite with LocalSparkContext {
+
+ test("1 iteration") {
+ withSpark { sc =>
+ val n = 5
+ val starEdges = (1 to n).map(x => (0: VertexID, x: VertexID))
+ val star = Graph.fromEdgeTuples(sc.parallelize(starEdges, 3), "v").cache()
+ val result = Pregel(star, 0)(
+ (vid, attr, msg) => attr,
+ et => Iterator.empty,
+ (a: Int, b: Int) => throw new Exception("mergeMsg run unexpectedly"))
+ assert(result.vertices.collect.toSet === star.vertices.collect.toSet)
+ }
+ }
+
+ test("chain propagation") {
+ withSpark { sc =>
+ val n = 5
+ val chain = Graph.fromEdgeTuples(
+ sc.parallelize((1 until n).map(x => (x: VertexID, x + 1: VertexID)), 3),
+ 0).cache()
+ assert(chain.vertices.collect.toSet === (1 to n).map(x => (x: VertexID, 0)).toSet)
+ val chainWithSeed = chain.mapVertices { (vid, attr) => if (vid == 1) 1 else 0 }.cache()
+ assert(chainWithSeed.vertices.collect.toSet ===
+ Set((1: VertexID, 1)) ++ (2 to n).map(x => (x: VertexID, 0)).toSet)
+ val result = Pregel(chainWithSeed, 0)(
+ (vid, attr, msg) => math.max(msg, attr),
+ et => if (et.dstAttr != et.srcAttr) Iterator((et.dstId, et.srcAttr)) else Iterator.empty,
+ (a: Int, b: Int) => math.max(a, b))
+ assert(result.vertices.collect.toSet ===
+ chain.vertices.mapValues { (vid, attr) => attr + 1 }.collect.toSet)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
new file mode 100644
index 0000000000..3ba412c1f8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
@@ -0,0 +1,183 @@
+package org.apache.spark.graphx
+
+import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.impl.MsgRDDFunctions._
+import org.apache.spark.serializer.SerializationStream
+
+
+class SerializerSuite extends FunSuite with LocalSparkContext {
+
+ test("IntVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Int](3, 4, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("LongVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Long](3, 4, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("DoubleVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("IntAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new IntAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new IntAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Int) = inStrm.readObject()
+ val inMsg2: (VertexID, Int) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("LongAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 1L << 32)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new LongAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new LongAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Long) = inStrm.readObject()
+ val inMsg2: (VertexID, Long) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("DoubleAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 5.0)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new DoubleAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new DoubleAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Double) = inStrm.readObject()
+ val inMsg2: (VertexID, Double) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("TestShuffleVertexBroadcastMsg") {
+ withSpark { sc =>
+ val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
+ new VertexBroadcastMsg[Int](pid, pid, pid)
+ }
+ bmsgs.partitionBy(new HashPartitioner(3)).collect()
+ }
+ }
+
+ test("variable long encoding") {
+ def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
+ val bout = new ByteArrayOutputStream
+ val stream = new ShuffleSerializationStream(bout) {
+ def writeObject[T](t: T): SerializationStream = {
+ writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
+ this
+ }
+ }
+ stream.writeObject(v)
+
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val dstream = new ShuffleDeserializationStream(bin) {
+ def readObject[T](): T = {
+ readVarLong(optimizePositive).asInstanceOf[T]
+ }
+ }
+ val read = dstream.readObject[Long]()
+ assert(read === v)
+ }
+
+ // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
+ val d = Random.nextLong() % 128
+ Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
+ 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
+ testVarLongEncoding(number, optimizePositive = false)
+ testVarLongEncoding(number, optimizePositive = true)
+ testVarLongEncoding(-number, optimizePositive = false)
+ testVarLongEncoding(-number, optimizePositive = true)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
new file mode 100644
index 0000000000..d94a3aa67c
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -0,0 +1,85 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd._
+import org.scalatest.FunSuite
+
+class VertexRDDSuite extends FunSuite with LocalSparkContext {
+
+ def vertices(sc: SparkContext, n: Int) = {
+ VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
+ }
+
+ test("filter") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val evens = verts.filter(q => ((q._2 % 2) == 0))
+ assert(evens.count === (0 to n).filter(_ % 2 == 0).size)
+ }
+ }
+
+ test("mapValues") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val negatives = verts.mapValues(x => -x).cache() // Allow joining b with a derived RDD of b
+ assert(negatives.count === n + 1)
+ }
+ }
+
+ test("diff") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val flipEvens = verts.mapValues(x => if (x % 2 == 0) -x else x).cache()
+ // diff should keep only the changed vertices
+ assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet)
+ // diff should keep the vertex values from `other`
+ assert(flipEvens.diff(verts).map(_._2).collect().toSet === (2 to n by 2).toSet)
+ }
+ }
+
+ test("leftJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
+ // leftJoin with another VertexRDD
+ assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ // leftJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ }
+ }
+
+ test("innerJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
+ // innerJoin with another VertexRDD
+ assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet)
+ // innerJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet) }
+ }
+
+ test("aggregateUsingIndex") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val messageTargets = (0 to n) ++ (0 to n by 2)
+ val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1)))
+ assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet ===
+ (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
new file mode 100644
index 0000000000..eb82436f09
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -0,0 +1,76 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx._
+
+class EdgePartitionSuite extends FunSuite {
+
+ test("reverse") {
+ val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0))
+ val reversedEdges = List(Edge(0, 2, 0), Edge(1, 0, 0), Edge(2, 1, 0))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.reverse.iterator.map(_.copy()).toList === reversedEdges)
+ assert(edgePartition.reverse.reverse.iterator.map(_.copy()).toList === edges)
+ }
+
+ test("map") {
+ val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.map(e => e.srcId + e.dstId).iterator.map(_.copy()).toList ===
+ edges.map(e => e.copy(attr = e.srcId + e.dstId)))
+ }
+
+ test("groupEdges") {
+ val edges = List(
+ Edge(0, 1, 1), Edge(1, 2, 2), Edge(2, 0, 4), Edge(0, 1, 8), Edge(1, 2, 16), Edge(2, 0, 32))
+ val groupedEdges = List(Edge(0, 1, 9), Edge(1, 2, 18), Edge(2, 0, 36))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges)
+ }
+
+ test("indexIterator") {
+ val edgesFrom0 = List(Edge(0, 1, 0))
+ val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
+ val sortedEdges = edgesFrom0 ++ edgesFrom1
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- Random.shuffle(sortedEdges)) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
+ assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
+ assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
+ }
+
+ test("innerJoin") {
+ def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A] = {
+ val builder = new EdgePartitionBuilder[A]
+ for ((src, dst, attr) <- xs) { builder.add(src: VertexID, dst: VertexID, attr) }
+ builder.toEdgePartition
+ }
+ val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
+ val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
+ val a = makeEdgePartition(aList)
+ val b = makeEdgePartition(bList)
+
+ assert(a.innerJoin(b) { (src, dst, a, b) => a }.iterator.map(_.copy()).toList ===
+ List(Edge(0, 1, 0), Edge(1, 0, 0), Edge(5, 5, 0)))
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
new file mode 100644
index 0000000000..d37d64e8c8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -0,0 +1,113 @@
+package org.apache.spark.graphx.impl
+
+import org.apache.spark.graphx._
+import org.scalatest.FunSuite
+
+class VertexPartitionSuite extends FunSuite {
+
+ test("isDefined, filter") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
+ assert(vp.isDefined(0))
+ assert(!vp.isDefined(1))
+ assert(!vp.isDefined(2))
+ assert(!vp.isDefined(-1))
+ }
+
+ test("isActive, numActives, replaceActives") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1)))
+ .filter { (vid, attr) => vid == 0 }
+ .replaceActives(Iterator(0, 2, 0))
+ assert(vp.isActive(0))
+ assert(!vp.isActive(1))
+ assert(vp.isActive(2))
+ assert(!vp.isActive(-1))
+ assert(vp.numActives == Some(2))
+ }
+
+ test("map") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).map { (vid, attr) => 2 }
+ assert(vp(0) === 2)
+ }
+
+ test("diff") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3a = vp.map { (vid, attr) => 2 }
+ val vp3b = VertexPartition(vp3a.iterator)
+ // diff with same index
+ val diff1 = vp2.diff(vp3a)
+ assert(diff1(0) === 2)
+ assert(diff1(1) === 2)
+ assert(diff1(2) === 2)
+ assert(!diff1.isDefined(2))
+ // diff with different indexes
+ val diff2 = vp2.diff(vp3b)
+ assert(diff2(0) === 2)
+ assert(diff2(1) === 2)
+ assert(diff2(2) === 2)
+ assert(!diff2.isDefined(2))
+ }
+
+ test("leftJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // leftJoin with same index
+ val join1 = vp.leftJoin(vp2a) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin with different indexes
+ val join2 = vp.leftJoin(vp2b) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin an iterator
+ val join3 = vp.leftJoin(vp2a.iterator) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ }
+
+ test("innerJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // innerJoin with same index
+ val join1 = vp.innerJoin(vp2a) { (vid, a, b) => b }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin with different indexes
+ val join2 = vp.innerJoin(vp2b) { (vid, a, b) => b }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin an iterator
+ val join3 = vp.innerJoin(vp2a.iterator) { (vid, a, b) => b }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2)))
+ }
+
+ test("createUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.createUsingIndex(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp.index === vp2.index)
+ }
+
+ test("innerJoinKeepLeft") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.innerJoinKeepLeft(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp2(1) === 1)
+ }
+
+ test("aggregateUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val messages = List((0L, "a"), (2L, "b"), (0L, "c"), (3L, "d"))
+ val vp2 = vp.aggregateUsingIndex[String](messages.iterator, _ + _)
+ assert(vp2.iterator.toSet === Set((0L, "ac"), (2L, "b")))
+ }
+
+ test("reindex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3 = vp2.reindex()
+ assert(vp2.iterator.toSet === vp3.iterator.toSet)
+ assert(vp2(2) === 1)
+ assert(vp3.index.getPos(2) === -1)
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
new file mode 100644
index 0000000000..27c8705bca
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -0,0 +1,113 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+
+ test("Grid Connected Components") {
+ withSpark { sc =>
+ val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
+ val ccGraph = gridGraph.connectedComponents()
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ assert(maxCCid === 0)
+ }
+ } // end of Grid connected components
+
+
+ test("Reverse Grid Connected Components") {
+ withSpark { sc =>
+ val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
+ val ccGraph = gridGraph.connectedComponents()
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ assert(maxCCid === 0)
+ }
+ } // end of Grid connected components
+
+
+ test("Chain Connected Components") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val chain2 = (10 until 20).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0)
+ val ccGraph = twoChains.connectedComponents()
+ val vertices = ccGraph.vertices.collect()
+ for ( (id, cc) <- vertices ) {
+ if(id < 10) { assert(cc === 0) }
+ else { assert(cc === 10) }
+ }
+ val ccMap = vertices.toMap
+ for (id <- 0 until 20) {
+ if (id < 10) {
+ assert(ccMap(id) === 0)
+ } else {
+ assert(ccMap(id) === 10)
+ }
+ }
+ }
+ } // end of chain connected components
+
+ test("Reverse Chain Connected Components") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val chain2 = (10 until 20).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse
+ val ccGraph = twoChains.connectedComponents()
+ val vertices = ccGraph.vertices.collect
+ for ( (id, cc) <- vertices ) {
+ if (id < 10) {
+ assert(cc === 0)
+ } else {
+ assert(cc === 10)
+ }
+ }
+ val ccMap = vertices.toMap
+ for ( id <- 0 until 20 ) {
+ if (id < 10) {
+ assert(ccMap(id) === 0)
+ } else {
+ assert(ccMap(id) === 10)
+ }
+ }
+ }
+ } // end of reverse chain connected components
+
+ test("Connected Components on a Toy Connected Graph") {
+ withSpark { sc =>
+ // Create an RDD for the vertices
+ val users: RDD[(VertexID, (String, String))] =
+ sc.parallelize(Array((3L, ("rxin", "student")), (7L, ("jgonzal", "postdoc")),
+ (5L, ("franklin", "prof")), (2L, ("istoica", "prof")),
+ (4L, ("peter", "student"))))
+ // Create an RDD for edges
+ val relationships: RDD[Edge[String]] =
+ sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
+ Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"),
+ Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague")))
+ // Edges are:
+ // 2 ---> 5 ---> 3
+ // | \
+ // V \|
+ // 4 ---> 0 7
+ //
+ // Define a default user in case there are relationship with missing user
+ val defaultUser = ("John Doe", "Missing")
+ // Build the initial Graph
+ val graph = Graph(users, relationships, defaultUser)
+ val ccGraph = graph.connectedComponents()
+ val vertices = ccGraph.vertices.collect
+ for ( (id, cc) <- vertices ) {
+ assert(cc == 0)
+ }
+ }
+ } // end of toy connected components
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
new file mode 100644
index 0000000000..fe7e4261f8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -0,0 +1,119 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.lib._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+object GridPageRank {
+ def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = {
+ val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int])
+ val outDegree = Array.fill(nRows * nCols)(0)
+ // Convert row column address into vertex ids (row major order)
+ def sub2ind(r: Int, c: Int): Int = r * nCols + c
+ // Make the grid graph
+ for (r <- 0 until nRows; c <- 0 until nCols) {
+ val ind = sub2ind(r,c)
+ if (r+1 < nRows) {
+ outDegree(ind) += 1
+ inNbrs(sub2ind(r+1,c)) += ind
+ }
+ if (c+1 < nCols) {
+ outDegree(ind) += 1
+ inNbrs(sub2ind(r,c+1)) += ind
+ }
+ }
+ // compute the pagerank
+ var pr = Array.fill(nRows * nCols)(resetProb)
+ for (iter <- 0 until nIter) {
+ val oldPr = pr
+ pr = new Array[Double](nRows * nCols)
+ for (ind <- 0 until (nRows * nCols)) {
+ pr(ind) = resetProb + (1.0 - resetProb) *
+ inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum
+ }
+ }
+ (0L until (nRows * nCols)).zip(pr)
+ }
+
+}
+
+
+class PageRankSuite extends FunSuite with LocalSparkContext {
+
+ def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
+ a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
+ .map { case (id, error) => error }.sum
+ }
+
+ test("Star PageRank") {
+ withSpark { sc =>
+ val nVertices = 100
+ val starGraph = GraphGenerators.starGraph(sc, nVertices).cache()
+ val resetProb = 0.15
+ val errorTol = 1.0e-5
+
+ val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices
+ val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache()
+
+ // Static PageRank should only take 2 iterations to converge
+ val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) =>
+ if (pr1 != pr2) 1 else 0
+ }.map { case (vid, test) => test }.sum
+ assert(notMatching === 0)
+
+ val staticErrors = staticRanks2.map { case (vid, pr) =>
+ val correct = (vid > 0 && pr == resetProb) ||
+ (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
+ if (!correct) 1 else 0
+ }
+ assert(staticErrors.sum === 0)
+
+ val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache()
+ assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
+ }
+ } // end of test Star PageRank
+
+
+
+ test("Grid PageRank") {
+ withSpark { sc =>
+ val rows = 10
+ val cols = 10
+ val resetProb = 0.15
+ val tol = 0.0001
+ val numIter = 50
+ val errorTol = 1.0e-5
+ val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache()
+
+ val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache()
+ val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache()
+ val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache()
+
+ assert(compareRanks(staticRanks, referenceRanks) < errorTol)
+ assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
+ }
+ } // end of Grid PageRank
+
+
+ test("Chain PageRank") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
+ val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val resetProb = 0.15
+ val tol = 0.0001
+ val numIter = 10
+ val errorTol = 1.0e-5
+
+ val staticRanks = chain.staticPageRank(numIter, resetProb).vertices
+ val dynamicRanks = chain.pageRank(tol, resetProb).vertices
+
+ assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
new file mode 100644
index 0000000000..e173c652a5
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -0,0 +1,31 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
+
+ test("Test SVD++ with mean square error on training set") {
+ withSpark { sc =>
+ val svdppErr = 8.0
+ val edges = sc.textFile("mllib/data/als/test.data").map { line =>
+ val fields = line.split(",")
+ Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
+ }
+ val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
+ var (graph, u) = SVDPlusPlus.run(edges, conf)
+ graph.cache()
+ val err = graph.vertices.collect.map{ case (vid, vd) =>
+ if (vid % 2 == 1) vd._4 else 0.0
+ }.reduce(_ + _) / graph.triplets.collect.size
+ assert(err <= svdppErr)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
new file mode 100644
index 0000000000..0458311661
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -0,0 +1,57 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+
+ test("Island Strongly Connected Components") {
+ withSpark { sc =>
+ val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
+ val edges = sc.parallelize(Seq.empty[Edge[Int]])
+ val graph = Graph(vertices, edges)
+ val sccGraph = graph.stronglyConnectedComponents(5)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ assert(id == scc)
+ }
+ }
+ }
+
+ test("Cycle Strongly Connected Components") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
+ val graph = Graph.fromEdgeTuples(rawEdges, -1)
+ val sccGraph = graph.stronglyConnectedComponents(20)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ assert(0L == scc)
+ }
+ }
+ }
+
+ test("2 Cycle Strongly Connected Components") {
+ withSpark { sc =>
+ val edges =
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++
+ Array(6L -> 0L, 5L -> 7L)
+ val rawEdges = sc.parallelize(edges)
+ val graph = Graph.fromEdgeTuples(rawEdges, -1)
+ val sccGraph = graph.stronglyConnectedComponents(20)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ if (id < 3)
+ assert(0L == scc)
+ else if (id < 6)
+ assert(3L == scc)
+ else
+ assert(id == scc)
+ }
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
new file mode 100644
index 0000000000..3452ce9764
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -0,0 +1,70 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut
+
+
+class TriangleCountSuite extends FunSuite with LocalSparkContext {
+
+ test("Count a single triangle") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect.foreach { case (vid, count) => assert(count === 1) }
+ }
+ }
+
+ test("Count two triangles") {
+ withSpark { sc =>
+ val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
+ val rawEdges = sc.parallelize(triangles, 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect().foreach { case (vid, count) =>
+ if (vid == 0) {
+ assert(count === 2)
+ } else {
+ assert(count === 1)
+ }
+ }
+ }
+ }
+
+ test("Count two triangles with bi-directed edges") {
+ withSpark { sc =>
+ val triangles =
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
+ val revTriangles = triangles.map { case (a,b) => (b,a) }
+ val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect().foreach { case (vid, count) =>
+ if (vid == 0) {
+ assert(count === 4)
+ } else {
+ assert(count === 2)
+ }
+ }
+ }
+ }
+
+ test("Count a single triangle with duplicate edges") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect.foreach { case (vid, count) => assert(count === 1) }
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
new file mode 100644
index 0000000000..11db339750
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
@@ -0,0 +1,93 @@
+package org.apache.spark.graphx.util
+
+import org.scalatest.FunSuite
+
+
+class BytecodeUtilsSuite extends FunSuite {
+
+ import BytecodeUtilsSuite.TestClass
+
+ test("closure invokes a method") {
+ val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); }
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+
+ val c2 = {e: TestClass => println(e.foo); println(e.bar); }
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "bar"))
+ assert(!BytecodeUtils.invokedMethod(c2, classOf[TestClass], "baz"))
+
+ val c3 = {e: TestClass => println(e.foo); }
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "bar"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "baz"))
+ }
+
+ test("closure inside a closure invokes a method") {
+ val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); }
+ val c2 = {e: TestClass => c1(e); println(e.foo); }
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "baz"))
+ }
+
+ test("closure inside a closure inside a closure invokes a method") {
+ val c1 = {e: TestClass => println(e.baz); }
+ val c2 = {e: TestClass => c1(e); println(e.foo); }
+ val c3 = {e: TestClass => c2(e) }
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "baz"))
+ }
+
+ test("closure calling a function that invokes a method") {
+ def zoo(e: TestClass) {
+ println(e.baz)
+ }
+ val c1 = {e: TestClass => zoo(e)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ test("closure calling a function that invokes a method which uses another closure") {
+ val c2 = {e: TestClass => println(e.baz)}
+ def zoo(e: TestClass) {
+ c2(e)
+ }
+ val c1 = {e: TestClass => zoo(e)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ test("nested closure") {
+ val c2 = {e: TestClass => println(e.baz)}
+ def zoo(e: TestClass, c: TestClass => Unit) {
+ c(e)
+ }
+ val c1 = {e: TestClass => zoo(e, c2)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ // The following doesn't work yet, because the byte code doesn't contain any information
+ // about what exactly "c" is.
+// test("invoke interface") {
+// val c1 = {e: TestClass => c(e)}
+// assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+// assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+// assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+// }
+
+ private val c = {e: TestClass => println(e.baz)}
+}
+
+
+object BytecodeUtilsSuite {
+ class TestClass(val foo: Int, val bar: Long) {
+ def baz: Boolean = false
+ }
+}