From ce38a35b764397fcf561ac81de6da96579f5c13e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 26 Jan 2016 20:12:34 -0800 Subject: [SPARK-12935][SQL] DataFrame API for Count-Min Sketch This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs. Author: Cheng Lian Closes #10911 from liancheng/cms-df-api. --- .../org/apache/spark/util/sketch/BloomFilter.java | 10 +-- .../apache/spark/util/sketch/CountMinSketch.java | 26 ++++--- .../spark/util/sketch/CountMinSketchImpl.java | 56 +++++++++------ sql/core/pom.xml | 5 ++ .../apache/spark/sql/DataFrameStatFunctions.scala | 81 ++++++++++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 28 +++++++- .../org/apache/spark/sql/DataFrameStatSuite.scala | 36 ++++++++++ 7 files changed, 205 insertions(+), 37 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 00378d5851..d392fb187a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -47,10 +47,12 @@ public abstract class BloomFilter { public enum Version { /** * {@code BloomFilter} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total number of words of the underlying bit array (32 bit) - * - The words/longs (numWords * 64 bit) - * - Number of hash functions (32 bit) + * */ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 00c0b1b9e2..5692e574d4 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -59,16 +59,22 @@ abstract public class CountMinSketch { public enum Version { /** * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) + * */ V1(1); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index d08809605a..8cc29e4076 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -21,13 +21,16 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; +import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; -class CountMinSketchImpl extends CountMinSketch { - public static final long PRIME_MODULUS = (1L << 31) - 1; +class CountMinSketchImpl extends CountMinSketch implements Serializable { + private static final long PRIME_MODULUS = (1L << 31) - 1; private int depth; private int width; @@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; + private CountMinSketchImpl() { + } + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; @@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch { initTablesWith(depth, width, seed); } - CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { - this.depth = depth; - this.width = width; - this.eps = 2.0 / width; - this.confidence = 1 - 1 / Math.pow(2, depth); - this.hashA = hashA; - this.table = table; - this.totalCount = totalCount; - } - @Override public boolean equals(Object other) { if (other == this) { @@ -325,27 +321,43 @@ class CountMinSketchImpl extends CountMinSketch { } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + CountMinSketchImpl sketch = new CountMinSketchImpl(); + sketch.readFrom0(in); + return sketch; + } + + private void readFrom0(InputStream in) throws IOException { DataInputStream dis = new DataInputStream(in); - // Ignores version number - dis.readInt(); + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")"); + } - long totalCount = dis.readLong(); - int depth = dis.readInt(); - int width = dis.readInt(); + this.totalCount = dis.readLong(); + this.depth = dis.readInt(); + this.width = dis.readInt(); + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); - long hashA[] = new long[depth]; + this.hashA = new long[depth]; for (int i = 0; i < depth; ++i) { - hashA[i] = dis.readLong(); + this.hashA[i] = dis.readLong(); } - long table[][] = new long[depth][width]; + this.table = new long[depth][width]; for (int i = 0; i < depth; ++i) { for (int j = 0; j < width; ++j) { - table[i][j] = dis.readLong(); + this.table[i][j] = dis.readLong(); } } + } + + private void writeObject(ObjectOutputStream out) throws IOException { + this.writeTo(out); + } - return new CountMinSketchImpl(depth, width, totalCount, hashA, table); + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.readFrom0(in); } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 31b364f351..4bb55f6b7f 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -42,6 +42,11 @@ 1.5.6 jar + + org.apache.spark + spark-sketch_2.10 + ${project.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index e66aa5f947..465b12bb59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.CountMinSketch /** * :: Experimental :: @@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), depth, width, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(depth, width, seed)) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) + } + + private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require( + colType == StringType || colType.isInstanceOf[IntegralType], + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + + singleCol.rdd.aggregate(zero)( + (sketch: CountMinSketch, row: Row) => { + sketch.add(row.get(0)) + sketch + }, + + (sketch1: CountMinSketch, sketch2: CountMinSketch) => { + sketch1.mergeInPlace(sketch2) + } + ) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ac1607ba35..9cf94e72d3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -35,9 +35,10 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.CountMinSketch; +import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { @@ -321,4 +322,29 @@ public class JavaDataFrameSuite { Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); } + + @Test + public void testCountMinSketch() { + DataFrame df = context.range(1000); + + CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); + Assert.assertEquals(sketch1.totalCount(), 1000); + Assert.assertEquals(sketch1.depth(), 10); + Assert.assertEquals(sketch1.width(), 20); + + CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); + Assert.assertEquals(sketch2.totalCount(), 1000); + Assert.assertEquals(sketch2.depth(), 10); + Assert.assertEquals(sketch2.width(), 20); + + CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); + Assert.assertEquals(sketch3.totalCount(), 1000); + Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + + CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); + Assert.assertEquals(sketch4.totalCount(), 1000); + Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 63ad6c439a..8f3ea5a286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql import java.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DoubleType class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { sampled.groupBy("key").count().orderBy("key"), Seq(Row(0, 6), Row(1, 11))) } + + // This test case only verifies that `DataFrame.countMinSketch()` methods do return + // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in + // `CountMinSketchSuite` in project spark-sketch. + test("countMinSketch") { + val df = sqlContext.range(1000) + + val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) + assert(sketch1.totalCount() === 1000) + assert(sketch1.depth() === 10) + assert(sketch1.width() === 20) + + val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42) + assert(sketch2.totalCount() === 1000) + assert(sketch2.depth() === 10) + assert(sketch2.width() === 20) + + val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch3.totalCount() === 1000) + assert(sketch3.relativeError() === 0.001) + assert(sketch3.confidence() === 0.99 +- 5e-3) + + val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch4.totalCount() === 1000) + assert(sketch4.relativeError() === 0.001 +- 1e04) + assert(sketch4.confidence() === 0.99 +- 5e-3) + + intercept[IllegalArgumentException] { + df.select('id cast DoubleType as 'id) + .stat + .countMinSketch('id, depth = 10, width = 20, seed = 42) + } + } } -- cgit v1.2.3