From d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 29 Nov 2016 13:16:46 -0800 Subject: [SPARK-18429][SQL] implement a new Aggregate for CountMinSketch ## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua Closes #15877 from wzhfy/cms. --- .../apache/spark/util/sketch/CountMinSketch.java | 4 +++ .../spark/util/sketch/CountMinSketchImpl.java | 30 +++++++++++++++++----- .../spark/util/sketch/CountMinSketchSuite.scala | 23 +++++++++++++++-- 3 files changed, 49 insertions(+), 8 deletions(-) (limited to 'common/sketch') diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 40fa20c4a3..0011096d4a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -30,6 +30,10 @@ import java.io.OutputStream; *
  • {@link Integer}
  • *
  • {@link Long}
  • *
  • {@link String}
  • + *
  • {@link Float}
  • + *
  • {@link Double}
  • + *
  • {@link java.math.BigDecimal}
  • + *
  • {@link Boolean}
  • * * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: *
      diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 2acbb247b1..94ab3a98cb 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,6 +25,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; +import java.math.BigDecimal; import java.util.Arrays; import java.util.Random; @@ -152,6 +153,16 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); + } else if (item instanceof BigDecimal) { + addString(((BigDecimal) item).toString(), count); + } else if (item instanceof byte[]) { + addBinary((byte[]) item, count); + } else if (item instanceof Float) { + addLong(Float.floatToIntBits((Float) item), count); + } else if (item instanceof Double) { + addLong(Double.doubleToLongBits((Double) item), count); + } else if (item instanceof Boolean) { + addLong(((Boolean) item) ? 1L : 0L, count); } else { addLong(Utils.integralToLong(item), count); } @@ -216,10 +227,6 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return ((int) hash) % width; } - private static int[] getHashBuckets(String key, int hashCount, int max) { - return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); - } - private static int[] getHashBuckets(byte[] b, int hashCount, int max) { int[] result = new int[hashCount]; int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); @@ -233,7 +240,18 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { @Override public long estimateCount(Object item) { if (item instanceof String) { - return estimateCountForStringItem((String) item); + return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item)); + } else if (item instanceof BigDecimal) { + return estimateCountForBinaryItem( + Utils.getBytesFromUTF8String(((BigDecimal) item).toString())); + } else if (item instanceof byte[]) { + return estimateCountForBinaryItem((byte[]) item); + } else if (item instanceof Float) { + return estimateCountForLongItem(Float.floatToIntBits((Float) item)); + } else if (item instanceof Double) { + return estimateCountForLongItem(Double.doubleToLongBits((Double) item)); + } else if (item instanceof Boolean) { + return estimateCountForLongItem(((Boolean) item) ? 1L : 0L); } else { return estimateCountForLongItem(Utils.integralToLong(item)); } @@ -247,7 +265,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return res; } - private long estimateCountForStringItem(String item) { + private long estimateCountForBinaryItem(byte[] item) { long res = Long.MAX_VALUE; int[] buckets = getHashBuckets(item, depth, width); for (int i = 0; i < depth; ++i) { diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index b9c7f5c23a..2c358fcee4 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.sketch import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import scala.util.Random @@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + def getProbeItem(item: T): Any = item match { + // Use a string to represent the content of an array of bytes + case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) + case i => identity(i) + } + test(s"accuracy - $typeName") { // Uses fixed seed to ensure reproducible test execution val r = new Random(31) @@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val exactFreq = { val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(identity).mapValues(_.length.toLong) + sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val probCorrect = { val numErrors = allItems.map { item => - val count = exactFreq.getOrElse(item, 0L) + val count = exactFreq.getOrElse(getProbeItem(item), 0L) val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems if (ratio > epsOfTotalCount) 1 else 0 }.sum @@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + testItemType[Float]("Float") { _.nextFloat() } + + testItemType[Double]("Double") { _.nextDouble() } + + testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) } + + testItemType[Boolean]("Boolean") { _.nextBoolean() } + + testItemType[Array[Byte]]("Binary") { r => + Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20))) + } + test("incompatible merge") { intercept[IncompatibleMergeException] { CountMinSketch.create(10, 10, 1).mergeInPlace(null) -- cgit v1.2.3