diff options
author | wangzhenhua <wangzhenhua@huawei.com> | 2016-11-29 13:16:46 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-11-29 13:16:46 -0800 |
commit | d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e (patch) | |
tree | 847dfe0de2a6ec831917709f169708695d09f95f /common/sketch | |
parent | f643fe47f4889faf68da3da8d7850ee48df7c22f (diff) | |
download | spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.gz spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.bz2 spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.zip |
[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch
## What changes were proposed in this pull request?
This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.
## How was this patch tested?
add test cases
Author: wangzhenhua <wangzhenhua@huawei.com>
Closes #15877 from wzhfy/cms.
Diffstat (limited to 'common/sketch')
3 files changed, 49 insertions, 8 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 40fa20c4a3..0011096d4a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -30,6 +30,10 @@ import java.io.OutputStream; * <li>{@link Integer}</li> * <li>{@link Long}</li> * <li>{@link String}</li> + * <li>{@link Float}</li> + * <li>{@link Double}</li> + * <li>{@link java.math.BigDecimal}</li> + * <li>{@link Boolean}</li> * </ul> * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: * <ol> diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 2acbb247b1..94ab3a98cb 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,6 +25,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; +import java.math.BigDecimal; import java.util.Arrays; import java.util.Random; @@ -152,6 +153,16 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); + } else if (item instanceof BigDecimal) { + addString(((BigDecimal) item).toString(), count); + } else if (item instanceof byte[]) { + addBinary((byte[]) item, count); + } else if (item instanceof Float) { + addLong(Float.floatToIntBits((Float) item), count); + } else if (item instanceof Double) { + addLong(Double.doubleToLongBits((Double) item), count); + } else if (item instanceof Boolean) { + addLong(((Boolean) item) ? 1L : 0L, count); } else { addLong(Utils.integralToLong(item), count); } @@ -216,10 +227,6 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return ((int) hash) % width; } - private static int[] getHashBuckets(String key, int hashCount, int max) { - return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); - } - private static int[] getHashBuckets(byte[] b, int hashCount, int max) { int[] result = new int[hashCount]; int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); @@ -233,7 +240,18 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { @Override public long estimateCount(Object item) { if (item instanceof String) { - return estimateCountForStringItem((String) item); + return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item)); + } else if (item instanceof BigDecimal) { + return estimateCountForBinaryItem( + Utils.getBytesFromUTF8String(((BigDecimal) item).toString())); + } else if (item instanceof byte[]) { + return estimateCountForBinaryItem((byte[]) item); + } else if (item instanceof Float) { + return estimateCountForLongItem(Float.floatToIntBits((Float) item)); + } else if (item instanceof Double) { + return estimateCountForLongItem(Double.doubleToLongBits((Double) item)); + } else if (item instanceof Boolean) { + return estimateCountForLongItem(((Boolean) item) ? 1L : 0L); } else { return estimateCountForLongItem(Utils.integralToLong(item)); } @@ -247,7 +265,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return res; } - private long estimateCountForStringItem(String item) { + private long estimateCountForBinaryItem(byte[] item) { long res = Long.MAX_VALUE; int[] buckets = getHashBuckets(item, depth, width); for (int i = 0; i < depth; ++i) { diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index b9c7f5c23a..2c358fcee4 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.sketch import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import scala.util.Random @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + def getProbeItem(item: T): Any = item match { + // Use a string to represent the content of an array of bytes + case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) + case i => identity(i) + } + test(s"accuracy - $typeName") { // Uses fixed seed to ensure reproducible test execution val r = new Random(31) @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val exactFreq = { val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(identity).mapValues(_.length.toLong) + sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val probCorrect = { val numErrors = allItems.map { item => - val count = exactFreq.getOrElse(item, 0L) + val count = exactFreq.getOrElse(getProbeItem(item), 0L) val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems if (ratio > epsOfTotalCount) 1 else 0 }.sum @@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + testItemType[Float]("Float") { _.nextFloat() } + + testItemType[Double]("Double") { _.nextDouble() } + + testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) } + + testItemType[Boolean]("Boolean") { _.nextBoolean() } + + testItemType[Array[Byte]]("Binary") { r => + Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20))) + } + test("incompatible merge") { intercept[IncompatibleMergeException] { CountMinSketch.create(10, 10, 1).mergeInPlace(null) |