aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala2
-rw-r--r--python/pyspark/rdd.py39
-rw-r--r--python/pyspark/tests.py16
3 files changed, 51 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index daea2617e6..af9e31ba7b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -993,7 +993,7 @@ abstract class RDD[T: ClassTag](
*/
@Experimental
def countApproxDistinct(p: Int, sp: Int): Long = {
- require(p >= 4, s"p ($p) must be greater than 0")
+ require(p >= 4, s"p ($p) must be at least 4")
require(sp <= 32, s"sp ($sp) cannot be greater than 32")
require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)")
val zeroCounter = new HyperLogLogPlus(p, sp)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d80fad796..6fc9f66bc5 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -62,7 +62,7 @@ def portable_hash(x):
>>> portable_hash(None)
0
- >>> portable_hash((None, 1))
+ >>> portable_hash((None, 1)) & 0xffffffff
219750521
"""
if x is None:
@@ -72,7 +72,7 @@ def portable_hash(x):
for i in x:
h ^= portable_hash(i)
h *= 1000003
- h &= 0xffffffff
+ h &= sys.maxint
h ^= len(x)
if h == -1:
h = -2
@@ -1942,7 +1942,7 @@ class RDD(object):
return True
return False
- def _to_jrdd(self):
+ def _to_java_object_rdd(self):
""" Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pyrolite, whenever the
@@ -1977,7 +1977,7 @@ class RDD(object):
>>> (rdd.sumApprox(1000) - r) / r < 0.05
True
"""
- jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_jrdd()
+ jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
r = jdrdd.sumApprox(timeout, confidence).getFinalValue()
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
@@ -1993,11 +1993,40 @@ class RDD(object):
>>> (rdd.meanApprox(1000) - r) / r < 0.05
True
"""
- jrdd = self.map(float)._to_jrdd()
+ jrdd = self.map(float)._to_java_object_rdd()
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
r = jdrdd.meanApprox(timeout, confidence).getFinalValue()
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
+ def countApproxDistinct(self, relativeSD=0.05):
+ """
+ :: Experimental ::
+ Return approximate number of distinct elements in the RDD.
+
+ The algorithm used is based on streamlib's implementation of
+ "HyperLogLog in Practice: Algorithmic Engineering of a State
+ of The Art Cardinality Estimation Algorithm", available
+ <a href="http://dx.doi.org/10.1145/2452376.2452456">here</a>.
+
+ @param relativeSD Relative accuracy. Smaller values create
+ counters that require more space.
+ It must be greater than 0.000017.
+
+ >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
+ >>> 950 < n < 1050
+ True
+ >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
+ >>> 18 < n < 22
+ True
+ """
+ if relativeSD < 0.000017:
+ raise ValueError("relativeSD should be greater than 0.000017")
+ if relativeSD > 0.37:
+ raise ValueError("relativeSD should be smaller than 0.37")
+ # the hash space in Java is 2^32
+ hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
+ return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
+
class PipelinedRDD(RDD):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 3e7040eade..f1a75cbff5 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -404,6 +404,22 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())
+ def test_count_approx_distinct(self):
+ rdd = self.sc.parallelize(range(1000))
+ self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
+ self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
+ self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
+ self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)
+
+ rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
+ self.assertTrue(18 < rdd.countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
+
+ self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
+ self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5))
+
def test_histogram(self):
# empty
rdd = self.sc.parallelize([])