aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2014-02-11 14:48:59 -0800
committerReynold Xin <rxin@apache.org>2014-02-11 14:48:59 -0800
commitb0dab1bb9f4cfacae68b106a44d9b14f6bea3d29 (patch)
tree88b9686b7aefd99fe4510b2d74bca80ed6e2d418 /core
parentba38d9892ec922ff11f204cd4c1b8ddc90f1bd55 (diff)
downloadspark-b0dab1bb9f4cfacae68b106a44d9b14f6bea3d29.tar.gz
spark-b0dab1bb9f4cfacae68b106a44d9b14f6bea3d29.tar.bz2
spark-b0dab1bb9f4cfacae68b106a44d9b14f6bea3d29.zip
Merge pull request #571 from holdenk/switchtobinarysearch.
SPARK-1072 Use binary search when needed in RangePartioner Author: Holden Karau <holden@pigscanfly.ca> Closes #571 and squashes the following commits: f31a2e1 [Holden Karau] Swith to using CollectionsUtils in Partitioner 4c7a0c3 [Holden Karau] Add CollectionsUtil as suggested by aarondav 7099962 [Holden Karau] Add the binary search to only init once 1bef01d [Holden Karau] CR feedback a21e097 [Holden Karau] Use binary search if we have more than 1000 elements inside of RangePartitioner
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala29
3 files changed, 91 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index cfba43dec3..ad99882264 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.CollectionsUtils
import org.apache.spark.util.Utils
/**
@@ -118,12 +119,26 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V](
def numPartitions = partitions
+ private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
+
def getPartition(key: Any): Int = {
- // TODO: Use a binary search here if number of partitions is large
val k = key.asInstanceOf[K]
var partition = 0
- while (partition < rangeBounds.length && k > rangeBounds(partition)) {
- partition += 1
+ if (rangeBounds.length < 1000) {
+ // If we have less than 100 partitions naive search
+ while (partition < rangeBounds.length && k > rangeBounds(partition)) {
+ partition += 1
+ }
+ } else {
+ // Determine which binary search method to use only once.
+ partition = binarySearch(rangeBounds, k)
+ // binarySearch either returns the match location or -[insertion point]-1
+ if (partition < 0) {
+ partition = -partition-1
+ }
+ if (partition > rangeBounds.length) {
+ partition = rangeBounds.length
+ }
}
if (ascending) {
partition
diff --git a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
new file mode 100644
index 0000000000..db3db87e66
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util
+
+import scala.Array
+import scala.reflect._
+
+object CollectionsUtils {
+ def makeBinarySearch[K <% Ordered[K] : ClassTag] : (Array[K], K) => Int = {
+ classTag[K] match {
+ case ClassTag.Float =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Float]], x.asInstanceOf[Float])
+ case ClassTag.Double =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Double]], x.asInstanceOf[Double])
+ case ClassTag.Byte =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Byte]], x.asInstanceOf[Byte])
+ case ClassTag.Char =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Char]], x.asInstanceOf[Char])
+ case ClassTag.Short =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Short]], x.asInstanceOf[Short])
+ case ClassTag.Int =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Int]], x.asInstanceOf[Int])
+ case ClassTag.Long =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Long]], x.asInstanceOf[Long])
+ case _ =>
+ (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 1374d01774..1c5d5ea436 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark
import scala.math.abs
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
+import org.scalatest.{FunSuite, PrivateMethodTester}
import org.apache.spark.SparkContext._
import org.apache.spark.util.StatCounter
import org.apache.spark.rdd.RDD
-class PartitioningSuite extends FunSuite with SharedSparkContext {
+class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
@@ -67,6 +67,31 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
assert(descendingP4 != p4)
}
+ test("RangePartitioner getPartition") {
+ val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
+ // We have different behaviour of getPartition for partitions with less than 1000 and more than
+ // 1000 partitions.
+ val partitionSizes = List(1, 2, 10, 100, 500, 1000, 1500)
+ val partitioners = partitionSizes.map(p => (p, new RangePartitioner(p, rdd)))
+ val decoratedRangeBounds = PrivateMethod[Array[Int]]('rangeBounds)
+ partitioners.map { case (numPartitions, partitioner) =>
+ val rangeBounds = partitioner.invokePrivate(decoratedRangeBounds())
+ 1.to(1000).map { element => {
+ val partition = partitioner.getPartition(element)
+ if (numPartitions > 1) {
+ if (partition < rangeBounds.size) {
+ assert(element <= rangeBounds(partition))
+ }
+ if (partition > 0) {
+ assert(element > rangeBounds(partition - 1))
+ }
+ } else {
+ assert(partition === 0)
+ }
+ }}
+ }
+ }
+
test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)