From d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 29 Nov 2016 13:16:46 -0800 Subject: [SPARK-18429][SQL] implement a new Aggregate for CountMinSketch ## What changes were proposed in this pull request? This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch. ## How was this patch tested? add test cases Author: wangzhenhua Closes #15877 from wzhfy/cms. --- sql/catalyst/pom.xml | 5 + .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/CountMinSketchAgg.scala | 146 ++++++++++ .../aggregate/CountMinSketchAggSuite.scala | 320 +++++++++++++++++++++ .../spark/sql/CountMinSketchAggQuerySuite.scala | 189 ++++++++++++ 5 files changed, 661 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala (limited to 'sql') diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f118a9a984..82a5a85317 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -61,6 +61,11 @@ spark-unsafe_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sketch_${scala.binary.version} + ${project.version} + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 2636afe620..e41f1cad93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -262,6 +262,7 @@ object FunctionRegistry { expression[VarianceSamp]("var_samp"), expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), + expression[CountMinSketchAgg]("count_min_sketch"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala new file mode 100644 index 0000000000..1bfae9e5a4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.CountMinSketch + +/** + * This function returns a count-min sketch of a column with the given esp, confidence and seed. + * A count-min sketch is a probabilistic data structure used for summarizing streams of data in + * sub-linear space, which is useful for equality predicates and join size estimation. + * The result returned by the function is an array of bytes, which should be deserialized to a + * `CountMinSketch` before usage. + * + * @param child child expression that can produce column value with `child.eval(inputRow)` + * @param epsExpression relative error, must be positive + * @param confidenceExpression confidence, must be positive and less than 1.0 + * @param seedExpression random seed + */ +@ExpressionDescription( + usage = """ + _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, + confidence and seed. The result is an array of bytes, which should be deserialized to a + `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join + size estimation. + """) +case class CountMinSketchAgg( + child: Expression, + epsExpression: Expression, + confidenceExpression: Expression, + seedExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] { + + def this( + child: Expression, + epsExpression: Expression, + confidenceExpression: Expression, + seedExpression: Expression) = { + this(child, epsExpression, confidenceExpression, seedExpression, 0, 0) + } + + // Mark as lazy so that they are not evaluated during tree transformation. + private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double] + private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double] + private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int] + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!epsExpression.foldable || !confidenceExpression.foldable || + !seedExpression.foldable) { + TypeCheckFailure( + "The eps, confidence or seed provided must be a literal or constant foldable") + } else if (epsExpression.eval() == null || confidenceExpression.eval() == null || + seedExpression.eval() == null) { + TypeCheckFailure("The eps, confidence or seed provided should not be null") + } else if (eps <= 0D) { + TypeCheckFailure(s"Relative error must be positive (current value = $eps)") + } else if (confidence <= 0D || confidence >= 1D) { + TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): CountMinSketch = { + CountMinSketch.create(eps, confidence, seed) + } + + override def update(buffer: CountMinSketch, input: InternalRow): Unit = { + val value = child.eval(input) + // Ignore empty rows + if (value != null) { + child.dataType match { + // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them + // into acceptable types for `CountMinSketch`. + case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal) + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes) + case _ => buffer.add(value) + } + } + } + + override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = { + buffer.mergeInPlace(input) + } + + override def eval(buffer: CountMinSketch): Any = serialize(buffer) + + override def serialize(buffer: CountMinSketch): Array[Byte] = { + val out = new ByteArrayOutputStream() + buffer.writeTo(out) + out.toByteArray + } + + override def deserialize(storageFormat: Array[Byte]): CountMinSketch = { + val in = new ByteArrayInputStream(storageFormat) + CountMinSketch.readFrom(in) + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType), + DoubleType, DoubleType, IntegerType) + } + + override def nullable: Boolean = false + + override def dataType: DataType = BinaryType + + override def children: Seq[Expression] = + Seq(child, epsExpression, confidenceExpression, seedExpression) + + override def prettyName: String = "count_min_sketch" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala new file mode 100644 index 0000000000..6e08e29c04 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.ByteArrayInputStream +import java.nio.charset.StandardCharsets + +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal} +import org.apache.spark.sql.types.{DecimalType, _} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.CountMinSketch + +class CountMinSketchAggSuite extends SparkFunSuite { + private val childExpression = BoundReference(0, IntegerType, nullable = true) + private val epsOfTotalCount = 0.0001 + private val confidence = 0.99 + private val seed = 42 + + test("serialize and de-serialize") { + // Check empty serialize and de-serialize + val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence), + Literal(seed)) + val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed) + assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) + + // Check non-empty serialize and de-serialize + val random = new Random(31) + (0 until 10000).map(_ => random.nextInt(100)).foreach { value => + buffer.add(value) + } + assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) + } + + def testHighLevelInterface[T: ClassTag]( + dataType: DataType, + sampledItemIndices: Array[Int], + allItems: Array[T], + exactFreq: Map[Any, Long]): Any = { + test(s"high level interface, update, merge, eval... - $dataType") { + val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true), + Literal(epsOfTotalCount), Literal(confidence), Literal(seed)) + assert(!agg.nullable) + + val group1 = 0 until sampledItemIndices.length / 2 + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(allItems(sampledItemIndices(index))) + agg.update(group1Buffer, input) + } + + val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(allItems(sampledItemIndices(index))) + agg.update(group2Buffer, input) + } + + var mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + checkResult(agg.eval(mergeBuffer), allItems, exactFreq) + + // Merge in a different order + mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group2Buffer) + agg.merge(mergeBuffer, group1Buffer) + checkResult(agg.eval(mergeBuffer), allItems, exactFreq) + + // Merge with an empty partition + val emptyBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, emptyBuffer) + checkResult(agg.eval(mergeBuffer), allItems, exactFreq) + } + } + + def testLowLevelInterface[T: ClassTag]( + dataType: DataType, + sampledItemIndices: Array[Int], + allItems: Array[T], + exactFreq: Map[Any, Long]): Any = { + test(s"low level interface, update, merge, eval... - ${dataType.typeName}") { + val inputAggregationBufferOffset = 1 + val mutableAggregationBufferOffset = 2 + + // Phase one, partial mode aggregation + val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true), + Literal(epsOfTotalCount), Literal(confidence), Literal(seed)) + .withNewInputAggBufferOffset(inputAggregationBufferOffset) + .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) + + val mutableAggBuffer = new GenericInternalRow( + new Array[Any](mutableAggregationBufferOffset + 1)) + agg.initialize(mutableAggBuffer) + + sampledItemIndices.foreach { i => + agg.update(mutableAggBuffer, InternalRow(allItems(i))) + } + agg.serializeAggregateBufferInPlace(mutableAggBuffer) + + // Serialize the aggregation buffer + val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) + val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) + + // Phase 2: final mode aggregation + // Re-initialize the aggregation buffer + agg.initialize(mutableAggBuffer) + agg.merge(mutableAggBuffer, inputAggBuffer) + checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq) + } + } + + private def checkResult[T: ClassTag]( + result: Any, + data: Array[T], + exactFreq: Map[Any, Long]): Unit = { + result match { + case bytesData: Array[Byte] => + val in = new ByteArrayInputStream(bytesData) + val cms = CountMinSketch.readFrom(in) + val probCorrect = { + val numErrors = data.map { i => + val count = exactFreq.getOrElse(getProbeItem(i), 0L) + val item = i match { + case dec: Decimal => dec.toJavaBigDecimal + case str: UTF8String => str.getBytes + case _ => i + } + val ratio = (cms.estimateCount(item) - count).toDouble / data.length + if (ratio > epsOfTotalCount) 1 else 0 + }.sum + + 1D - numErrors.toDouble / data.length + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + case _ => fail("unexpected return type") + } + } + + private def getProbeItem[T: ClassTag](item: T): Any = item match { + // Use a string to represent the content of an array of bytes + case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) + case i => identity(i) + } + + def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + + val numAllItems = 1000000 + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + + val numSamples = numAllItems / 10 + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) + } + + testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq) + testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq) + } + + testItemType[Byte](ByteType) { _.nextInt().toByte } + + testItemType[Short](ShortType) { _.nextInt().toShort } + + testItemType[Int](IntegerType) { _.nextInt() } + + testItemType[Long](LongType) { _.nextLong() } + + testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) } + + testItemType[Float](FloatType) { _.nextFloat() } + + testItemType[Double](DoubleType) { _.nextDouble() } + + testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) } + + testItemType[Boolean](BooleanType) { _.nextBoolean() } + + testItemType[Array[Byte]](BinaryType) { r => + r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8) + } + + + test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") { + val wrongEps = new CountMinSketchAgg( + childExpression, + epsExpression = AttributeReference("a", DoubleType)(), + confidenceExpression = Literal(confidence), + seedExpression = Literal(seed)) + val wrongConfidence = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(epsOfTotalCount), + confidenceExpression = AttributeReference("b", DoubleType)(), + seedExpression = Literal(seed)) + val wrongSeed = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(epsOfTotalCount), + confidenceExpression = Literal(confidence), + seedExpression = AttributeReference("c", IntegerType)()) + + Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg => + assertEqual( + wrongAgg.checkInputDataTypes(), + TypeCheckFailure( + "The eps, confidence or seed provided must be a literal or constant foldable") + ) + } + } + + test("fails analysis if parameters are invalid") { + // parameters are null + val wrongEps = new CountMinSketchAgg( + childExpression, + epsExpression = Cast(Literal(null), DoubleType), + confidenceExpression = Literal(confidence), + seedExpression = Literal(seed)) + val wrongConfidence = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(epsOfTotalCount), + confidenceExpression = Cast(Literal(null), DoubleType), + seedExpression = Literal(seed)) + val wrongSeed = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(epsOfTotalCount), + confidenceExpression = Literal(confidence), + seedExpression = Cast(Literal(null), IntegerType)) + + Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg => + assertEqual( + wrongAgg.checkInputDataTypes(), + TypeCheckFailure("The eps, confidence or seed provided should not be null") + ) + } + + // parameters are out of the valid range + Seq(0.0, -1000.0).foreach { invalidEps => + val invalidAgg = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(invalidEps), + confidenceExpression = Literal(confidence), + seedExpression = Literal(seed)) + assertEqual( + invalidAgg.checkInputDataTypes(), + TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)") + ) + } + + Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence => + val invalidAgg = new CountMinSketchAgg( + childExpression, + epsExpression = Literal(epsOfTotalCount), + confidenceExpression = Literal(invalidConfidence), + seedExpression = Literal(seed)) + assertEqual( + invalidAgg.checkInputDataTypes(), + TypeCheckFailure( + s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)") + ) + } + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } + + test("null handling") { + def isEqual(result: Any, other: CountMinSketch): Boolean = { + result match { + case bytesData: Array[Byte] => + val in = new ByteArrayInputStream(bytesData) + val cms = CountMinSketch.readFrom(in) + cms.equals(other) + case _ => fail("unexpected return type") + } + } + + val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence), + Literal(seed)) + val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed) + val buffer = new GenericInternalRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(isEqual(agg.eval(buffer), emptyCms)) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(isEqual(agg.eval(buffer), emptyCms)) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(!isEqual(agg.eval(buffer), emptyCms)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala new file mode 100644 index 0000000000..4cc50604bc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayInputStream +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{Decimal, StringType, _} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.CountMinSketch + +class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext { + + private val table = "count_min_sketch_table" + + /** Uses fixed seed to ensure reproducible test execution */ + private val r = new Random(42) + private val numAllItems = 1000 + private val numSamples = numAllItems / 10 + + private val eps = 0.1D + private val confidence = 0.95D + private val seed = 11 + + val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01")) + val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01")) + val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00")) + val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00")) + + test(s"compute count-min sketch for multiple columns of different types") { + val (allBytes, sampledByteIndices, exactByteFreq) = + generateTestData[Byte] { _.nextInt().toByte } + val (allShorts, sampledShortIndices, exactShortFreq) = + generateTestData[Short] { _.nextInt().toShort } + val (allInts, sampledIntIndices, exactIntFreq) = + generateTestData[Int] { _.nextInt() } + val (allLongs, sampledLongIndices, exactLongFreq) = + generateTestData[Long] { _.nextLong() } + val (allStrings, sampledStringIndices, exactStringFreq) = + generateTestData[String] { r => r.nextString(r.nextInt(20)) } + val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r => + DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate) + } + val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r => + DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS) + } + val (allFloats, sampledFloatIndices, exactFloatFreq) = + generateTestData[Float] { _.nextFloat() } + val (allDoubles, sampledDoubleIndices, exactDoubleFreq) = + generateTestData[Double] { _.nextDouble() } + val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) = + generateTestData[Decimal] { r => Decimal(r.nextDouble()) } + val (allBooleans, sampledBooleanIndices, exactBooleanFreq) = + generateTestData[Boolean] { _.nextBoolean() } + val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r => + r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8) + } + + val data = (0 until numSamples).map { i => + Row(allBytes(sampledByteIndices(i)), + allShorts(sampledShortIndices(i)), + allInts(sampledIntIndices(i)), + allLongs(sampledLongIndices(i)), + allStrings(sampledStringIndices(i)), + allDates(sampledDateIndices(i)), + allTimestamps(sampledTSIndices(i)), + allFloats(sampledFloatIndices(i)), + allDoubles(sampledDoubleIndices(i)), + allDeciamls(sampledDecimalIndices(i)), + allBooleans(sampledBooleanIndices(i)), + allBinaries(sampledBinaryIndices(i))) + } + + val schema = StructType(Seq( + StructField("c1", ByteType), + StructField("c2", ShortType), + StructField("c3", IntegerType), + StructField("c4", LongType), + StructField("c5", StringType), + StructField("c6", DateType), + StructField("c7", TimestampType), + StructField("c8", FloatType), + StructField("c9", DoubleType), + StructField("c10", new DecimalType()), + StructField("c11", BooleanType), + StructField("c12", BinaryType))) + + withTempView(table) { + val rdd: RDD[Row] = spark.sparkContext.parallelize(data) + spark.createDataFrame(rdd, schema).createOrReplaceTempView(table) + val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)") + .mkString(", ") + val result = sql(s"SELECT $cmsSql FROM $table").head() + schema.indices.foreach { i => + val binaryData = result.getAs[Array[Byte]](i) + val in = new ByteArrayInputStream(binaryData) + val cms = CountMinSketch.readFrom(in) + schema.fields(i).dataType match { + case ByteType => checkResult(cms, allBytes, exactByteFreq) + case ShortType => checkResult(cms, allShorts, exactShortFreq) + case IntegerType => checkResult(cms, allInts, exactIntFreq) + case LongType => checkResult(cms, allLongs, exactLongFreq) + case StringType => checkResult(cms, allStrings, exactStringFreq) + case DateType => + checkResult(cms, + allDates.map(DateTimeUtils.fromJavaDate), + exactDateFreq.map { e => + (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2) + }) + case TimestampType => + checkResult(cms, + allTimestamps.map(DateTimeUtils.fromJavaTimestamp), + exactTSFreq.map { e => + (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2) + }) + case FloatType => checkResult(cms, allFloats, exactFloatFreq) + case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq) + case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq) + case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq) + case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq) + } + } + } + } + + private def checkResult[T: ClassTag]( + cms: CountMinSketch, + data: Array[T], + exactFreq: Map[Any, Long]): Unit = { + val probCorrect = { + val numErrors = data.map { i => + val count = exactFreq.getOrElse(getProbeItem(i), 0L) + val item = i match { + case dec: Decimal => dec.toJavaBigDecimal + case str: UTF8String => str.getBytes + case _ => i + } + val ratio = (cms.estimateCount(item) - count).toDouble / data.length + if (ratio > eps) 1 else 0 + }.sum + + 1D - numErrors.toDouble / data.length + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + } + + private def getProbeItem[T: ClassTag](item: T): Any = item match { + // Use a string to represent the content of an array of bytes + case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) + case i => identity(i) + } + + private def generateTestData[T: ClassTag]( + itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = { + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) + } + (allItems, sampledItemIndices, exactFreq) + } +} -- cgit v1.2.3