aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-27 13:29:09 -0800
committerReynold Xin <rxin@databricks.com>2016-01-27 13:29:09 -0800
commit680afabe78b77e4e63e793236453d69567d24290 (patch)
tree483b4a1e2669aefec50f6293408ee16f0e5dcdad /sql
parent32f741115bda5d7d7dbfcd9fe827ecbea7303ffa (diff)
downloadspark-680afabe78b77e4e63e793236453d69567d24290.tar.gz
spark-680afabe78b77e4e63e793236453d69567d24290.tar.bz2
spark-680afabe78b77e4e63e793236453d69567d24290.zip
[SPARK-12938][SQL] DataFrame API for Bloom filter
This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs. This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`. Author: Wenchen Fan <wenchen@databricks.com> Closes #10937 from cloud-fan/bloom-filter.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala76
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala22
3 files changed, 127 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 465b12bb59..b0b6995a22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,9 +22,10 @@ import java.{lang => jl, util => ju}
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.sketch.CountMinSketch
+import org.apache.spark.sql.types.{IntegralType, StringType}
+import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
* :: Experimental ::
@@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
}
)
}
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param colName name of the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param fpp expected false positive probability of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
+ buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param col the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param fpp expected false positive probability of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
+ buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param colName name of the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param numBits expected number of bits of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
+ buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits))
+ }
+
+ /**
+ * Builds a Bloom filter over a specified column.
+ *
+ * @param col the column over which the filter is built
+ * @param expectedNumItems expected number of items which will be put into the filter.
+ * @param numBits expected number of bits of the filter.
+ * @since 2.0.0
+ */
+ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
+ buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits))
+ }
+
+ private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = {
+ val singleCol = df.select(col)
+ val colType = singleCol.schema.head.dataType
+
+ require(colType == StringType || colType.isInstanceOf[IntegralType],
+ s"Bloom filter only supports string type and integral types, but got $colType.")
+
+ val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
+ (filter, row) =>
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
+ // instead of `putString` to avoid unnecessary conversion.
+ filter.putBinary(row.getUTF8String(0).getBytes)
+ filter
+ } else {
+ (filter, row) =>
+ // TODO: specialize it.
+ filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
+ filter
+ }
+
+ singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 9cf94e72d3..0d4c128cb3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -40,6 +40,7 @@ import org.apache.spark.sql.types.*;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
+import org.apache.spark.util.sketch.BloomFilter;
public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
@@ -300,6 +301,7 @@ public class JavaDataFrameSuite {
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}
+ @Test
public void testGenericLoad() {
DataFrame df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
@@ -347,4 +349,33 @@ public class JavaDataFrameSuite {
Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
}
+
+ @Test
+ public void testBloomFilter() {
+ DataFrame df = context.range(1000);
+
+ BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
+ assert (filter1.expectedFpp() - 0.03 < 1e-3);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter1.mightContain(i));
+ }
+
+ BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03);
+ assert (filter2.expectedFpp() - 0.03 < 1e-3);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter2.mightContain(i * 3));
+ }
+
+ BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5);
+ assert (filter3.bitSize() == 64 * 5);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter3.mightContain(i));
+ }
+
+ BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5);
+ assert (filter4.bitSize() == 64 * 5);
+ for (int i = 0; i < 1000; i++) {
+ assert (filter4.mightContain(i * 3));
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8f3ea5a286..f01f126f76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
.countMinSketch('id, depth = 10, width = 20, seed = 42)
}
}
+
+ // This test only verifies some basic requirements, more correctness tests can be found in
+ // `BloomFilterSuite` in project spark-sketch.
+ test("Bloom filter") {
+ val df = sqlContext.range(1000)
+
+ val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
+ assert(filter1.expectedFpp() - 0.03 < 1e-3)
+ assert(0.until(1000).forall(filter1.mightContain))
+
+ val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03)
+ assert(filter2.expectedFpp() - 0.03 < 1e-3)
+ assert(0.until(1000).forall(i => filter2.mightContain(i * 3)))
+
+ val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5)
+ assert(filter3.bitSize() == 64 * 5)
+ assert(0.until(1000).forall(filter3.mightContain))
+
+ val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5)
+ assert(filter4.bitSize() == 64 * 5)
+ assert(0.until(1000).forall(i => filter4.mightContain(i * 3)))
+ }
}