aboutsummaryrefslogtreecommitdiff
path: root/common/sketch/src/test/scala/org/apache
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-23 00:34:55 -0800
committerReynold Xin <rxin@databricks.com>2016-01-23 00:34:55 -0800
commit1c690ddafa8376c55cbc5b7a7a750200abfbe2a6 (patch)
tree1be95d50cb9c14eb6051c1f068f6f708b1a34e9c /common/sketch/src/test/scala/org/apache
parent5af5a02160b42115579003b749c4d1831bf9d48e (diff)
downloadspark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.gz
spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.bz2
spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.zip
[SPARK-12933][SQL] Initial implementation of Count-Min sketch
This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1]. As required by the [design doc][2], spark-sketch should have no external dependency. Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation. The following features will be added in future follow-up PRs: - Serialization support - DataFrame API integration [1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java [2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf Author: Cheng Lian <lian@databricks.com> Closes #10851 from liancheng/count-min-sketch.
Diffstat (limited to 'common/sketch/src/test/scala/org/apache')
-rw-r--r--common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala112
1 files changed, 112 insertions, 0 deletions
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
new file mode 100644
index 0000000000..ec5b4eddec
--- /dev/null
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
+ private val epsOfTotalCount = 0.0001
+
+ private val confidence = 0.99
+
+ private val seed = 42
+
+ def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ test(s"accuracy - $typeName") {
+ val r = new Random()
+
+ val numAllItems = 1000000
+ val allItems = Array.fill(numAllItems)(itemGenerator(r))
+
+ val numSamples = numAllItems / 10
+ val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+
+ val exactFreq = {
+ val sampledItems = sampledItemIndices.map(allItems)
+ sampledItems.groupBy(identity).mapValues(_.length.toLong)
+ }
+
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ sampledItemIndices.foreach(i => sketch.add(allItems(i)))
+
+ val probCorrect = {
+ val numErrors = allItems.map { item =>
+ val count = exactFreq.getOrElse(item, 0L)
+ val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
+ if (ratio > epsOfTotalCount) 1 else 0
+ }.sum
+
+ 1D - numErrors.toDouble / numAllItems
+ }
+
+ assert(
+ probCorrect > confidence,
+ s"Confidence not reached: required $confidence, reached $probCorrect"
+ )
+ }
+ }
+
+ def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ test(s"mergeInPlace - $typeName") {
+ val r = new Random()
+ val numToMerge = 5
+ val numItemsPerSketch = 100000
+ val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
+ itemGenerator(r)
+ }
+
+ val sketches = perSketchItems.map { items =>
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ items.foreach(sketch.add)
+ sketch
+ }
+
+ val mergedSketch = sketches.reduce(_ mergeInPlace _)
+
+ val expectedSketch = {
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ perSketchItems.foreach(_.foreach(sketch.add))
+ sketch
+ }
+
+ perSketchItems.foreach {
+ _.foreach { item =>
+ assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item))
+ }
+ }
+ }
+ }
+
+ def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ testAccuracy[T](typeName)(itemGenerator)
+ testMergeInPlace[T](typeName)(itemGenerator)
+ }
+
+ testItemType[Byte]("Byte") { _.nextInt().toByte }
+
+ testItemType[Short]("Short") { _.nextInt().toShort }
+
+ testItemType[Int]("Int") { _.nextInt() }
+
+ testItemType[Long]("Long") { _.nextLong() }
+
+ testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+}