diff options
author | Cheng Lian <lian@databricks.com> | 2016-01-23 00:34:55 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-23 00:34:55 -0800 |
commit | 1c690ddafa8376c55cbc5b7a7a750200abfbe2a6 (patch) | |
tree | 1be95d50cb9c14eb6051c1f068f6f708b1a34e9c /common/sketch/src/test/scala/org/apache | |
parent | 5af5a02160b42115579003b749c4d1831bf9d48e (diff) | |
download | spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.gz spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.bz2 spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.zip |
[SPARK-12933][SQL] Initial implementation of Count-Min sketch
This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1].
As required by the [design doc][2], spark-sketch should have no external dependency.
Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation.
The following features will be added in future follow-up PRs:
- Serialization support
- DataFrame API integration
[1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java
[2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf
Author: Cheng Lian <lian@databricks.com>
Closes #10851 from liancheng/count-min-sketch.
Diffstat (limited to 'common/sketch/src/test/scala/org/apache')
-rw-r--r-- | common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala new file mode 100644 index 0000000000..ec5b4eddec --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite + private val epsOfTotalCount = 0.0001 + + private val confidence = 0.99 + + private val seed = 42 + + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"accuracy - $typeName") { + val r = new Random() + + val numAllItems = 1000000 + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + + val numSamples = numAllItems / 10 + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(identity).mapValues(_.length.toLong) + } + + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + + val probCorrect = { + val numErrors = allItems.map { item => + val count = exactFreq.getOrElse(item, 0L) + val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems + if (ratio > epsOfTotalCount) 1 else 0 + }.sum + + 1D - numErrors.toDouble / numAllItems + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + val r = new Random() + val numToMerge = 5 + val numItemsPerSketch = 100000 + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { + itemGenerator(r) + } + + val sketches = perSketchItems.map { items => + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + items.foreach(sketch.add) + sketch + } + + val mergedSketch = sketches.reduce(_ mergeInPlace _) + + val expectedSketch = { + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(sketch.add)) + sketch + } + + perSketchItems.foreach { + _.foreach { item => + assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item)) + } + } + } + } + + def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + testAccuracy[T](typeName)(itemGenerator) + testMergeInPlace[T](typeName)(itemGenerator) + } + + testItemType[Byte]("Byte") { _.nextInt().toByte } + + testItemType[Short]("Short") { _.nextInt().toShort } + + testItemType[Int]("Int") { _.nextInt() } + + testItemType[Long]("Long") { _.nextLong() } + + testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } +} |