diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-01-27 13:29:09 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-27 13:29:09 -0800 |
commit | 680afabe78b77e4e63e793236453d69567d24290 (patch) | |
tree | 483b4a1e2669aefec50f6293408ee16f0e5dcdad /sql | |
parent | 32f741115bda5d7d7dbfcd9fe827ecbea7303ffa (diff) | |
download | spark-680afabe78b77e4e63e793236453d69567d24290.tar.gz spark-680afabe78b77e4e63e793236453d69567d24290.tar.bz2 spark-680afabe78b77e4e63e793236453d69567d24290.zip |
[SPARK-12938][SQL] DataFrame API for Bloom filter
This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs.
This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10937 from cloud-fan/bloom-filter.
Diffstat (limited to 'sql')
3 files changed, 127 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 465b12bb59..b0b6995a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,9 +22,10 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * :: Experimental :: @@ -390,4 +391,75 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { } ) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { + (filter, row) => + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + filter.putBinary(row.getUTF8String(0).getBytes) + filter + } else { + (filter, row) => + // TODO: specialize it. + filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) + filter + } + + singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9cf94e72d3..0d4c128cb3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -40,6 +40,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -300,6 +301,7 @@ public class JavaDataFrameSuite { Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + @Test public void testGenericLoad() { DataFrame df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); @@ -347,4 +349,33 @@ public class JavaDataFrameSuite { Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); } + + @Test + public void testBloomFilter() { + DataFrame df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + assert (filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + assert (filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + assert (filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + assert (filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + assert (filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + assert (filter4.mightContain(i * 3)); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f3ea5a286..f01f126f76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -246,4 +246,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { .countMinSketch('id, depth = 10, width = 20, seed = 42) } } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } } |