aboutsummaryrefslogtreecommitdiff
path: root/common/sketch
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-28 12:26:03 -0800
committerReynold Xin <rxin@databricks.com>2016-01-28 12:26:03 -0800
commit415d0a859b7a76f3a866ec62ab472c4050f2a01b (patch)
tree34ffe59512387a0f1a02b282339a75b08fff4aa4 /common/sketch
parentc2204436a15838f2dce44e3cfb0fe58236ef6196 (diff)
downloadspark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.tar.gz
spark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.tar.bz2
spark-415d0a859b7a76f3a866ec62ab472c4050f2a01b.zip
[SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch
This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`. Author: Cheng Lian <lian@databricks.com> Closes #10968 from liancheng/cms-specialized.
Diffstat (limited to 'common/sketch')
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java34
-rw-r--r--common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java35
2 files changed, 60 insertions, 9 deletions
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 5692e574d4..f0aac5bb00 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -115,16 +115,46 @@ abstract public class CountMinSketch {
public abstract long totalCount();
/**
- * Adds 1 to {@code item}.
+ * Increments {@code item}'s count by one.
*/
public abstract void add(Object item);
/**
- * Adds {@code count} to {@code item}.
+ * Increments {@code item}'s count by {@code count}.
*/
public abstract void add(Object item, long count);
/**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addLong(long item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addLong(long item, long count);
+
+ /**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addString(String item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addString(String item, long count);
+
+ /**
+ * Increments {@code item}'s count by one.
+ */
+ public abstract void addBinary(byte[] item);
+
+ /**
+ * Increments {@code item}'s count by {@code count}.
+ */
+ public abstract void addBinary(byte[] item, long count);
+
+ /**
* Returns the estimated frequency of {@code item}.
*/
public abstract long estimateCount(Object item);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index e49ae22906..c0631c6778 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -25,7 +25,6 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
-import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
@@ -146,27 +145,49 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
}
}
- private void addString(String item, long count) {
+ @Override
+ public void addString(String item) {
+ addString(item, 1);
+ }
+
+ @Override
+ public void addString(String item, long count) {
+ addBinary(Utils.getBytesFromUTF8String(item), count);
+ }
+
+ @Override
+ public void addLong(long item) {
+ addLong(item, 1);
+ }
+
+ @Override
+ public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
- int[] buckets = getHashBuckets(item, depth, width);
-
for (int i = 0; i < depth; ++i) {
- table[i][buckets[i]] += count;
+ table[i][hash(item, i)] += count;
}
totalCount += count;
}
- private void addLong(long item, long count) {
+ @Override
+ public void addBinary(byte[] item) {
+ addBinary(item, 1);
+ }
+
+ @Override
+ public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}
+ int[] buckets = getHashBuckets(item, depth, width);
+
for (int i = 0; i < depth; ++i) {
- table[i][hash(item, i)] += count;
+ table[i][buckets[i]] += count;
}
totalCount += count;