From 6743de3a98e3f0d0e6064ca1872fa88c3aeaa143 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jan 2016 00:53:05 -0800 Subject: [SPARK-12937][SQL] bloom filter serialization This PR adds serialization support for BloomFilter. A version number is added to version the serialized binary format. Author: Wenchen Fan Closes #10920 from cloud-fan/bloom-filter. --- .../org/apache/spark/util/sketch/BitArray.java | 46 +++++++++++++++------ .../org/apache/spark/util/sketch/BloomFilter.java | 42 ++++++++++++++++++- .../apache/spark/util/sketch/BloomFilterImpl.java | 48 +++++++++++++++++++++- .../apache/spark/util/sketch/CountMinSketch.java | 25 +++++++---- .../spark/util/sketch/CountMinSketchImpl.java | 22 +--------- .../spark/util/sketch/BloomFilterSuite.scala | 20 +++++++++ 6 files changed, 159 insertions(+), 44 deletions(-) (limited to 'common') diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 1bc665ad54..2a0484e324 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -17,6 +17,9 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.Arrays; public final class BitArray { @@ -24,6 +27,9 @@ public final class BitArray { private long bitCount; static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); @@ -32,13 +38,14 @@ public final class BitArray { } BitArray(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); - } - this.data = new long[numWords(numBits)]; + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; long bitCount = 0; - for (long value : data) { - bitCount += Long.bitCount(value); + for (long word : data) { + bitCount += Long.bitCount(word); } this.bitCount = bitCount; } @@ -78,13 +85,28 @@ public final class BitArray { this.bitCount = bitCount; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof BitArray)) return false; + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } + + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } - BitArray bitArray = (BitArray) o; - return Arrays.equals(data, bitArray.data); + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); } @Override diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 38949c6311..00378d5851 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,6 +17,10 @@ package org.apache.spark.util.sketch; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + /** * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether * an element is a member of a set. It returns false when the element is definitely not in the @@ -39,6 +43,28 @@ package org.apache.spark.util.sketch; * The implementation is largely based on the {@code BloomFilter} class from guava. */ public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total number of words of the underlying bit array (32 bit) + * - The words/longs (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + /** * Returns the false positive probability, i.e. the probability that * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that @@ -83,7 +109,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IllegalArgumentException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; @@ -93,6 +119,20 @@ public abstract class BloomFilter { */ public abstract boolean mightContain(Object item); + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. + * It is the caller's responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. + * It is the caller's responsibility to close the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + /** * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index bbd6cf719d..1c08d07afa 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -17,7 +17,7 @@ package org.apache.spark.util.sketch; -import java.io.UnsupportedEncodingException; +import java.io.*; public class BloomFilterImpl extends BloomFilter { @@ -25,8 +25,32 @@ public class BloomFilterImpl extends BloomFilter { private final BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; this.numHashFunctions = numHashFunctions; - this.bits = new BitArray(numBits); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; } @Override @@ -161,4 +185,24 @@ public class BloomFilterImpl extends BloomFilter { this.bits.putAll(that.bits); return this; } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + bits.writeTo(dos); + dos.writeInt(numHashFunctions); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 9f4ff42403..00c0b1b9e2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,10 +55,21 @@ import java.io.OutputStream; * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { - /** - * Version number of the serialized binary format. - */ + public enum Version { + /** + * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ V1(1); private final int versionNumber; @@ -67,13 +78,11 @@ abstract public class CountMinSketch { this.versionNumber = versionNumber; } - public int getVersionNumber() { + int getVersionNumber() { return versionNumber; } } - public abstract Version version(); - /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -128,13 +137,13 @@ abstract public class CountMinSketch { /** * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 0209446ea3..d08809605a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -26,21 +26,6 @@ import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; -/* - * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian - * order): - * - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) - */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -112,11 +97,6 @@ class CountMinSketchImpl extends CountMinSketch { return hash; } - @Override - public Version version() { - return Version.V1; - } - private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -327,7 +307,7 @@ class CountMinSketchImpl extends CountMinSketch { public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(version().getVersionNumber()); + dos.writeInt(Version.V1.getVersionNumber()); dos.writeLong(this.totalCount); dos.writeInt(this.depth); diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index d2de509f19..a0408d2da4 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite private final val EPSILON = 0.01 + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { test(s"accuracy - $typeName") { // use a fixed seed to make the test predictable. @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite // Also check the actual fpp is not significantly higher than we expected. val actualFpp = errorCount.toDouble / (numItems - numInsertion) assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) } } @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite items1.foreach(i => assert(filter1.mightContain(i))) items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) } } -- cgit v1.2.3