aboutsummaryrefslogtreecommitdiff
path: root/graphx/src
diff options
context:
space:
mode:
Diffstat (limited to 'graphx/src')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Graph.scala10
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala8
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala6
3 files changed, 20 insertions, 4 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index c4f9d6514c..14ae50e665 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -106,10 +106,20 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
/**
* Repartitions the edges in the graph according to `partitionStrategy`.
+ *
+ * @param the partitioning strategy to use when partitioning the edges in the graph.
*/
def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED]
/**
+ * Repartitions the edges in the graph according to `partitionStrategy`.
+ *
+ * @param the partitioning strategy to use when partitioning the edges in the graph.
+ * @param numPartitions the number of edge partitions in the new graph.
+ */
+ def partitionBy(partitionStrategy: PartitionStrategy, numPartitions: Int): Graph[VD, ED]
+
+ /**
* Transforms each vertex attribute in the graph using the map function.
*
* @note The new graph has the same structure. As a consequence the underlying index structures
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
index ef412cfd4e..5e7e72a764 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
@@ -114,9 +114,11 @@ object PartitionStrategy {
*/
case object CanonicalRandomVertexCut extends PartitionStrategy {
override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = {
- val lower = math.min(src, dst)
- val higher = math.max(src, dst)
- math.abs((lower, higher).hashCode()) % numParts
+ if (src < dst) {
+ math.abs((src, dst).hashCode()) % numParts
+ } else {
+ math.abs((dst, src).hashCode()) % numParts
+ }
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 59d9a8808e..15ea05cbe2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -74,7 +74,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}
override def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] = {
- val numPartitions = edges.partitions.size
+ partitionBy(partitionStrategy, edges.partitions.size)
+ }
+
+ override def partitionBy(
+ partitionStrategy: PartitionStrategy, numPartitions: Int): Graph[VD, ED] = {
val edTag = classTag[ED]
val vdTag = classTag[VD]
val newEdges = edges.withPartitionsRDD(edges.map { e =>