aboutsummaryrefslogtreecommitdiff
path: root/common/sketch
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-26 00:53:05 -0800
committerReynold Xin <rxin@databricks.com>2016-01-26 00:53:05 -0800
commit6743de3a98e3f0d0e6064ca1872fa88c3aeaa143 (patch)
treee51af54a94ebb5481c0c4e2cfe75cbd5ec42cfba /common/sketch
parentd54cfed5a6953a9ce2b9de2f31ee2d673cb5cc62 (diff)
downloadspark-6743de3a98e3f0d0e6064ca1872fa88c3aeaa143.tar.gz
spark-6743de3a98e3f0d0e6064ca1872fa88c3aeaa143.tar.bz2
spark-6743de3a98e3f0d0e6064ca1872fa88c3aeaa143.zip
[SPARK-12937][SQL] bloom filter serialization
This PR adds serialization support for BloomFilter. A version number is added to version the serialized binary format. Author: Wenchen Fan <wenchen@databricks.com> Closes #10920 from cloud-fan/bloom-filter.
Diffstat (limited to 'common/sketch')
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java46
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java42
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java48
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java25
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java22
-rw-r--r--common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala20
6 files changed, 159 insertions, 44 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
index 1bc665ad54..2a0484e324 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
@@ -17,6 +17,9 @@
package org.apache.spark.util.sketch;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
import java.util.Arrays;
public final class BitArray {
@@ -24,6 +27,9 @@ public final class BitArray {
private long bitCount;
static int numWords(long numBits) {
+ if (numBits <= 0) {
+ throw new IllegalArgumentException("numBits must be positive, but got " + numBits);
+ }
long numWords = (long) Math.ceil(numBits / 64.0);
if (numWords > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits");
@@ -32,13 +38,14 @@ public final class BitArray {
}
BitArray(long numBits) {
- if (numBits <= 0) {
- throw new IllegalArgumentException("numBits must be positive");
- }
- this.data = new long[numWords(numBits)];
+ this(new long[numWords(numBits)]);
+ }
+
+ private BitArray(long[] data) {
+ this.data = data;
long bitCount = 0;
- for (long value : data) {
- bitCount += Long.bitCount(value);
+ for (long word : data) {
+ bitCount += Long.bitCount(word);
}
this.bitCount = bitCount;
}
@@ -78,13 +85,28 @@ public final class BitArray {
this.bitCount = bitCount;
}
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || !(o instanceof BitArray)) return false;
+ void writeTo(DataOutputStream out) throws IOException {
+ out.writeInt(data.length);
+ for (long datum : data) {
+ out.writeLong(datum);
+ }
+ }
+
+ static BitArray readFrom(DataInputStream in) throws IOException {
+ int numWords = in.readInt();
+ long[] data = new long[numWords];
+ for (int i = 0; i < numWords; i++) {
+ data[i] = in.readLong();
+ }
+ return new BitArray(data);
+ }
- BitArray bitArray = (BitArray) o;
- return Arrays.equals(data, bitArray.data);
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || !(other instanceof BitArray)) return false;
+ BitArray that = (BitArray) other;
+ return Arrays.equals(data, that.data);
}
@Override
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 38949c6311..00378d5851 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -17,6 +17,10 @@
package org.apache.spark.util.sketch;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
/**
* A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether
* an element is a member of a set. It returns false when the element is definitely not in the
@@ -39,6 +43,28 @@ package org.apache.spark.util.sketch;
* The implementation is largely based on the {@code BloomFilter} class from guava.
*/
public abstract class BloomFilter {
+
+ public enum Version {
+ /**
+ * {@code BloomFilter} binary format version 1 (all values written in big-endian order):
+ * - Version number, always 1 (32 bit)
+ * - Total number of words of the underlying bit array (32 bit)
+ * - The words/longs (numWords * 64 bit)
+ * - Number of hash functions (32 bit)
+ */
+ V1(1);
+
+ private final int versionNumber;
+
+ Version(int versionNumber) {
+ this.versionNumber = versionNumber;
+ }
+
+ int getVersionNumber() {
+ return versionNumber;
+ }
+ }
+
/**
* Returns the false positive probability, i.e. the probability that
* {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that
@@ -83,7 +109,7 @@ public abstract class BloomFilter {
* bloom filters are appropriately sized to avoid saturating them.
*
* @param other The bloom filter to combine this bloom filter with. It is not mutated.
- * @throws IllegalArgumentException if {@code isCompatible(that) == false}
+ * @throws IncompatibleMergeException if {@code isCompatible(other) == false}
*/
public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException;
@@ -94,6 +120,20 @@ public abstract class BloomFilter {
public abstract boolean mightContain(Object item);
/**
+ * Writes out this {@link BloomFilter} to an output stream in binary format.
+ * It is the caller's responsibility to close the stream.
+ */
+ public abstract void writeTo(OutputStream out) throws IOException;
+
+ /**
+ * Reads in a {@link BloomFilter} from an input stream.
+ * It is the caller's responsibility to close the stream.
+ */
+ public static BloomFilter readFrom(InputStream in) throws IOException {
+ return BloomFilterImpl.readFrom(in);
+ }
+
+ /**
* Computes the optimal k (number of hashes per element inserted in Bloom filter), given the
* expected insertions and total number of bits in the Bloom filter.
*
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index bbd6cf719d..1c08d07afa 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -17,7 +17,7 @@
package org.apache.spark.util.sketch;
-import java.io.UnsupportedEncodingException;
+import java.io.*;
public class BloomFilterImpl extends BloomFilter {
@@ -25,8 +25,32 @@ public class BloomFilterImpl extends BloomFilter {
private final BitArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
+ this(new BitArray(numBits), numHashFunctions);
+ }
+
+ private BloomFilterImpl(BitArray bits, int numHashFunctions) {
+ this.bits = bits;
this.numHashFunctions = numHashFunctions;
- this.bits = new BitArray(numBits);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) {
+ return true;
+ }
+
+ if (other == null || !(other instanceof BloomFilterImpl)) {
+ return false;
+ }
+
+ BloomFilterImpl that = (BloomFilterImpl) other;
+
+ return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits);
+ }
+
+ @Override
+ public int hashCode() {
+ return bits.hashCode() * 31 + numHashFunctions;
}
@Override
@@ -161,4 +185,24 @@ public class BloomFilterImpl extends BloomFilter {
this.bits.putAll(that.bits);
return this;
}
+
+ @Override
+ public void writeTo(OutputStream out) throws IOException {
+ DataOutputStream dos = new DataOutputStream(out);
+
+ dos.writeInt(Version.V1.getVersionNumber());
+ bits.writeTo(dos);
+ dos.writeInt(numHashFunctions);
+ }
+
+ public static BloomFilterImpl readFrom(InputStream in) throws IOException {
+ DataInputStream dis = new DataInputStream(in);
+
+ int version = dis.readInt();
+ if (version != Version.V1.getVersionNumber()) {
+ throw new IOException("Unexpected Bloom filter version number (" + version + ")");
+ }
+
+ return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
+ }
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 9f4ff42403..00c0b1b9e2 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -55,10 +55,21 @@ import java.io.OutputStream;
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
*/
abstract public class CountMinSketch {
- /**
- * Version number of the serialized binary format.
- */
+
public enum Version {
+ /**
+ * {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
+ * - Version number, always 1 (32 bit)
+ * - Total count of added items (64 bit)
+ * - Depth (32 bit)
+ * - Width (32 bit)
+ * - Hash functions (depth * 64 bit)
+ * - Count table
+ * - Row 0 (width * 64 bit)
+ * - Row 1 (width * 64 bit)
+ * - ...
+ * - Row depth - 1 (width * 64 bit)
+ */
V1(1);
private final int versionNumber;
@@ -67,13 +78,11 @@ abstract public class CountMinSketch {
this.versionNumber = versionNumber;
}
- public int getVersionNumber() {
+ int getVersionNumber() {
return versionNumber;
}
}
- public abstract Version version();
-
/**
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
*/
@@ -128,13 +137,13 @@ abstract public class CountMinSketch {
/**
* Writes out this {@link CountMinSketch} to an output stream in binary format.
- * It is the caller's responsibility to close the stream
+ * It is the caller's responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;
/**
* Reads in a {@link CountMinSketch} from an input stream.
- * It is the caller's responsibility to close the stream
+ * It is the caller's responsibility to close the stream.
*/
public static CountMinSketch readFrom(InputStream in) throws IOException {
return CountMinSketchImpl.readFrom(in);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 0209446ea3..d08809605a 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -26,21 +26,6 @@ import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
-/*
- * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
- * order):
- *
- * - Version number, always 1 (32 bit)
- * - Total count of added items (64 bit)
- * - Depth (32 bit)
- * - Width (32 bit)
- * - Hash functions (depth * 64 bit)
- * - Count table
- * - Row 0 (width * 64 bit)
- * - Row 1 (width * 64 bit)
- * - ...
- * - Row depth - 1 (width * 64 bit)
- */
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
@@ -112,11 +97,6 @@ class CountMinSketchImpl extends CountMinSketch {
return hash;
}
- @Override
- public Version version() {
- return Version.V1;
- }
-
private void initTablesWith(int depth, int width, int seed) {
this.table = new long[depth][width];
this.hashA = new long[depth];
@@ -327,7 +307,7 @@ class CountMinSketchImpl extends CountMinSketch {
public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);
- dos.writeInt(version().getVersionNumber());
+ dos.writeInt(Version.V1.getVersionNumber());
dos.writeLong(this.totalCount);
dos.writeInt(this.depth);
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
index d2de509f19..a0408d2da4 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util.sketch
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
import scala.reflect.ClassTag
import scala.util.Random
@@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite
class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
private final val EPSILON = 0.01
+ // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized
+ // version is equivalent to the original one.
+ private def checkSerDe(filter: BloomFilter): Unit = {
+ val out = new ByteArrayOutputStream()
+ filter.writeTo(out)
+ out.close()
+
+ val in = new ByteArrayInputStream(out.toByteArray)
+ val deserialized = BloomFilter.readFrom(in)
+ in.close()
+
+ assert(filter == deserialized)
+ }
+
def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
test(s"accuracy - $typeName") {
// use a fixed seed to make the test predictable.
@@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
// Also check the actual fpp is not significantly higher than we expected.
val actualFpp = errorCount.toDouble / (numItems - numInsertion)
assert(actualFpp - fpp < EPSILON)
+
+ checkSerDe(filter)
}
}
@@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
items1.foreach(i => assert(filter1.mightContain(i)))
items2.foreach(i => assert(filter1.mightContain(i)))
+
+ checkSerDe(filter1)
}
}