diff options
author | Cheng Lian <lian@databricks.com> | 2016-01-28 12:26:03 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-28 12:26:03 -0800 |
commit | 415d0a859b7a76f3a866ec62ab472c4050f2a01b (patch) | |
tree | 34ffe59512387a0f1a02b282339a75b08fff4aa4 /common/sketch | |
parent | c2204436a15838f2dce44e3cfb0fe58236ef6196 (diff) | |
download | spark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.tar.gz spark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.tar.bz2 spark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.zip |
[SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch
This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`.
Author: Cheng Lian <lian@databricks.com>
Closes #10968 from liancheng/cms-specialized.
Diffstat (limited to 'common/sketch')
-rw-r--r-- | common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java | 34 | ||||
-rw-r--r-- | common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java | 35 |
2 files changed, 60 insertions, 9 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 5692e574d4..f0aac5bb00 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -115,16 +115,46 @@ abstract public class CountMinSketch { public abstract long totalCount(); /** - * Adds 1 to {@code item}. + * Increments {@code item}'s count by one. */ public abstract void add(Object item); /** - * Adds {@code count} to {@code item}. + * Increments {@code item}'s count by {@code count}. */ public abstract void add(Object item, long count); /** + * Increments {@code item}'s count by one. + */ + public abstract void addLong(long item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addLong(long item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addString(String item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addString(String item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addBinary(byte[] item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addBinary(byte[] item, long count); + + /** * Returns the estimated frequency of {@code item}. */ public abstract long estimateCount(Object item); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e49ae22906..c0631c6778 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -25,7 +25,6 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; -import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; @@ -146,27 +145,49 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { } } - private void addString(String item, long count) { + @Override + public void addString(String item) { + addString(item, 1); + } + + @Override + public void addString(String item, long count) { + addBinary(Utils.getBytesFromUTF8String(item), count); + } + + @Override + public void addLong(long item) { + addLong(item, 1); + } + + @Override + public void addLong(long item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } - int[] buckets = getHashBuckets(item, depth, width); - for (int i = 0; i < depth; ++i) { - table[i][buckets[i]] += count; + table[i][hash(item, i)] += count; } totalCount += count; } - private void addLong(long item, long count) { + @Override + public void addBinary(byte[] item) { + addBinary(item, 1); + } + + @Override + public void addBinary(byte[] item, long count) { if (count < 0) { throw new IllegalArgumentException("Negative increments not implemented"); } + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { - table[i][hash(item, i)] += count; + table[i][buckets[i]] += count; } totalCount += count; |