aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-26 20:12:34 -0800
committerReynold Xin <rxin@databricks.com>2016-01-26 20:12:34 -0800
commitce38a35b764397fcf561ac81de6da96579f5c13e (patch)
tree0f03dfb31f4840488fabc75d5b4edbdc7eb0d874
parente7f9199e709c46a6b5ad6b03c9ecf12cc19e3a41 (diff)
downloadspark-ce38a35b764397fcf561ac81de6da96579f5c13e.tar.gz
spark-ce38a35b764397fcf561ac81de6da96579f5c13e.tar.bz2
spark-ce38a35b764397fcf561ac81de6da96579f5c13e.zip
[SPARK-12935][SQL] DataFrame API for Count-Min Sketch
This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs. Author: Cheng Lian <lian@databricks.com> Closes #10911 from liancheng/cms-df-api.
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java10
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java26
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java56
-rw-r--r--sql/core/pom.xml5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala81
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala36
7 files changed, 205 insertions, 37 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 00378d5851..d392fb187a 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -47,10 +47,12 @@ public abstract class BloomFilter {
public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
- * - Version number, always 1 (32 bit)
- * - Total number of words of the underlying bit array (32 bit)
- * - The words/longs (numWords * 64 bit)
- * - Number of hash functions (32 bit)
+ * <ul>
+ * <li>Version number, always 1 (32 bit)</li>
+ * <li>Total number of words of the underlying bit array (32 bit)</li>
+ * <li>The words/longs (numWords * 64 bit)</li>
+ * <li>Number of hash functions (32 bit)</li>
+ * </ul>
*/
V1(1);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 00c0b1b9e2..5692e574d4 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -59,16 +59,22 @@ abstract public class CountMinSketch {
public enum Version {
/**
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
- * - Version number, always 1 (32 bit)
- * - Total count of added items (64 bit)
- * - Depth (32 bit)
- * - Width (32 bit)
- * - Hash functions (depth * 64 bit)
- * - Count table
- * - Row 0 (width * 64 bit)
- * - Row 1 (width * 64 bit)
- * - ...
- * - Row depth - 1 (width * 64 bit)
+ * <ul>
+ * <li>Version number, always 1 (32 bit)</li>
+ * <li>Total count of added items (64 bit)</li>
+ * <li>Depth (32 bit)</li>
+ * <li>Width (32 bit)</li>
+ * <li>Hash functions (depth * 64 bit)</li>
+ * <li>
+ * Count table
+ * <ul>
+ * <li>Row 0 (width * 64 bit)</li>
+ * <li>Row 1 (width * 64 bit)</li>
+ * <li>...</li>
+ * <li>Row {@code depth - 1} (width * 64 bit)</li>
+ * </ul>
+ * </li>
+ * </ul>
*/
V1(1);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index d08809605a..8cc29e4076 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -21,13 +21,16 @@ import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.io.OutputStream;
+import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
-class CountMinSketchImpl extends CountMinSketch {
- public static final long PRIME_MODULUS = (1L << 31) - 1;
+class CountMinSketchImpl extends CountMinSketch implements Serializable {
+ private static final long PRIME_MODULUS = (1L << 31) - 1;
private int depth;
private int width;
@@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch {
private double eps;
private double confidence;
+ private CountMinSketchImpl() {
+ }
+
CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
@@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch {
initTablesWith(depth, width, seed);
}
- CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
- this.depth = depth;
- this.width = width;
- this.eps = 2.0 / width;
- this.confidence = 1 - 1 / Math.pow(2, depth);
- this.hashA = hashA;
- this.table = table;
- this.totalCount = totalCount;
- }
-
@Override
public boolean equals(Object other) {
if (other == this) {
@@ -325,27 +321,43 @@ class CountMinSketchImpl extends CountMinSketch {
}
public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
+ CountMinSketchImpl sketch = new CountMinSketchImpl();
+ sketch.readFrom0(in);
+ return sketch;
+ }
+
+ private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
- // Ignores version number
- dis.readInt();
+ int version = dis.readInt();
+ if (version != Version.V1.getVersionNumber()) {
+ throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
+ }
- long totalCount = dis.readLong();
- int depth = dis.readInt();
- int width = dis.readInt();
+ this.totalCount = dis.readLong();
+ this.depth = dis.readInt();
+ this.width = dis.readInt();
+ this.eps = 2.0 / width;
+ this.confidence = 1 - 1 / Math.pow(2, depth);
- long hashA[] = new long[depth];
+ this.hashA = new long[depth];
for (int i = 0; i < depth; ++i) {
- hashA[i] = dis.readLong();
+ this.hashA[i] = dis.readLong();
}
- long table[][] = new long[depth][width];
+ this.table = new long[depth][width];
for (int i = 0; i < depth; ++i) {
for (int j = 0; j < width; ++j) {
- table[i][j] = dis.readLong();
+ this.table[i][j] = dis.readLong();
}
}
+ }
+
+ private void writeObject(ObjectOutputStream out) throws IOException {
+ this.writeTo(out);
+ }
- return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
+ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+ this.readFrom0(in);
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 31b364f351..4bb55f6b7f 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -44,6 +44,11 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-sketch_2.10</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index e66aa5f947..465b12bb59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.CountMinSketch
/**
* :: Experimental ::
@@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName name of the column over which the sketch is built
+ * @param depth depth of the sketch
+ * @param width width of the sketch
+ * @param seed random seed
+ * @return a [[CountMinSketch]] over column `colName`
+ * @since 2.0.0
+ */
+ def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), depth, width, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName name of the column over which the sketch is built
+ * @param eps relative error of the sketch
+ * @param confidence confidence of the sketch
+ * @param seed random seed
+ * @return a [[CountMinSketch]] over column `colName`
+ * @since 2.0.0
+ */
+ def countMinSketch(
+ colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), eps, confidence, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col the column over which the sketch is built
+ * @param depth depth of the sketch
+ * @param width width of the sketch
+ * @param seed random seed
+ * @return a [[CountMinSketch]] over column `colName`
+ * @since 2.0.0
+ */
+ def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(col, CountMinSketch.create(depth, width, seed))
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col the column over which the sketch is built
+ * @param eps relative error of the sketch
+ * @param confidence confidence of the sketch
+ * @param seed random seed
+ * @return a [[CountMinSketch]] over column `colName`
+ * @since 2.0.0
+ */
+ def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+ countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
+ }
+
+ private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = {
+ val singleCol = df.select(col)
+ val colType = singleCol.schema.head.dataType
+
+ require(
+ colType == StringType || colType.isInstanceOf[IntegralType],
+ s"Count-min Sketch only supports string type and integral types, " +
+ s"and does not support type $colType."
+ )
+
+ singleCol.rdd.aggregate(zero)(
+ (sketch: CountMinSketch, row: Row) => {
+ sketch.add(row.get(0))
+ sketch
+ },
+
+ (sketch1: CountMinSketch, sketch2: CountMinSketch) => {
+ sketch1.mergeInPlace(sketch2)
+ }
+ )
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index ac1607ba35..9cf94e72d3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -35,9 +35,10 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
-import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.types.*;
+import org.apache.spark.util.sketch.CountMinSketch;
+import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaDataFrameSuite {
@@ -321,4 +322,29 @@ public class JavaDataFrameSuite {
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}
+
+ @Test
+ public void testCountMinSketch() {
+ DataFrame df = context.range(1000);
+
+ CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
+ Assert.assertEquals(sketch1.totalCount(), 1000);
+ Assert.assertEquals(sketch1.depth(), 10);
+ Assert.assertEquals(sketch1.width(), 20);
+
+ CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
+ Assert.assertEquals(sketch2.totalCount(), 1000);
+ Assert.assertEquals(sketch2.depth(), 10);
+ Assert.assertEquals(sketch2.width(), 20);
+
+ CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
+ Assert.assertEquals(sketch3.totalCount(), 1000);
+ Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
+ Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);
+
+ CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
+ Assert.assertEquals(sketch4.totalCount(), 1000);
+ Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
+ Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 63ad6c439a..8f3ea5a286 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,8 +19,11 @@ package org.apache.spark.sql
import java.util.Random
+import org.scalatest.Matchers._
+
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.DoubleType
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 6), Row(1, 11)))
}
+
+ // This test case only verifies that `DataFrame.countMinSketch()` methods do return
+ // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in
+ // `CountMinSketchSuite` in project spark-sketch.
+ test("countMinSketch") {
+ val df = sqlContext.range(1000)
+
+ val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
+ assert(sketch1.totalCount() === 1000)
+ assert(sketch1.depth() === 10)
+ assert(sketch1.width() === 20)
+
+ val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42)
+ assert(sketch2.totalCount() === 1000)
+ assert(sketch2.depth() === 10)
+ assert(sketch2.width() === 20)
+
+ val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
+ assert(sketch3.totalCount() === 1000)
+ assert(sketch3.relativeError() === 0.001)
+ assert(sketch3.confidence() === 0.99 +- 5e-3)
+
+ val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42)
+ assert(sketch4.totalCount() === 1000)
+ assert(sketch4.relativeError() === 0.001 +- 1e04)
+ assert(sketch4.confidence() === 0.99 +- 5e-3)
+
+ intercept[IllegalArgumentException] {
+ df.select('id cast DoubleType as 'id)
+ .stat
+ .countMinSketch('id, depth = 10, width = 20, seed = 42)
+ }
+ }
}