aboutsummaryrefslogtreecommitdiff
path: root/common/sketch
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2016-11-29 13:16:46 -0800
committerReynold Xin <rxin@databricks.com>2016-11-29 13:16:46 -0800
commitd57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e (patch)
tree847dfe0de2a6ec831917709f169708695d09f95f /common/sketch
parentf643fe47f4889faf68da3da8d7850ee48df7c22f (diff)
downloadspark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.gz
spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.bz2
spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.zip
[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch
## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua <wangzhenhua@huawei.com> Closes #15877 from wzhfy/cms.
Diffstat (limited to 'common/sketch')
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java4
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java30
-rw-r--r--common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala23
3 files changed, 49 insertions, 8 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 40fa20c4a3..0011096d4a 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -30,6 +30,10 @@ import java.io.OutputStream;
* <li>{@link Integer}</li>
* <li>{@link Long}</li>
* <li>{@link String}</li>
+ * <li>{@link Float}</li>
+ * <li>{@link Double}</li>
+ * <li>{@link java.math.BigDecimal}</li>
+ * <li>{@link Boolean}</li>
* </ul>
* A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
* <ol>
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 2acbb247b1..94ab3a98cb 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -25,6 +25,7 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
+import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Random;
@@ -152,6 +153,16 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
+ } else if (item instanceof BigDecimal) {
+ addString(((BigDecimal) item).toString(), count);
+ } else if (item instanceof byte[]) {
+ addBinary((byte[]) item, count);
+ } else if (item instanceof Float) {
+ addLong(Float.floatToIntBits((Float) item), count);
+ } else if (item instanceof Double) {
+ addLong(Double.doubleToLongBits((Double) item), count);
+ } else if (item instanceof Boolean) {
+ addLong(((Boolean) item) ? 1L : 0L, count);
} else {
addLong(Utils.integralToLong(item), count);
}
@@ -216,10 +227,6 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
return ((int) hash) % width;
}
- private static int[] getHashBuckets(String key, int hashCount, int max) {
- return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
- }
-
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
int[] result = new int[hashCount];
int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
@@ -233,7 +240,18 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
@Override
public long estimateCount(Object item) {
if (item instanceof String) {
- return estimateCountForStringItem((String) item);
+ return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
+ } else if (item instanceof BigDecimal) {
+ return estimateCountForBinaryItem(
+ Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
+ } else if (item instanceof byte[]) {
+ return estimateCountForBinaryItem((byte[]) item);
+ } else if (item instanceof Float) {
+ return estimateCountForLongItem(Float.floatToIntBits((Float) item));
+ } else if (item instanceof Double) {
+ return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
+ } else if (item instanceof Boolean) {
+ return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
} else {
return estimateCountForLongItem(Utils.integralToLong(item));
}
@@ -247,7 +265,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
return res;
}
- private long estimateCountForStringItem(String item) {
+ private long estimateCountForBinaryItem(byte[] item) {
long res = Long.MAX_VALUE;
int[] buckets = getHashBuckets(item, depth, width);
for (int i = 0; i < depth; ++i) {
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index b9c7f5c23a..2c358fcee4 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util.sketch
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets
import scala.reflect.ClassTag
import scala.util.Random
@@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ def getProbeItem(item: T): Any = item match {
+ // Use a string to represent the content of an array of bytes
+ case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+ case i => identity(i)
+ }
+
test(s"accuracy - $typeName") {
// Uses fixed seed to ensure reproducible test execution
val r = new Random(31)
@@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
val exactFreq = {
val sampledItems = sampledItemIndices.map(allItems)
- sampledItems.groupBy(identity).mapValues(_.length.toLong)
+ sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
}
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
val probCorrect = {
val numErrors = allItems.map { item =>
- val count = exactFreq.getOrElse(item, 0L)
+ val count = exactFreq.getOrElse(getProbeItem(item), 0L)
val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
if (ratio > epsOfTotalCount) 1 else 0
}.sum
@@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+ testItemType[Float]("Float") { _.nextFloat() }
+
+ testItemType[Double]("Double") { _.nextDouble() }
+
+ testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }
+
+ testItemType[Boolean]("Boolean") { _.nextBoolean() }
+
+ testItemType[Array[Byte]]("Binary") { r =>
+ Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
+ }
+
test("incompatible merge") {
intercept[IncompatibleMergeException] {
CountMinSketch.create(10, 10, 1).mergeInPlace(null)