From 415d0a859b7a76f3a866ec62ab472c4050f2a01b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 28 Jan 2016 12:26:03 -0800 Subject: [SPARK-12818][SQL] Specialized integral and string types for Count-min Sketch This PR is a follow-up of #10911. It adds specialized update methods for `CountMinSketch` so that we can avoid doing internal/external row format conversion in `DataFrame.countMinSketch()`. Author: Cheng Lian Closes #10968 from liancheng/cms-specialized. --- .../apache/spark/sql/DataFrameStatFunctions.scala | 65 +++++++++++++--------- 1 file changed, 39 insertions(+), 26 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b0b6995a22..bb3cc02800 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** @@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * - * * @param col1 The name of the first column. Distinct items will make the first item of * each row. * @param col2 The name of the second column. Distinct items will make the column names @@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { val singleCol = df.select(col) val colType = singleCol.schema.head.dataType - require( - colType == StringType || colType.isInstanceOf[IntegralType], - s"Count-min Sketch only supports string type and integral types, " + - s"and does not support type $colType." - ) + val updater: (CountMinSketch, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) + case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) + case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) + case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) + case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + } - singleCol.rdd.aggregate(zero)( - (sketch: CountMinSketch, row: Row) => { - sketch.add(row.get(0)) + singleCol.queryExecution.toRdd.aggregate(zero)( + (sketch: CountMinSketch, row: InternalRow) => { + updater(sketch, row) sketch }, - - (sketch1: CountMinSketch, sketch2: CountMinSketch) => { - sketch1.mergeInPlace(sketch2) - } + (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) ) } @@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { require(colType == StringType || colType.isInstanceOf[IntegralType], s"Bloom filter only supports string type and integral types, but got $colType.") - val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) { - (filter, row) => - // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` - // instead of `putString` to avoid unnecessary conversion. - filter.putBinary(row.getUTF8String(0).getBytes) - filter - } else { - (filter, row) => - // TODO: specialize it. - filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue()) - filter + val updater: (BloomFilter, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) + case ByteType => (filter, row) => filter.putLong(row.getByte(0)) + case ShortType => (filter, row) => filter.putLong(row.getShort(0)) + case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) + case LongType => (filter, row) => filter.putLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Bloom filter only supports string type and integral types, " + + s"and does not support type $colType." + ) } - singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _) + singleCol.queryExecution.toRdd.aggregate(zero)( + (filter: BloomFilter, row: InternalRow) => { + updater(filter, row) + filter + }, + (filter1, filter2) => filter1.mergeInPlace(filter2) + ) } } -- cgit v1.2.3