aboutsummaryrefslogtreecommitdiff
path: root/common/sketch
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-25 15:05:05 -0800
committerReynold Xin <rxin@databricks.com>2016-01-25 15:05:05 -0800
commit6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da (patch)
treee008931434bb051449db9f7fcc8523bb88060b93 /common/sketch
parentdcae355c64d7f6fdf61df2feefe464eb96c4cf5e (diff)
downloadspark-6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da.tar.gz
spark-6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da.tar.bz2
spark-6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da.zip
[SPARK-12934][SQL] Count-min sketch serialization
This PR adds serialization support for `CountMinSketch`. A version number is added to version the serialized binary format. Author: Cheng Lian <lian@databricks.com> Closes #10893 from liancheng/cms-serialization.
Diffstat (limited to 'common/sketch')
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java32
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java129
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java24
-rw-r--r--common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala47
4 files changed, 213 insertions, 19 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 21b161bc74..67938644d9 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.sketch;
+import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -55,6 +56,25 @@ import java.io.OutputStream;
*/
abstract public class CountMinSketch {
/**
+ * Version number of the serialized binary format.
+ */
+ public enum Version {
+ V1(1);
+
+ private final int versionNumber;
+
+ Version(int versionNumber) {
+ this.versionNumber = versionNumber;
+ }
+
+ public int getVersionNumber() {
+ return versionNumber;
+ }
+ }
+
+ public abstract Version version();
+
+ /**
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
*/
public abstract double relativeError();
@@ -99,19 +119,23 @@ abstract public class CountMinSketch {
*
* Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
* can be merged.
+ *
+ * @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has
+ * incompatible depth, width, relative-error, confidence, or random seed.
*/
- public abstract CountMinSketch mergeInPlace(CountMinSketch other);
+ public abstract CountMinSketch mergeInPlace(CountMinSketch other)
+ throws IncompatibleMergeException;
/**
* Writes out this {@link CountMinSketch} to an output stream in binary format.
*/
- public abstract void writeTo(OutputStream out);
+ public abstract void writeTo(OutputStream out) throws IOException;
/**
* Reads in a {@link CountMinSketch} from an input stream.
*/
- public static CountMinSketch readFrom(InputStream in) {
- throw new UnsupportedOperationException("Not implemented yet");
+ public static CountMinSketch readFrom(InputStream in) throws IOException {
+ return CountMinSketchImpl.readFrom(in);
}
/**
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index e9fdbe3a86..0209446ea3 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -17,11 +17,30 @@
package org.apache.spark.util.sketch;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
+/*
+ * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
+ * order):
+ *
+ * - Version number, always 1 (32 bit)
+ * - Total count of added items (64 bit)
+ * - Depth (32 bit)
+ * - Width (32 bit)
+ * - Hash functions (depth * 64 bit)
+ * - Count table
+ * - Row 0 (width * 64 bit)
+ * - Row 1 (width * 64 bit)
+ * - ...
+ * - Row depth - 1 (width * 64 bit)
+ */
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
@@ -33,7 +52,7 @@ class CountMinSketchImpl extends CountMinSketch {
private double eps;
private double confidence;
- public CountMinSketchImpl(int depth, int width, int seed) {
+ CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
@@ -41,7 +60,7 @@ class CountMinSketchImpl extends CountMinSketch {
initTablesWith(depth, width, seed);
}
- public CountMinSketchImpl(double eps, double confidence, int seed) {
+ CountMinSketchImpl(double eps, double confidence, int seed) {
// 2/w = eps ; w = 2/eps
// 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
this.eps = eps;
@@ -51,6 +70,53 @@ class CountMinSketchImpl extends CountMinSketch {
initTablesWith(depth, width, seed);
}
+ CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
+ this.depth = depth;
+ this.width = width;
+ this.eps = 2.0 / width;
+ this.confidence = 1 - 1 / Math.pow(2, depth);
+ this.hashA = hashA;
+ this.table = table;
+ this.totalCount = totalCount;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) {
+ return true;
+ }
+
+ if (other == null || !(other instanceof CountMinSketchImpl)) {
+ return false;
+ }
+
+ CountMinSketchImpl that = (CountMinSketchImpl) other;
+
+ return
+ this.depth == that.depth &&
+ this.width == that.width &&
+ this.totalCount == that.totalCount &&
+ Arrays.equals(this.hashA, that.hashA) &&
+ Arrays.deepEquals(this.table, that.table);
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = depth;
+
+ hash = hash * 31 + width;
+ hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
+ hash = hash * 31 + Arrays.hashCode(hashA);
+ hash = hash * 31 + Arrays.deepHashCode(table);
+
+ return hash;
+ }
+
+ @Override
+ public Version version() {
+ return Version.V1;
+ }
+
private void initTablesWith(int depth, int width, int seed) {
this.table = new long[depth][width];
this.hashA = new long[depth];
@@ -221,27 +287,29 @@ class CountMinSketchImpl extends CountMinSketch {
}
@Override
- public CountMinSketch mergeInPlace(CountMinSketch other) {
+ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException {
if (other == null) {
- throw new CMSMergeException("Cannot merge null estimator");
+ throw new IncompatibleMergeException("Cannot merge null estimator");
}
if (!(other instanceof CountMinSketchImpl)) {
- throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
+ throw new IncompatibleMergeException(
+ "Cannot merge estimator of class " + other.getClass().getName()
+ );
}
CountMinSketchImpl that = (CountMinSketchImpl) other;
if (this.depth != that.depth) {
- throw new CMSMergeException("Cannot merge estimators of different depth");
+ throw new IncompatibleMergeException("Cannot merge estimators of different depth");
}
if (this.width != that.width) {
- throw new CMSMergeException("Cannot merge estimators of different width");
+ throw new IncompatibleMergeException("Cannot merge estimators of different width");
}
if (!Arrays.equals(this.hashA, that.hashA)) {
- throw new CMSMergeException("Cannot merge estimators of different seed");
+ throw new IncompatibleMergeException("Cannot merge estimators of different seed");
}
for (int i = 0; i < this.table.length; ++i) {
@@ -256,13 +324,48 @@ class CountMinSketchImpl extends CountMinSketch {
}
@Override
- public void writeTo(OutputStream out) {
- throw new UnsupportedOperationException("Not implemented yet");
+ public void writeTo(OutputStream out) throws IOException {
+ DataOutputStream dos = new DataOutputStream(out);
+
+ dos.writeInt(version().getVersionNumber());
+
+ dos.writeLong(this.totalCount);
+ dos.writeInt(this.depth);
+ dos.writeInt(this.width);
+
+ for (int i = 0; i < this.depth; ++i) {
+ dos.writeLong(this.hashA[i]);
+ }
+
+ for (int i = 0; i < this.depth; ++i) {
+ for (int j = 0; j < this.width; ++j) {
+ dos.writeLong(table[i][j]);
+ }
+ }
}
- protected static class CMSMergeException extends RuntimeException {
- public CMSMergeException(String message) {
- super(message);
+ public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
+ DataInputStream dis = new DataInputStream(in);
+
+ // Ignores version number
+ dis.readInt();
+
+ long totalCount = dis.readLong();
+ int depth = dis.readInt();
+ int width = dis.readInt();
+
+ long hashA[] = new long[depth];
+ for (int i = 0; i < depth; ++i) {
+ hashA[i] = dis.readLong();
+ }
+
+ long table[][] = new long[depth][width];
+ for (int i = 0; i < depth; ++i) {
+ for (int j = 0; j < width; ++j) {
+ table[i][j] = dis.readLong();
+ }
}
+
+ return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
new file mode 100644
index 0000000000..64b567caa5
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+public class IncompatibleMergeException extends Exception {
+ public IncompatibleMergeException(String message) {
+ super(message);
+ }
+}
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index ec5b4eddec..b9c7f5c23a 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util.sketch
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
import scala.reflect.ClassTag
import scala.util.Random
@@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
private val seed = 42
+ // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized
+ // version is equivalent to the original one.
+ private def checkSerDe(sketch: CountMinSketch): Unit = {
+ val out = new ByteArrayOutputStream()
+ sketch.writeTo(out)
+
+ val in = new ByteArrayInputStream(out.toByteArray)
+ val deserialized = CountMinSketch.readFrom(in)
+
+ assert(sketch === deserialized)
+ }
+
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
test(s"accuracy - $typeName") {
- val r = new Random()
+ // Uses fixed seed to ensure reproducible test execution
+ val r = new Random(31)
val numAllItems = 1000000
val allItems = Array.fill(numAllItems)(itemGenerator(r))
@@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
}
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ checkSerDe(sketch)
+
sampledItemIndices.foreach(i => sketch.add(allItems(i)))
+ checkSerDe(sketch)
val probCorrect = {
val numErrors = allItems.map { item =>
@@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
test(s"mergeInPlace - $typeName") {
- val r = new Random()
+ // Uses fixed seed to ensure reproducible test execution
+ val r = new Random(31)
+
val numToMerge = 5
val numItemsPerSketch = 100000
val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
@@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
val sketches = perSketchItems.map { items =>
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ checkSerDe(sketch)
+
items.foreach(sketch.add)
+ checkSerDe(sketch)
+
sketch
}
val mergedSketch = sketches.reduce(_ mergeInPlace _)
+ checkSerDe(mergedSketch)
val expectedSketch = {
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
testItemType[Long]("Long") { _.nextLong() }
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+
+ test("incompatible merge") {
+ intercept[IncompatibleMergeException] {
+ CountMinSketch.create(10, 10, 1).mergeInPlace(null)
+ }
+
+ intercept[IncompatibleMergeException] {
+ val sketch1 = CountMinSketch.create(10, 20, 1)
+ val sketch2 = CountMinSketch.create(10, 20, 2)
+ sketch1.mergeInPlace(sketch2)
+ }
+
+ intercept[IncompatibleMergeException] {
+ val sketch1 = CountMinSketch.create(10, 10, 1)
+ val sketch2 = CountMinSketch.create(10, 20, 2)
+ sketch1.mergeInPlace(sketch2)
+ }
+ }
}