aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java34
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java141
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java47
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala76
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala22
7 files changed, 306 insertions, 93 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index d392fb187a..81772fcea0 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -49,9 +49,9 @@ public abstract class BloomFilter {
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* <ul>
* <li>Version number, always 1 (32 bit)</li>
+ * <li>Number of hash functions (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
- * <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);
@@ -98,6 +98,21 @@ public abstract class BloomFilter {
public abstract boolean put(Object item);
/**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string.
+ */
+ public abstract boolean putString(String str);
+
+ /**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put long.
+ */
+ public abstract boolean putLong(long l);
+
+ /**
+ * A specialized variant of {@link #put(Object)}, that can only be used to put byte array.
+ */
+ public abstract boolean putBinary(byte[] bytes);
+
+ /**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
* bloom filters to be compatible, they must have the same bit size.
*
@@ -122,6 +137,23 @@ public abstract class BloomFilter {
public abstract boolean mightContain(Object item);
/**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8
+ * string.
+ */
+ public abstract boolean mightContainString(String str);
+
+ /**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long.
+ */
+ public abstract boolean mightContainLong(long l);
+
+ /**
+ * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte
+ * array.
+ */
+ public abstract boolean mightContainBinary(byte[] bytes);
+
+ /**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
*/
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index 1c08d07afa..35107e0b38 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -19,10 +19,10 @@ package org.apache.spark.util.sketch;
import java.io.*;
-public class BloomFilterImpl extends BloomFilter {
+public class BloomFilterImpl extends BloomFilter implements Serializable {
- private final int numHashFunctions;
- private final BitArray bits;
+ private int numHashFunctions;
+ private BitArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
@@ -33,6 +33,8 @@ public class BloomFilterImpl extends BloomFilter {
this.numHashFunctions = numHashFunctions;
}
+ private BloomFilterImpl() {}
+
@Override
public boolean equals(Object other) {
if (other == this) {
@@ -63,55 +65,75 @@ public class BloomFilterImpl extends BloomFilter {
return bits.bitSize();
}
- private static long hashObjectToLong(Object item) {
+ @Override
+ public boolean put(Object item) {
if (item instanceof String) {
- try {
- byte[] bytes = ((String) item).getBytes("utf-8");
- return hashBytesToLong(bytes);
- } catch (UnsupportedEncodingException e) {
- throw new RuntimeException("Only support utf-8 string", e);
- }
+ return putString((String) item);
+ } else if (item instanceof byte[]) {
+ return putBinary((byte[]) item);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- int h1 = Murmur3_x86_32.hashLong(longValue, 0);
- int h2 = Murmur3_x86_32.hashLong(longValue, h1);
- return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+ return putLong(Utils.integralToLong(item));
}
}
- private static long hashBytesToLong(byte[] bytes) {
+ @Override
+ public boolean putString(String str) {
+ return putBinary(Utils.getBytesFromUTF8String(str));
+ }
+
+ @Override
+ public boolean putBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
- return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+
+ long bitSize = bits.bitSize();
+ boolean bitsChanged = false;
+ for (int i = 1; i <= numHashFunctions; i++) {
+ int combinedHash = h1 + (i * h2);
+ // Flip all the bits if it's negative (guaranteed positive number)
+ if (combinedHash < 0) {
+ combinedHash = ~combinedHash;
+ }
+ bitsChanged |= bits.set(combinedHash % bitSize);
+ }
+ return bitsChanged;
}
@Override
- public boolean put(Object item) {
+ public boolean mightContainString(String str) {
+ return mightContainBinary(Utils.getBytesFromUTF8String(str));
+ }
+
+ @Override
+ public boolean mightContainBinary(byte[] bytes) {
+ int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
+ int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
+
long bitSize = bits.bitSize();
+ for (int i = 1; i <= numHashFunctions; i++) {
+ int combinedHash = h1 + (i * h2);
+ // Flip all the bits if it's negative (guaranteed positive number)
+ if (combinedHash < 0) {
+ combinedHash = ~combinedHash;
+ }
+ if (!bits.get(combinedHash % bitSize)) {
+ return false;
+ }
+ }
+ return true;
+ }
- // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
- // values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
- // Note that `CountMinSketch` use a different strategy for long type, it hash the input long
- // element with every i to produce n hash values.
- long hash64 = hashObjectToLong(item);
- int h1 = (int) (hash64 >> 32);
- int h2 = (int) hash64;
+ @Override
+ public boolean putLong(long l) {
+ // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
+ // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
+ // Note that `CountMinSketch` use a different strategy, it hash the input long element with
+ // every i to produce n hash values.
+ // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
+ int h1 = Murmur3_x86_32.hashLong(l, 0);
+ int h2 = Murmur3_x86_32.hashLong(l, h1);
+ long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
@@ -125,12 +147,11 @@ public class BloomFilterImpl extends BloomFilter {
}
@Override
- public boolean mightContain(Object item) {
- long bitSize = bits.bitSize();
- long hash64 = hashObjectToLong(item);
- int h1 = (int) (hash64 >> 32);
- int h2 = (int) hash64;
+ public boolean mightContainLong(long l) {
+ int h1 = Murmur3_x86_32.hashLong(l, 0);
+ int h2 = Murmur3_x86_32.hashLong(l, h1);
+ long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
@@ -145,6 +166,17 @@ public class BloomFilterImpl extends BloomFilter {
}
@Override
+ public boolean mightContain(Object item) {
+ if (item instanceof String) {
+ return mightContainString((String) item);
+ } else if (item instanceof byte[]) {
+ return mightContainBinary((byte[]) item);
+ } else {
+ return mightContainLong(Utils.integralToLong(item));
+ }
+ }
+
+ @Override
public boolean isCompatible(BloomFilter other) {
if (other == null) {
return false;
@@ -191,11 +223,11 @@ public class BloomFilterImpl extends BloomFilter {
DataOutputStream dos = new DataOutputStream(out);
dos.writeInt(Version.V1.getVersionNumber());
- bits.writeTo(dos);
dos.writeInt(numHashFunctions);
+ bits.writeTo(dos);
}
- public static BloomFilterImpl readFrom(InputStream in) throws IOException {
+ private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
int version = dis.readInt();
@@ -203,6 +235,21 @@ public class BloomFilterImpl extends BloomFilter {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}
- return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
+ this.numHashFunctions = dis.readInt();
+ this.bits = BitArray.readFrom(dis);
+ }
+
+ public static BloomFilterImpl readFrom(InputStream in) throws IOException {
+ BloomFilterImpl filter = new BloomFilterImpl();
+ filter.readFrom0(in);
+ return filter;
+ }
+
+ private void writeObject(ObjectOutputStream out) throws IOException {
+ writeTo(out);
+ }
+
+ private void readObject(ObjectInputStream in) throws IOException {
+ readFrom0(in);
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 8cc29e4076..e49ae22906 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
private double eps;
private double confidence;
- private CountMinSketchImpl() {
- }
+ private CountMinSketchImpl() {}
CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
@@ -143,23 +142,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
if (item instanceof String) {
addString((String) item, count);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- addLong(longValue, count);
+ addLong(Utils.integralToLong(item), count);
}
}
@@ -201,13 +184,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
}
private static int[] getHashBuckets(String key, int hashCount, int max) {
- byte[] b;
- try {
- b = key.getBytes("UTF-8");
- } catch (UnsupportedEncodingException e) {
- throw new RuntimeException(e);
- }
- return getHashBuckets(b, hashCount, max);
+ return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}
private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
@@ -225,23 +202,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
} else {
- long longValue;
-
- if (item instanceof Long) {
- longValue = (Long) item;
- } else if (item instanceof Integer) {
- longValue = ((Integer) item).longValue();
- } else if (item instanceof Short) {
- longValue = ((Short) item).longValue();
- } else if (item instanceof Byte) {
- longValue = ((Byte) item).longValue();
- } else {
- throw new IllegalArgumentException(
- "Support for " + item.getClass().getName() + " not implemented"
- );
- }
-
- return estimateCountForLongItem(longValue);
+ return estimateCountForLongItem(Utils.integralToLong(item));
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java
new file mode 100644
index 0000000000..a6b3331303
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.UnsupportedEncodingException;
+
+public class Utils {
+ public static byte[] getBytesFromUTF8String(String str) {
+ try {
+ return str.getBytes("utf-8");
+ } catch (UnsupportedEncodingException e) {
+ throw new IllegalArgumentException("Only support utf-8 string", e);
+ }
+ }
+
+ public static long integralToLong(Object i) {
+ long longValue;
+
+ if (i instanceof Long) {
+ longValue = (Long) i;
+ } else if (i instanceof Integer) {
+ longValue = ((Integer) i).longValue();
+ } else if (i instanceof Short) {
+ longValue = ((Short) i).longValue();
+ } else if (i instanceof Byte) {
+ longValue = ((Byte) i).longValue();
+ } else {
+ throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName());
+ }
+
+ return longValue;
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 465b12bb59..b0b6995a22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,9 +22,10 @@ import java.{lang => jl, util => ju}
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.sketch.CountMinSketch
+import org.apache.spark.sql.types.{IntegralType, StringType}
+import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
* :: Experimental ::
@@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
}
)
}
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param colName name of the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param fpp expected false positive probability of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
+ buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param col the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param fpp expected false positive probability of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
+ buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param colName name of the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param numBits expected number of bits of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
+ buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param col the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param numBits expected number of bits of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
+ buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits))
+ }
+
+ private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = {
+ val singleCol = df.select(col)
+ val colType = singleCol.schema.head.dataType
+
+ require(colType == StringType || colType.isInstanceOf[IntegralType],
+ s"Bloom filter only supports string type and integral types, but got $colType.")
+
+ val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
+ (filter, row) =>
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
+ // instead of `putString` to avoid unnecessary conversion.
+ filter.putBinary(row.getUTF8String(0).getBytes)
+ filter
+ } else {
+ (filter, row) =>
+ // TODO: specialize it.
+ filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
+ filter
+ }
+
+ singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 9cf94e72d3..0d4c128cb3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -40,6 +40,7 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
+import org.apache.spark.util.sketch.BloomFilter;
public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
@@ -300,6 +301,7 @@ public class JavaDataFrameSuite {
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}
+ @Test
public void testGenericLoad() {
DataFrame df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
@@ -347,4 +349,33 @@ public class JavaDataFrameSuite {
Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
}
+
+ @Test
+ public void testBloomFilter() {
+ DataFrame df = context.range(1000);
+
+ BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
+ assert (filter1.expectedFpp() - 0.03 < 1e-3);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter1.mightContain(i));
+ }
+
+ BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03);
+ assert (filter2.expectedFpp() - 0.03 < 1e-3);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter2.mightContain(i * 3));
+ }
+
+ BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5);
+ assert (filter3.bitSize() == 64 * 5);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter3.mightContain(i));
+ }
+
+ BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5);
+ assert (filter4.bitSize() == 64 * 5);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter4.mightContain(i * 3));
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8f3ea5a286..f01f126f76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
.countMinSketch('id, depth = 10, width = 20, seed = 42)
}
}
+
+ // This test only verifies some basic requirements, more correctness tests can be found in
+ // `BloomFilterSuite` in project spark-sketch.
+ test("Bloom filter") {
+ val df = sqlContext.range(1000)
+
+ val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
+ assert(filter1.expectedFpp() - 0.03 < 1e-3)
+ assert(0.until(1000).forall(filter1.mightContain))
+
+ val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03)
+ assert(filter2.expectedFpp() - 0.03 < 1e-3)
+ assert(0.until(1000).forall(i => filter2.mightContain(i * 3)))
+
+ val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5)
+ assert(filter3.bitSize() == 64 * 5)
+ assert(0.until(1000).forall(filter3.mightContain))
+
+ val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5)
+ assert(filter4.bitSize() == 64 * 5)
+ assert(0.until(1000).forall(i => filter4.mightContain(i * 3)))
+ }
}