aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java34
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala65
3 files changed, 99 insertions, 35 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 5692e574d4..f0aac5bb00 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -115,16 +115,46 @@ abstract public class CountMinSketch {
public abstract long totalCount();
/**
- * Adds 1 to {@code item}.
+ * Increments {@code item}'s count by one.
*/
public abstract void add(Object item);
/**
- * Adds {@code count} to {@code item}.
+ * Increments {@code item}'s count by {@code count}.
*/
public abstract void add(Object item, long count);
/**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addLong(long item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addLong(long item, long count);
+
+ /**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addString(String item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addString(String item, long count);
+
+ /**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addBinary(byte[] item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addBinary(byte[] item, long count);
+
+ /**
* Returns the estimated frequency of {@code item}.
*/
public abstract long estimateCount(Object item);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index e49ae22906..c0631c6778 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -25,7 +25,6 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
-import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
@@ -146,27 +145,49 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
}
}
- private void addString(String item, long count) {
+ @Override
+ public void addString(String item) {
+ addString(item, 1);
+ }
+
+ @Override
+ public void addString(String item, long count) {
+ addBinary(Utils.getBytesFromUTF8String(item), count);
+ }
+
+ @Override
+ public void addLong(long item) {
+ addLong(item, 1);
+ }
+
+ @Override
+ public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
- int[] buckets = getHashBuckets(item, depth, width);
-
for (int i = 0; i < depth; ++i) {
- table[i][buckets[i]] += count;
+ table[i][hash(item, i)] += count;
}
totalCount += count;
}
- private void addLong(long item, long count) {
+ @Override
+ public void addBinary(byte[] item) {
+ addBinary(item, 1);
+ }
+
+ @Override
+ public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
+ int[] buckets = getHashBuckets(item, depth, width);
+
for (int i = 0; i < depth; ++i) {
- table[i][hash(item, i)] += count;
+ table[i][buckets[i]] += count;
}
totalCount += count;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index b0b6995a22..bb3cc02800 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.types.{IntegralType, StringType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
@@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
- *
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
@@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType
- require(
- colType == StringType || colType.isInstanceOf[IntegralType],
- s"Count-min Sketch only supports string type and integral types, " +
- s"and does not support type $colType."
- )
+ val updater: (CountMinSketch, InternalRow) => Unit = colType match {
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
+ // instead of `addString` to avoid unnecessary conversion.
+ case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
+ case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
+ case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
+ case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
+ case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Count-min Sketch only supports string type and integral types, " +
+ s"and does not support type $colType."
+ )
+ }
- singleCol.rdd.aggregate(zero)(
- (sketch: CountMinSketch, row: Row) => {
- sketch.add(row.get(0))
+ singleCol.queryExecution.toRdd.aggregate(zero)(
+ (sketch: CountMinSketch, row: InternalRow) => {
+ updater(sketch, row)
sketch
},
-
- (sketch1: CountMinSketch, sketch2: CountMinSketch) => {
- sketch1.mergeInPlace(sketch2)
- }
+ (sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
)
}
@@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")
- val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
- (filter, row) =>
- // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
- // instead of `putString` to avoid unnecessary conversion.
- filter.putBinary(row.getUTF8String(0).getBytes)
- filter
- } else {
- (filter, row) =>
- // TODO: specialize it.
- filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
- filter
+ val updater: (BloomFilter, InternalRow) => Unit = colType match {
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
+ // instead of `putString` to avoid unnecessary conversion.
+ case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
+ case ByteType => (filter, row) => filter.putLong(row.getByte(0))
+ case ShortType => (filter, row) => filter.putLong(row.getShort(0))
+ case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
+ case LongType => (filter, row) => filter.putLong(row.getLong(0))
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Bloom filter only supports string type and integral types, " +
+ s"and does not support type $colType."
+ )
}
- singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
+ singleCol.queryExecution.toRdd.aggregate(zero)(
+ (filter: BloomFilter, row: InternalRow) => {
+ updater(filter, row)
+ filter
+ },
+ (filter1, filter2) => filter1.mergeInPlace(filter2)
+ )
}
}