aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-23 00:34:55 -0800
committerReynold Xin <rxin@databricks.com>2016-01-23 00:34:55 -0800
commit1c690ddafa8376c55cbc5b7a7a750200abfbe2a6 (patch)
tree1be95d50cb9c14eb6051c1f068f6f708b1a34e9c
parent5af5a02160b42115579003b749c4d1831bf9d48e (diff)
downloadspark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.gz
spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.tar.bz2
spark-1c690ddafa8376c55cbc5b7a7a750200abfbe2a6.zip
[SPARK-12933][SQL] Initial implementation of Count-Min sketch
This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1]. As required by the [design doc][2], spark-sketch should have no external dependency. Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation. The following features will be added in future follow-up PRs: - Serialization support - DataFrame API integration [1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java [2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf Author: Cheng Lian <lian@databricks.com> Closes #10851 from liancheng/count-min-sketch.
-rw-r--r--common/sketch/pom.xml42
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java132
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java268
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java126
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java172
-rw-r--r--common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala112
-rw-r--r--dev/sparktestsupport/modules.py12
-rw-r--r--pom.xml1
-rw-r--r--project/SparkBuild.scala39
9 files changed, 892 insertions, 12 deletions
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
new file mode 100644
index 0000000000..67723fa421
--- /dev/null
+++ b/common/sketch/pom.xml
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent_2.10</artifactId>
+ <version>2.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sketch_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Sketch</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>sketch</sbt.project.name>
+ </properties>
+
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
new file mode 100644
index 0000000000..21b161bc74
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/**
+ * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in
+ * sub-linear space. Currently, supported data types include:
+ * <ul>
+ * <li>{@link Byte}</li>
+ * <li>{@link Short}</li>
+ * <li>{@link Integer}</li>
+ * <li>{@link Long}</li>
+ * <li>{@link String}</li>
+ * </ul>
+ * Each {@link CountMinSketch} is initialized with a random seed, and a pair
+ * of parameters:
+ * <ol>
+ * <li>relative error (or {@code eps}), and
+ * <li>confidence (or {@code delta})
+ * </ol>
+ * Suppose you want to estimate the number of times an element {@code x} has appeared in a data
+ * stream so far. With probability {@code delta}, the estimate of this frequency is within the
+ * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the
+ * total count of items have appeared the the data stream so far.
+ *
+ * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array
+ * with depth {@code d} and width {@code w}, where
+ * <ul>
+ * <li>{@code d = ceil(2 / eps)}</li>
+ * <li>{@code w = ceil(-log(1 - confidence) / log(2))}</li>
+ * </ul>
+ *
+ * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details,
+ * including proofs of the estimates and error bounds used in this implementation.
+ *
+ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
+ */
+abstract public class CountMinSketch {
+ /**
+ * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
+ */
+ public abstract double relativeError();
+
+ /**
+ * Returns the confidence (or {@code delta}) of this {@link CountMinSketch}.
+ */
+ public abstract double confidence();
+
+ /**
+ * Depth of this {@link CountMinSketch}.
+ */
+ public abstract int depth();
+
+ /**
+ * Width of this {@link CountMinSketch}.
+ */
+ public abstract int width();
+
+ /**
+ * Total count of items added to this {@link CountMinSketch} so far.
+ */
+ public abstract long totalCount();
+
+ /**
+ * Adds 1 to {@code item}.
+ */
+ public abstract void add(Object item);
+
+ /**
+ * Adds {@code count} to {@code item}.
+ */
+ public abstract void add(Object item, long count);
+
+ /**
+ * Returns the estimated frequency of {@code item}.
+ */
+ public abstract long estimateCount(Object item);
+
+ /**
+ * Merges another {@link CountMinSketch} with this one in place.
+ *
+ * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
+ * can be merged.
+ */
+ public abstract CountMinSketch mergeInPlace(CountMinSketch other);
+
+ /**
+ * Writes out this {@link CountMinSketch} to an output stream in binary format.
+ */
+ public abstract void writeTo(OutputStream out);
+
+ /**
+ * Reads in a {@link CountMinSketch} from an input stream.
+ */
+ public static CountMinSketch readFrom(InputStream in) {
+ throw new UnsupportedOperationException("Not implemented yet");
+ }
+
+ /**
+ * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
+ * {@code seed}.
+ */
+ public static CountMinSketch create(int depth, int width, int seed) {
+ return new CountMinSketchImpl(depth, width, seed);
+ }
+
+ /**
+ * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence},
+ * and random {@code seed}.
+ */
+ public static CountMinSketch create(double eps, double confidence, int seed) {
+ return new CountMinSketchImpl(eps, confidence, seed);
+ }
+}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
new file mode 100644
index 0000000000..e9fdbe3a86
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import java.util.Random;
+
+class CountMinSketchImpl extends CountMinSketch {
+ public static final long PRIME_MODULUS = (1L << 31) - 1;
+
+ private int depth;
+ private int width;
+ private long[][] table;
+ private long[] hashA;
+ private long totalCount;
+ private double eps;
+ private double confidence;
+
+ public CountMinSketchImpl(int depth, int width, int seed) {
+ this.depth = depth;
+ this.width = width;
+ this.eps = 2.0 / width;
+ this.confidence = 1 - 1 / Math.pow(2, depth);
+ initTablesWith(depth, width, seed);
+ }
+
+ public CountMinSketchImpl(double eps, double confidence, int seed) {
+ // 2/w = eps ; w = 2/eps
+ // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
+ this.eps = eps;
+ this.confidence = confidence;
+ this.width = (int) Math.ceil(2 / eps);
+ this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+ initTablesWith(depth, width, seed);
+ }
+
+ private void initTablesWith(int depth, int width, int seed) {
+ this.table = new long[depth][width];
+ this.hashA = new long[depth];
+ Random r = new Random(seed);
+ // We're using a linear hash functions
+ // of the form (a*x+b) mod p.
+ // a,b are chosen independently for each hash function.
+ // However we can set b = 0 as all it does is shift the results
+ // without compromising their uniformity or independence with
+ // the other hashes.
+ for (int i = 0; i < depth; ++i) {
+ hashA[i] = r.nextInt(Integer.MAX_VALUE);
+ }
+ }
+
+ @Override
+ public double relativeError() {
+ return eps;
+ }
+
+ @Override
+ public double confidence() {
+ return confidence;
+ }
+
+ @Override
+ public int depth() {
+ return depth;
+ }
+
+ @Override
+ public int width() {
+ return width;
+ }
+
+ @Override
+ public long totalCount() {
+ return totalCount;
+ }
+
+ @Override
+ public void add(Object item) {
+ add(item, 1);
+ }
+
+ @Override
+ public void add(Object item, long count) {
+ if (item instanceof String) {
+ addString((String) item, count);
+ } else {
+ long longValue;
+
+ if (item instanceof Long) {
+ longValue = (Long) item;
+ } else if (item instanceof Integer) {
+ longValue = ((Integer) item).longValue();
+ } else if (item instanceof Short) {
+ longValue = ((Short) item).longValue();
+ } else if (item instanceof Byte) {
+ longValue = ((Byte) item).longValue();
+ } else {
+ throw new IllegalArgumentException(
+ "Support for " + item.getClass().getName() + " not implemented"
+ );
+ }
+
+ addLong(longValue, count);
+ }
+ }
+
+ private void addString(String item, long count) {
+ if (count < 0) {
+ throw new IllegalArgumentException("Negative increments not implemented");
+ }
+
+ int[] buckets = getHashBuckets(item, depth, width);
+
+ for (int i = 0; i < depth; ++i) {
+ table[i][buckets[i]] += count;
+ }
+
+ totalCount += count;
+ }
+
+ private void addLong(long item, long count) {
+ if (count < 0) {
+ throw new IllegalArgumentException("Negative increments not implemented");
+ }
+
+ for (int i = 0; i < depth; ++i) {
+ table[i][hash(item, i)] += count;
+ }
+
+ totalCount += count;
+ }
+
+ private int hash(long item, int count) {
+ long hash = hashA[count] * item;
+ // A super fast way of computing x mod 2^p-1
+ // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
+ // page 149, right after Proposition 7.
+ hash += hash >> 32;
+ hash &= PRIME_MODULUS;
+ // Doing "%" after (int) conversion is ~2x faster than %'ing longs.
+ return ((int) hash) % width;
+ }
+
+ private static int[] getHashBuckets(String key, int hashCount, int max) {
+ byte[] b;
+ try {
+ b = key.getBytes("UTF-8");
+ } catch (UnsupportedEncodingException e) {
+ throw new RuntimeException(e);
+ }
+ return getHashBuckets(b, hashCount, max);
+ }
+
+ private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
+ int[] result = new int[hashCount];
+ int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
+ int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
+ for (int i = 0; i < hashCount; i++) {
+ result[i] = Math.abs((hash1 + i * hash2) % max);
+ }
+ return result;
+ }
+
+ @Override
+ public long estimateCount(Object item) {
+ if (item instanceof String) {
+ return estimateCountForStringItem((String) item);
+ } else {
+ long longValue;
+
+ if (item instanceof Long) {
+ longValue = (Long) item;
+ } else if (item instanceof Integer) {
+ longValue = ((Integer) item).longValue();
+ } else if (item instanceof Short) {
+ longValue = ((Short) item).longValue();
+ } else if (item instanceof Byte) {
+ longValue = ((Byte) item).longValue();
+ } else {
+ throw new IllegalArgumentException(
+ "Support for " + item.getClass().getName() + " not implemented"
+ );
+ }
+
+ return estimateCountForLongItem(longValue);
+ }
+ }
+
+ private long estimateCountForLongItem(long item) {
+ long res = Long.MAX_VALUE;
+ for (int i = 0; i < depth; ++i) {
+ res = Math.min(res, table[i][hash(item, i)]);
+ }
+ return res;
+ }
+
+ private long estimateCountForStringItem(String item) {
+ long res = Long.MAX_VALUE;
+ int[] buckets = getHashBuckets(item, depth, width);
+ for (int i = 0; i < depth; ++i) {
+ res = Math.min(res, table[i][buckets[i]]);
+ }
+ return res;
+ }
+
+ @Override
+ public CountMinSketch mergeInPlace(CountMinSketch other) {
+ if (other == null) {
+ throw new CMSMergeException("Cannot merge null estimator");
+ }
+
+ if (!(other instanceof CountMinSketchImpl)) {
+ throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
+ }
+
+ CountMinSketchImpl that = (CountMinSketchImpl) other;
+
+ if (this.depth != that.depth) {
+ throw new CMSMergeException("Cannot merge estimators of different depth");
+ }
+
+ if (this.width != that.width) {
+ throw new CMSMergeException("Cannot merge estimators of different width");
+ }
+
+ if (!Arrays.equals(this.hashA, that.hashA)) {
+ throw new CMSMergeException("Cannot merge estimators of different seed");
+ }
+
+ for (int i = 0; i < this.table.length; ++i) {
+ for (int j = 0; j < this.table[i].length; ++j) {
+ this.table[i][j] = this.table[i][j] + that.table[i][j];
+ }
+ }
+
+ this.totalCount += that.totalCount;
+
+ return this;
+ }
+
+ @Override
+ public void writeTo(OutputStream out) {
+ throw new UnsupportedOperationException("Not implemented yet");
+ }
+
+ protected static class CMSMergeException extends RuntimeException {
+ public CMSMergeException(String message) {
+ super(message);
+ }
+ }
+}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java
new file mode 100644
index 0000000000..3d1f28bcb9
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+/**
+ * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
+ */
+// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure
+// spark-sketch has no external dependencies.
+final class Murmur3_x86_32 {
+ private static final int C1 = 0xcc9e2d51;
+ private static final int C2 = 0x1b873593;
+
+ private final int seed;
+
+ public Murmur3_x86_32(int seed) {
+ this.seed = seed;
+ }
+
+ @Override
+ public String toString() {
+ return "Murmur3_32(seed=" + seed + ")";
+ }
+
+ public int hashInt(int input) {
+ return hashInt(input, seed);
+ }
+
+ public static int hashInt(int input, int seed) {
+ int k1 = mixK1(input);
+ int h1 = mixH1(seed, k1);
+
+ return fmix(h1, 4);
+ }
+
+ public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+ return hashUnsafeWords(base, offset, lengthInBytes, seed);
+ }
+
+ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
+ // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
+ assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+ int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
+ return fmix(h1, lengthInBytes);
+ }
+
+ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
+ assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
+ int lengthAligned = lengthInBytes - lengthInBytes % 4;
+ int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
+ for (int i = lengthAligned; i < lengthInBytes; i++) {
+ int halfWord = Platform.getByte(base, offset + i);
+ int k1 = mixK1(halfWord);
+ h1 = mixH1(h1, k1);
+ }
+ return fmix(h1, lengthInBytes);
+ }
+
+ private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
+ assert (lengthInBytes % 4 == 0);
+ int h1 = seed;
+ for (int i = 0; i < lengthInBytes; i += 4) {
+ int halfWord = Platform.getInt(base, offset + i);
+ int k1 = mixK1(halfWord);
+ h1 = mixH1(h1, k1);
+ }
+ return h1;
+ }
+
+ public int hashLong(long input) {
+ return hashLong(input, seed);
+ }
+
+ public static int hashLong(long input, int seed) {
+ int low = (int) input;
+ int high = (int) (input >>> 32);
+
+ int k1 = mixK1(low);
+ int h1 = mixH1(seed, k1);
+
+ k1 = mixK1(high);
+ h1 = mixH1(h1, k1);
+
+ return fmix(h1, 8);
+ }
+
+ private static int mixK1(int k1) {
+ k1 *= C1;
+ k1 = Integer.rotateLeft(k1, 15);
+ k1 *= C2;
+ return k1;
+ }
+
+ private static int mixH1(int h1, int k1) {
+ h1 ^= k1;
+ h1 = Integer.rotateLeft(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+ return h1;
+ }
+
+ // Finalization mix - force all bits of a hash block to avalanche
+ private static int fmix(int h1, int length) {
+ h1 ^= length;
+ h1 ^= h1 >>> 16;
+ h1 *= 0x85ebca6b;
+ h1 ^= h1 >>> 13;
+ h1 *= 0xc2b2ae35;
+ h1 ^= h1 >>> 16;
+ return h1;
+ }
+}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java
new file mode 100644
index 0000000000..75d6a6beec
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.lang.reflect.Field;
+
+import sun.misc.Unsafe;
+
+// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no
+// external dependencies.
+final class Platform {
+
+ private static final Unsafe _UNSAFE;
+
+ public static final int BYTE_ARRAY_OFFSET;
+
+ public static final int INT_ARRAY_OFFSET;
+
+ public static final int LONG_ARRAY_OFFSET;
+
+ public static final int DOUBLE_ARRAY_OFFSET;
+
+ public static int getInt(Object object, long offset) {
+ return _UNSAFE.getInt(object, offset);
+ }
+
+ public static void putInt(Object object, long offset, int value) {
+ _UNSAFE.putInt(object, offset, value);
+ }
+
+ public static boolean getBoolean(Object object, long offset) {
+ return _UNSAFE.getBoolean(object, offset);
+ }
+
+ public static void putBoolean(Object object, long offset, boolean value) {
+ _UNSAFE.putBoolean(object, offset, value);
+ }
+
+ public static byte getByte(Object object, long offset) {
+ return _UNSAFE.getByte(object, offset);
+ }
+
+ public static void putByte(Object object, long offset, byte value) {
+ _UNSAFE.putByte(object, offset, value);
+ }
+
+ public static short getShort(Object object, long offset) {
+ return _UNSAFE.getShort(object, offset);
+ }
+
+ public static void putShort(Object object, long offset, short value) {
+ _UNSAFE.putShort(object, offset, value);
+ }
+
+ public static long getLong(Object object, long offset) {
+ return _UNSAFE.getLong(object, offset);
+ }
+
+ public static void putLong(Object object, long offset, long value) {
+ _UNSAFE.putLong(object, offset, value);
+ }
+
+ public static float getFloat(Object object, long offset) {
+ return _UNSAFE.getFloat(object, offset);
+ }
+
+ public static void putFloat(Object object, long offset, float value) {
+ _UNSAFE.putFloat(object, offset, value);
+ }
+
+ public static double getDouble(Object object, long offset) {
+ return _UNSAFE.getDouble(object, offset);
+ }
+
+ public static void putDouble(Object object, long offset, double value) {
+ _UNSAFE.putDouble(object, offset, value);
+ }
+
+ public static Object getObjectVolatile(Object object, long offset) {
+ return _UNSAFE.getObjectVolatile(object, offset);
+ }
+
+ public static void putObjectVolatile(Object object, long offset, Object value) {
+ _UNSAFE.putObjectVolatile(object, offset, value);
+ }
+
+ public static long allocateMemory(long size) {
+ return _UNSAFE.allocateMemory(size);
+ }
+
+ public static void freeMemory(long address) {
+ _UNSAFE.freeMemory(address);
+ }
+
+ public static void copyMemory(
+ Object src, long srcOffset, Object dst, long dstOffset, long length) {
+ // Check if dstOffset is before or after srcOffset to determine if we should copy
+ // forward or backwards. This is necessary in case src and dst overlap.
+ if (dstOffset < srcOffset) {
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ srcOffset += size;
+ dstOffset += size;
+ }
+ } else {
+ srcOffset += length;
+ dstOffset += length;
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ srcOffset -= size;
+ dstOffset -= size;
+ _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ }
+
+ }
+ }
+
+ /**
+ * Raises an exception bypassing compiler checks for checked exceptions.
+ */
+ public static void throwException(Throwable t) {
+ _UNSAFE.throwException(t);
+ }
+
+ /**
+ * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
+ * allow safepoint polling during a large copy.
+ */
+ private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
+
+ static {
+ sun.misc.Unsafe unsafe;
+ try {
+ Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafeField.setAccessible(true);
+ unsafe = (sun.misc.Unsafe) unsafeField.get(null);
+ } catch (Throwable cause) {
+ unsafe = null;
+ }
+ _UNSAFE = unsafe;
+
+ if (_UNSAFE != null) {
+ BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+ INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
+ LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+ DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
+ } else {
+ BYTE_ARRAY_OFFSET = 0;
+ INT_ARRAY_OFFSET = 0;
+ LONG_ARRAY_OFFSET = 0;
+ DOUBLE_ARRAY_OFFSET = 0;
+ }
+ }
+}
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
new file mode 100644
index 0000000000..ec5b4eddec
--- /dev/null
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
+ private val epsOfTotalCount = 0.0001
+
+ private val confidence = 0.99
+
+ private val seed = 42
+
+ def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ test(s"accuracy - $typeName") {
+ val r = new Random()
+
+ val numAllItems = 1000000
+ val allItems = Array.fill(numAllItems)(itemGenerator(r))
+
+ val numSamples = numAllItems / 10
+ val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+
+ val exactFreq = {
+ val sampledItems = sampledItemIndices.map(allItems)
+ sampledItems.groupBy(identity).mapValues(_.length.toLong)
+ }
+
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ sampledItemIndices.foreach(i => sketch.add(allItems(i)))
+
+ val probCorrect = {
+ val numErrors = allItems.map { item =>
+ val count = exactFreq.getOrElse(item, 0L)
+ val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
+ if (ratio > epsOfTotalCount) 1 else 0
+ }.sum
+
+ 1D - numErrors.toDouble / numAllItems
+ }
+
+ assert(
+ probCorrect > confidence,
+ s"Confidence not reached: required $confidence, reached $probCorrect"
+ )
+ }
+ }
+
+ def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ test(s"mergeInPlace - $typeName") {
+ val r = new Random()
+ val numToMerge = 5
+ val numItemsPerSketch = 100000
+ val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
+ itemGenerator(r)
+ }
+
+ val sketches = perSketchItems.map { items =>
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ items.foreach(sketch.add)
+ sketch
+ }
+
+ val mergedSketch = sketches.reduce(_ mergeInPlace _)
+
+ val expectedSketch = {
+ val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ perSketchItems.foreach(_.foreach(sketch.add))
+ sketch
+ }
+
+ perSketchItems.foreach {
+ _.foreach { item =>
+ assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item))
+ }
+ }
+ }
+ }
+
+ def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+ testAccuracy[T](typeName)(itemGenerator)
+ testMergeInPlace[T](typeName)(itemGenerator)
+ }
+
+ testItemType[Byte]("Byte") { _.nextInt().toByte }
+
+ testItemType[Short]("Short") { _.nextInt().toShort }
+
+ testItemType[Int]("Int") { _.nextInt() }
+
+ testItemType[Long]("Long") { _.nextLong() }
+
+ testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index efe58ea2e0..032c0616ed 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -113,6 +113,18 @@ hive_thriftserver = Module(
)
+sketch = Module(
+ name="sketch",
+ dependencies=[],
+ source_file_regexes=[
+ "common/sketch/",
+ ],
+ sbt_test_goals=[
+ "sketch/test"
+ ]
+)
+
+
graphx = Module(
name="graphx",
dependencies=[],
diff --git a/pom.xml b/pom.xml
index f08642f606..fb7750602c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -86,6 +86,7 @@
</mailingLists>
<modules>
+ <module>common/sketch</module>
<module>tags</module>
<module>core</module>
<module>graphx</module>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 3927b88fb0..4224a65a82 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -34,13 +34,24 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
- sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka,
- streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) =
- Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
- "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
- "streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
- "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _))
+ val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq(
+ "catalyst", "sql", "hive", "hive-thriftserver"
+ ).map(ProjectRef(buildLocation, _))
+
+ val streamingProjects@Seq(
+ streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt,
+ streamingTwitter, streamingZeromq
+ ) = Seq(
+ "streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka",
+ "streaming-mqtt", "streaming-twitter", "streaming-zeromq"
+ ).map(ProjectRef(buildLocation, _))
+
+ val allProjects@Seq(
+ core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
+ ) = Seq(
+ "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
+ "test-tags", "sketch"
+ ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl,
streamingKinesisAsl, dockerIntegrationTests) =
@@ -232,11 +243,15 @@ object SparkBuild extends PomBuild {
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
- // TODO: remove streamingAkka from this list after 2.0.0
- allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
- networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach {
- x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
- }
+ // TODO: remove streamingAkka and sketch from this list after 2.0.0
+ allProjects.filterNot { x =>
+ Seq(
+ spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
+ unsafe, streamingAkka, testTags, sketch
+ ).contains(x)
+ }.foreach { x =>
+ enable(MimaBuild.mimaSettings(sparkHome, x))(x)
+ }
/* Unsafe settings */
enable(Unsafe.settings)(unsafe)