aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala4
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala20
2 files changed, 22 insertions, 2 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
index 897c7ee12a..f1550ac2e1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
@@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
* partitioner that allows co-partitioning with `partitionsRDD`.
*/
override val partitioner =
- partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+ partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size)))
override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 9da0064104..ed9876b8dc 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("non-default number of edge partitions") {
+ val n = 10
+ val defaultParallelism = 3
+ val numEdgePartitions = 4
+ assert(defaultParallelism != numEdgePartitions)
+ val conf = new org.apache.spark.SparkConf()
+ .set("spark.default.parallelism", defaultParallelism.toString)
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)),
+ numEdgePartitions)
+ val graph = Graph.fromEdgeTuples(edges, 1)
+ val neighborAttrSums = graph.mapReduceTriplets[Int](
+ et => Iterator((et.dstId, et.srcAttr)), _ + _)
+ assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n)))
+ } finally {
+ sc.stop()
+ }
+ }
+
}