aboutsummaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-27 13:29:09 -0800
committerReynold Xin <rxin@databricks.com>2016-01-27 13:29:09 -0800
commit680afabe78b77e4e63e793236453d69567d24290 (patch)
tree483b4a1e2669aefec50f6293408ee16f0e5dcdad /common
parent32f741115bda5d7d7dbfcd9fe827ecbea7303ffa (diff)
downloadspark-680afabe78b77e4e63e793236453d69567d24290.tar.gz
spark-680afabe78b77e4e63e793236453d69567d24290.tar.bz2
spark-680afabe78b77e4e63e793236453d69567d24290.zip
[SPARK-12938][SQL] DataFrame API for Bloom filter
This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs. This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`. Author: Wenchen Fan <wenchen@databricks.com> Closes #10937 from cloud-fan/bloom-filter.
Diffstat (limited to 'common')
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java34
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java141
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java47
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java48
4 files changed, 179 insertions, 91 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index d392fb187a..81772fcea0 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -49,9 +49,9 @@ public abstract class BloomFilter {
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* <ul>
* <li>Version number, always 1 (32 bit)</li>
+ * <li>Number of hash functions (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
- * <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);
@@ -98,6 +98,21 @@ public abstract class BloomFilter {
public abstract boolean put(Object item);
/**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string.
+ */
+ public abstract boolean putString(String str);
+
+ /**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put long.
+ */
+ public abstract boolean putLong(long l);
+
+ /**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put byte array.
+ */
+ public abstract boolean putBinary(byte[] bytes);
+
+ /**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
* bloom filters to be compatible, they must have the same bit size.
*
@@ -122,6 +137,23 @@ public abstract class BloomFilter {
public abstract boolean mightContain(Object item);
/**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8
+ * string.
+ */
+ public abstract boolean mightContainString(String str);
+
+ /**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long.
+ */
+ public abstract boolean mightContainLong(long l);
+
+ /**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte
+ * array.
+ */
+ public abstract boolean mightContainBinary(byte[] bytes);
+
+ /**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
*/
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index 1c08d07afa..35107e0b38 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -19,10 +19,10 @@ package org.apache.spark.util.sketch;
import java.io.*;
-public class BloomFilterImpl extends BloomFilter {
+public class BloomFilterImpl extends BloomFilter implements Serializable {
- private final int numHashFunctions;
- private final BitArray bits;
+ private int numHashFunctions;
+ private BitArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
@@ -33,6 +33,8 @@ public class BloomFilterImpl extends BloomFilter {
this.numHashFunctions = numHashFunctions;
}
+ private BloomFilterImpl() {}
+
@Override
public boolean equals(Object other) {
if (other == this) {
@@ -63,55 +65,75 @@ public class BloomFilterImpl extends BloomFilter {
return bits.bitSize();
}
- private static long hashObjectToLong(Object item) {
+ @Override
+ public boolean put(Object item) {
if (item instanceof String) {
- try {
- byte[] bytes = ((String) item).getBytes("utf-8");
- return hashBytesToLong(bytes);
- } catch (UnsupportedEncodingException e) {
- throw new RuntimeException("Only support utf-8 string", e);
- }
+ return putString((String) item);
+ } else if (item instanceof byte[]) {
+ return putBinary((byte[]) item);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- int h1 = Murmur3_x86_32.hashLong(longValue, 0);
- int h2 = Murmur3_x86_32.hashLong(longValue, h1);
- return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+ return putLong(Utils.integralToLong(item));
}
}
- private static long hashBytesToLong(byte[] bytes) {
+ @Override
+ public boolean putString(String str) {
+ return putBinary(Utils.getBytesFromUTF8String(str));
+ }
+
+ @Override
+ public boolean putBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
- return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+
+ long bitSize = bits.bitSize();
+ boolean bitsChanged = false;
+ for (int i = 1; i <= numHashFunctions; i++) {
+ int combinedHash = h1 + (i * h2);
+ // Flip all the bits if it's negative (guaranteed positive number)
+ if (combinedHash < 0) {
+ combinedHash = ~combinedHash;
+ }
+ bitsChanged |= bits.set(combinedHash % bitSize);
+ }
+ return bitsChanged;
}
@Override
- public boolean put(Object item) {
+ public boolean mightContainString(String str) {
+ return mightContainBinary(Utils.getBytesFromUTF8String(str));
+ }
+
+ @Override
+ public boolean mightContainBinary(byte[] bytes) {
+ int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
+ int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
+
long bitSize = bits.bitSize();
+ for (int i = 1; i <= numHashFunctions; i++) {
+ int combinedHash = h1 + (i * h2);
+ // Flip all the bits if it's negative (guaranteed positive number)
+ if (combinedHash < 0) {
+ combinedHash = ~combinedHash;
+ }
+ if (!bits.get(combinedHash % bitSize)) {
+ return false;
+ }
+ }
+ return true;
+ }
- // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
- // values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
- // Note that `CountMinSketch` use a different strategy for long type, it hash the input long
- // element with every i to produce n hash values.
- long hash64 = hashObjectToLong(item);
- int h1 = (int) (hash64 >> 32);
- int h2 = (int) hash64;
+ @Override
+ public boolean putLong(long l) {
+ // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
+ // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
+ // Note that `CountMinSketch` use a different strategy, it hash the input long element with
+ // every i to produce n hash values.
+ // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
+ int h1 = Murmur3_x86_32.hashLong(l, 0);
+ int h2 = Murmur3_x86_32.hashLong(l, h1);
+ long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
@@ -125,12 +147,11 @@ public class BloomFilterImpl extends BloomFilter {
}
@Override
- public boolean mightContain(Object item) {
- long bitSize = bits.bitSize();
- long hash64 = hashObjectToLong(item);
- int h1 = (int) (hash64 >> 32);
- int h2 = (int) hash64;
+ public boolean mightContainLong(long l) {
+ int h1 = Murmur3_x86_32.hashLong(l, 0);
+ int h2 = Murmur3_x86_32.hashLong(l, h1);
+ long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
@@ -145,6 +166,17 @@ public class BloomFilterImpl extends BloomFilter {
}
@Override
+ public boolean mightContain(Object item) {
+ if (item instanceof String) {
+ return mightContainString((String) item);
+ } else if (item instanceof byte[]) {
+ return mightContainBinary((byte[]) item);
+ } else {
+ return mightContainLong(Utils.integralToLong(item));
+ }
+ }
+
+ @Override
public boolean isCompatible(BloomFilter other) {
if (other == null) {
return false;
@@ -191,11 +223,11 @@ public class BloomFilterImpl extends BloomFilter {
DataOutputStream dos = new DataOutputStream(out);
dos.writeInt(Version.V1.getVersionNumber());
- bits.writeTo(dos);
dos.writeInt(numHashFunctions);
+ bits.writeTo(dos);
}
- public static BloomFilterImpl readFrom(InputStream in) throws IOException {
+ private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
int version = dis.readInt();
@@ -203,6 +235,21 @@ public class BloomFilterImpl extends BloomFilter {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}
- return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
+ this.numHashFunctions = dis.readInt();
+ this.bits = BitArray.readFrom(dis);
+ }
+
+ public static BloomFilterImpl readFrom(InputStream in) throws IOException {
+ BloomFilterImpl filter = new BloomFilterImpl();
+ filter.readFrom0(in);
+ return filter;
+ }
+
+ private void writeObject(ObjectOutputStream out) throws IOException {
+ writeTo(out);
+ }
+
+ private void readObject(ObjectInputStream in) throws IOException {
+ readFrom0(in);
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 8cc29e4076..e49ae22906 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
private double eps;
private double confidence;
- private CountMinSketchImpl() {
- }
+ private CountMinSketchImpl() {}
CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
@@ -143,23 +142,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
if (item instanceof String) {
addString((String) item, count);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- addLong(longValue, count);
+ addLong(Utils.integralToLong(item), count);
}
}
@@ -201,13 +184,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
}
private static int[] getHashBuckets(String key, int hashCount, int max) {
- byte[] b;
- try {
- b = key.getBytes("UTF-8");
- } catch (UnsupportedEncodingException e) {
- throw new RuntimeException(e);
- }
- return getHashBuckets(b, hashCount, max);
+ return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
@@ -225,23 +202,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- return estimateCountForLongItem(longValue);
+ return estimateCountForLongItem(Utils.integralToLong(item));
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java
new file mode 100644
index 0000000000..a6b3331303
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.UnsupportedEncodingException;
+
+public class Utils {
+ public static byte[] getBytesFromUTF8String(String str) {
+ try {
+ return str.getBytes("utf-8");
+ } catch (UnsupportedEncodingException e) {
+ throw new IllegalArgumentException("Only support utf-8 string", e);
+ }
+ }
+
+ public static long integralToLong(Object i) {
+ long longValue;
+
+ if (i instanceof Long) {
+ longValue = (Long) i;
+ } else if (i instanceof Integer) {
+ longValue = ((Integer) i).longValue();
+ } else if (i instanceof Short) {
+ longValue = ((Short) i).longValue();
+ } else if (i instanceof Byte) {
+ longValue = ((Byte) i).longValue();
+ } else {
+ throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName());
+ }
+
+ return longValue;
+ }
+}