aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2016-11-29 13:16:46 -0800
committerReynold Xin <rxin@databricks.com>2016-11-29 13:16:46 -0800
commitd57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e (patch)
tree847dfe0de2a6ec831917709f169708695d09f95f /sql
parentf643fe47f4889faf68da3da8d7850ee48df7c22f (diff)
downloadspark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.gz
spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.tar.bz2
spark-d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e.zip
[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch
## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua <wangzhenhua@huawei.com> Closes #15877 from wzhfy/cms.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/pom.xml5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala146
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala320
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala189
5 files changed, 661 insertions, 0 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index f118a9a984..82a5a85317 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -62,6 +62,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sketch_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 2636afe620..e41f1cad93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -262,6 +262,7 @@ object FunctionRegistry {
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
+ expression[CountMinSketchAgg]("count_min_sketch"),
// string functions
expression[Ascii]("ascii"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
new file mode 100644
index 0000000000..1bfae9e5a4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+/**
+ * This function returns a count-min sketch of a column with the given esp, confidence and seed.
+ * A count-min sketch is a probabilistic data structure used for summarizing streams of data in
+ * sub-linear space, which is useful for equality predicates and join size estimation.
+ * The result returned by the function is an array of bytes, which should be deserialized to a
+ * `CountMinSketch` before usage.
+ *
+ * @param child child expression that can produce column value with `child.eval(inputRow)`
+ * @param epsExpression relative error, must be positive
+ * @param confidenceExpression confidence, must be positive and less than 1.0
+ * @param seedExpression random seed
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
+ confidence and seed. The result is an array of bytes, which should be deserialized to a
+ `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
+ size estimation.
+ """)
+case class CountMinSketchAgg(
+ child: Expression,
+ epsExpression: Expression,
+ confidenceExpression: Expression,
+ seedExpression: Expression,
+ override val mutableAggBufferOffset: Int,
+ override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {
+
+ def this(
+ child: Expression,
+ epsExpression: Expression,
+ confidenceExpression: Expression,
+ seedExpression: Expression) = {
+ this(child, epsExpression, confidenceExpression, seedExpression, 0, 0)
+ }
+
+ // Mark as lazy so that they are not evaluated during tree transformation.
+ private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
+ private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
+ private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
+ !seedExpression.foldable) {
+ TypeCheckFailure(
+ "The eps, confidence or seed provided must be a literal or constant foldable")
+ } else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
+ seedExpression.eval() == null) {
+ TypeCheckFailure("The eps, confidence or seed provided should not be null")
+ } else if (eps <= 0D) {
+ TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
+ } else if (confidence <= 0D || confidence >= 1D) {
+ TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def createAggregationBuffer(): CountMinSketch = {
+ CountMinSketch.create(eps, confidence, seed)
+ }
+
+ override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
+ val value = child.eval(input)
+ // Ignore empty rows
+ if (value != null) {
+ child.dataType match {
+ // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
+ // into acceptable types for `CountMinSketch`.
+ case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
+ // instead of `addString` to avoid unnecessary conversion.
+ case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
+ case _ => buffer.add(value)
+ }
+ }
+ }
+
+ override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = {
+ buffer.mergeInPlace(input)
+ }
+
+ override def eval(buffer: CountMinSketch): Any = serialize(buffer)
+
+ override def serialize(buffer: CountMinSketch): Array[Byte] = {
+ val out = new ByteArrayOutputStream()
+ buffer.writeTo(out)
+ out.toByteArray
+ }
+
+ override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
+ val in = new ByteArrayInputStream(storageFormat)
+ CountMinSketch.readFrom(in)
+ }
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
+ DoubleType, DoubleType, IntegerType)
+ }
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = BinaryType
+
+ override def children: Seq[Expression] =
+ Seq(child, epsExpression, confidenceExpression, seedExpression)
+
+ override def prettyName: String = "count_min_sketch"
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
new file mode 100644
index 0000000000..6e08e29c04
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.ByteArrayInputStream
+import java.nio.charset.StandardCharsets
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
+import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+class CountMinSketchAggSuite extends SparkFunSuite {
+ private val childExpression = BoundReference(0, IntegerType, nullable = true)
+ private val epsOfTotalCount = 0.0001
+ private val confidence = 0.99
+ private val seed = 42
+
+ test("serialize and de-serialize") {
+ // Check empty serialize and de-serialize
+ val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
+ Literal(seed))
+ val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+
+ // Check non-empty serialize and de-serialize
+ val random = new Random(31)
+ (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
+ buffer.add(value)
+ }
+ assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+ }
+
+ def testHighLevelInterface[T: ClassTag](
+ dataType: DataType,
+ sampledItemIndices: Array[Int],
+ allItems: Array[T],
+ exactFreq: Map[Any, Long]): Any = {
+ test(s"high level interface, update, merge, eval... - $dataType") {
+ val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
+ Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
+ assert(!agg.nullable)
+
+ val group1 = 0 until sampledItemIndices.length / 2
+ val group1Buffer = agg.createAggregationBuffer()
+ group1.foreach { index =>
+ val input = InternalRow(allItems(sampledItemIndices(index)))
+ agg.update(group1Buffer, input)
+ }
+
+ val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length
+ val group2Buffer = agg.createAggregationBuffer()
+ group2.foreach { index =>
+ val input = InternalRow(allItems(sampledItemIndices(index)))
+ agg.update(group2Buffer, input)
+ }
+
+ var mergeBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, group1Buffer)
+ agg.merge(mergeBuffer, group2Buffer)
+ checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+
+ // Merge in a different order
+ mergeBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, group2Buffer)
+ agg.merge(mergeBuffer, group1Buffer)
+ checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+
+ // Merge with an empty partition
+ val emptyBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, emptyBuffer)
+ checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+ }
+ }
+
+ def testLowLevelInterface[T: ClassTag](
+ dataType: DataType,
+ sampledItemIndices: Array[Int],
+ allItems: Array[T],
+ exactFreq: Map[Any, Long]): Any = {
+ test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
+ val inputAggregationBufferOffset = 1
+ val mutableAggregationBufferOffset = 2
+
+ // Phase one, partial mode aggregation
+ val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
+ Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
+ .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+ .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+ val mutableAggBuffer = new GenericInternalRow(
+ new Array[Any](mutableAggregationBufferOffset + 1))
+ agg.initialize(mutableAggBuffer)
+
+ sampledItemIndices.foreach { i =>
+ agg.update(mutableAggBuffer, InternalRow(allItems(i)))
+ }
+ agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+ // Serialize the aggregation buffer
+ val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+ val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
+
+ // Phase 2: final mode aggregation
+ // Re-initialize the aggregation buffer
+ agg.initialize(mutableAggBuffer)
+ agg.merge(mutableAggBuffer, inputAggBuffer)
+ checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+ }
+ }
+
+ private def checkResult[T: ClassTag](
+ result: Any,
+ data: Array[T],
+ exactFreq: Map[Any, Long]): Unit = {
+ result match {
+ case bytesData: Array[Byte] =>
+ val in = new ByteArrayInputStream(bytesData)
+ val cms = CountMinSketch.readFrom(in)
+ val probCorrect = {
+ val numErrors = data.map { i =>
+ val count = exactFreq.getOrElse(getProbeItem(i), 0L)
+ val item = i match {
+ case dec: Decimal => dec.toJavaBigDecimal
+ case str: UTF8String => str.getBytes
+ case _ => i
+ }
+ val ratio = (cms.estimateCount(item) - count).toDouble / data.length
+ if (ratio > epsOfTotalCount) 1 else 0
+ }.sum
+
+ 1D - numErrors.toDouble / data.length
+ }
+
+ assert(
+ probCorrect > confidence,
+ s"Confidence not reached: required $confidence, reached $probCorrect"
+ )
+ case _ => fail("unexpected return type")
+ }
+ }
+
+ private def getProbeItem[T: ClassTag](item: T): Any = item match {
+ // Use a string to represent the content of an array of bytes
+ case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+ case i => identity(i)
+ }
+
+ def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
+ // Uses fixed seed to ensure reproducible test execution
+ val r = new Random(31)
+
+ val numAllItems = 1000000
+ val allItems = Array.fill(numAllItems)(itemGenerator(r))
+
+ val numSamples = numAllItems / 10
+ val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+
+ val exactFreq = {
+ val sampledItems = sampledItemIndices.map(allItems)
+ sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+ }
+
+ testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
+ testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
+ }
+
+ testItemType[Byte](ByteType) { _.nextInt().toByte }
+
+ testItemType[Short](ShortType) { _.nextInt().toShort }
+
+ testItemType[Int](IntegerType) { _.nextInt() }
+
+ testItemType[Long](LongType) { _.nextLong() }
+
+ testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }
+
+ testItemType[Float](FloatType) { _.nextFloat() }
+
+ testItemType[Double](DoubleType) { _.nextDouble() }
+
+ testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
+
+ testItemType[Boolean](BooleanType) { _.nextBoolean() }
+
+ testItemType[Array[Byte]](BinaryType) { r =>
+ r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+ }
+
+
+ test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
+ val wrongEps = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = AttributeReference("a", DoubleType)(),
+ confidenceExpression = Literal(confidence),
+ seedExpression = Literal(seed))
+ val wrongConfidence = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(epsOfTotalCount),
+ confidenceExpression = AttributeReference("b", DoubleType)(),
+ seedExpression = Literal(seed))
+ val wrongSeed = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(epsOfTotalCount),
+ confidenceExpression = Literal(confidence),
+ seedExpression = AttributeReference("c", IntegerType)())
+
+ Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
+ assertEqual(
+ wrongAgg.checkInputDataTypes(),
+ TypeCheckFailure(
+ "The eps, confidence or seed provided must be a literal or constant foldable")
+ )
+ }
+ }
+
+ test("fails analysis if parameters are invalid") {
+ // parameters are null
+ val wrongEps = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Cast(Literal(null), DoubleType),
+ confidenceExpression = Literal(confidence),
+ seedExpression = Literal(seed))
+ val wrongConfidence = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(epsOfTotalCount),
+ confidenceExpression = Cast(Literal(null), DoubleType),
+ seedExpression = Literal(seed))
+ val wrongSeed = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(epsOfTotalCount),
+ confidenceExpression = Literal(confidence),
+ seedExpression = Cast(Literal(null), IntegerType))
+
+ Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
+ assertEqual(
+ wrongAgg.checkInputDataTypes(),
+ TypeCheckFailure("The eps, confidence or seed provided should not be null")
+ )
+ }
+
+ // parameters are out of the valid range
+ Seq(0.0, -1000.0).foreach { invalidEps =>
+ val invalidAgg = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(invalidEps),
+ confidenceExpression = Literal(confidence),
+ seedExpression = Literal(seed))
+ assertEqual(
+ invalidAgg.checkInputDataTypes(),
+ TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")
+ )
+ }
+
+ Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
+ val invalidAgg = new CountMinSketchAgg(
+ childExpression,
+ epsExpression = Literal(epsOfTotalCount),
+ confidenceExpression = Literal(invalidConfidence),
+ seedExpression = Literal(seed))
+ assertEqual(
+ invalidAgg.checkInputDataTypes(),
+ TypeCheckFailure(
+ s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")
+ )
+ }
+ }
+
+ private def assertEqual[T](left: T, right: T): Unit = {
+ assert(left == right)
+ }
+
+ test("null handling") {
+ def isEqual(result: Any, other: CountMinSketch): Boolean = {
+ result match {
+ case bytesData: Array[Byte] =>
+ val in = new ByteArrayInputStream(bytesData)
+ val cms = CountMinSketch.readFrom(in)
+ cms.equals(other)
+ case _ => fail("unexpected return type")
+ }
+ }
+
+ val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
+ Literal(seed))
+ val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ val buffer = new GenericInternalRow(new Array[Any](1))
+ agg.initialize(buffer)
+ // Empty aggregation buffer
+ assert(isEqual(agg.eval(buffer), emptyCms))
+ // Empty input row
+ agg.update(buffer, InternalRow(null))
+ assert(isEqual(agg.eval(buffer), emptyCms))
+
+ // Add some non-empty row
+ agg.update(buffer, InternalRow(0))
+ assert(!isEqual(agg.eval(buffer), emptyCms))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
new file mode 100644
index 0000000000..4cc50604bc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.ByteArrayInputStream
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{Decimal, StringType, _}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
+
+ private val table = "count_min_sketch_table"
+
+ /** Uses fixed seed to ensure reproducible test execution */
+ private val r = new Random(42)
+ private val numAllItems = 1000
+ private val numSamples = numAllItems / 10
+
+ private val eps = 0.1D
+ private val confidence = 0.95D
+ private val seed = 11
+
+ val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
+ val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01"))
+ val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00"))
+ val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00"))
+
+ test(s"compute count-min sketch for multiple columns of different types") {
+ val (allBytes, sampledByteIndices, exactByteFreq) =
+ generateTestData[Byte] { _.nextInt().toByte }
+ val (allShorts, sampledShortIndices, exactShortFreq) =
+ generateTestData[Short] { _.nextInt().toShort }
+ val (allInts, sampledIntIndices, exactIntFreq) =
+ generateTestData[Int] { _.nextInt() }
+ val (allLongs, sampledLongIndices, exactLongFreq) =
+ generateTestData[Long] { _.nextLong() }
+ val (allStrings, sampledStringIndices, exactStringFreq) =
+ generateTestData[String] { r => r.nextString(r.nextInt(20)) }
+ val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r =>
+ DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate)
+ }
+ val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r =>
+ DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
+ }
+ val (allFloats, sampledFloatIndices, exactFloatFreq) =
+ generateTestData[Float] { _.nextFloat() }
+ val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
+ generateTestData[Double] { _.nextDouble() }
+ val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
+ generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
+ val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
+ generateTestData[Boolean] { _.nextBoolean() }
+ val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r =>
+ r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+ }
+
+ val data = (0 until numSamples).map { i =>
+ Row(allBytes(sampledByteIndices(i)),
+ allShorts(sampledShortIndices(i)),
+ allInts(sampledIntIndices(i)),
+ allLongs(sampledLongIndices(i)),
+ allStrings(sampledStringIndices(i)),
+ allDates(sampledDateIndices(i)),
+ allTimestamps(sampledTSIndices(i)),
+ allFloats(sampledFloatIndices(i)),
+ allDoubles(sampledDoubleIndices(i)),
+ allDeciamls(sampledDecimalIndices(i)),
+ allBooleans(sampledBooleanIndices(i)),
+ allBinaries(sampledBinaryIndices(i)))
+ }
+
+ val schema = StructType(Seq(
+ StructField("c1", ByteType),
+ StructField("c2", ShortType),
+ StructField("c3", IntegerType),
+ StructField("c4", LongType),
+ StructField("c5", StringType),
+ StructField("c6", DateType),
+ StructField("c7", TimestampType),
+ StructField("c8", FloatType),
+ StructField("c9", DoubleType),
+ StructField("c10", new DecimalType()),
+ StructField("c11", BooleanType),
+ StructField("c12", BinaryType)))
+
+ withTempView(table) {
+ val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
+ spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
+ val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)")
+ .mkString(", ")
+ val result = sql(s"SELECT $cmsSql FROM $table").head()
+ schema.indices.foreach { i =>
+ val binaryData = result.getAs[Array[Byte]](i)
+ val in = new ByteArrayInputStream(binaryData)
+ val cms = CountMinSketch.readFrom(in)
+ schema.fields(i).dataType match {
+ case ByteType => checkResult(cms, allBytes, exactByteFreq)
+ case ShortType => checkResult(cms, allShorts, exactShortFreq)
+ case IntegerType => checkResult(cms, allInts, exactIntFreq)
+ case LongType => checkResult(cms, allLongs, exactLongFreq)
+ case StringType => checkResult(cms, allStrings, exactStringFreq)
+ case DateType =>
+ checkResult(cms,
+ allDates.map(DateTimeUtils.fromJavaDate),
+ exactDateFreq.map { e =>
+ (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
+ })
+ case TimestampType =>
+ checkResult(cms,
+ allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
+ exactTSFreq.map { e =>
+ (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
+ })
+ case FloatType => checkResult(cms, allFloats, exactFloatFreq)
+ case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
+ case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
+ case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
+ case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
+ }
+ }
+ }
+ }
+
+ private def checkResult[T: ClassTag](
+ cms: CountMinSketch,
+ data: Array[T],
+ exactFreq: Map[Any, Long]): Unit = {
+ val probCorrect = {
+ val numErrors = data.map { i =>
+ val count = exactFreq.getOrElse(getProbeItem(i), 0L)
+ val item = i match {
+ case dec: Decimal => dec.toJavaBigDecimal
+ case str: UTF8String => str.getBytes
+ case _ => i
+ }
+ val ratio = (cms.estimateCount(item) - count).toDouble / data.length
+ if (ratio > eps) 1 else 0
+ }.sum
+
+ 1D - numErrors.toDouble / data.length
+ }
+
+ assert(
+ probCorrect > confidence,
+ s"Confidence not reached: required $confidence, reached $probCorrect"
+ )
+ }
+
+ private def getProbeItem[T: ClassTag](item: T): Any = item match {
+ // Use a string to represent the content of an array of bytes
+ case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+ case i => identity(i)
+ }
+
+ private def generateTestData[T: ClassTag](
+ itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
+ val allItems = Array.fill(numAllItems)(itemGenerator(r))
+ val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+ val exactFreq = {
+ val sampledItems = sampledItemIndices.map(allItems)
+ sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+ }
+ (allItems, sampledItemIndices, exactFreq)
+ }
+}