aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-12-01 21:38:52 -0800
committerReynold Xin <rxin@databricks.com>2016-12-01 21:38:52 -0800
commitd3c90b74edecc527ee468bead41d1cca0b667668 (patch)
tree1b64571522c38155e472e0da58dac55907a22225 /sql/catalyst/src
parenta5f02b00291e0a22429a3dca81f12cf6d38fea0b (diff)
downloadspark-d3c90b74edecc527ee468bead41d1cca0b667668.tar.gz
spark-d3c90b74edecc527ee468bead41d1cca0b667668.tar.bz2
spark-d3c90b74edecc527ee468bead41d1cca0b667668.zip
[SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
## What changes were proposed in this pull request? SPARK-18429 introduced count-min sketch aggregate function for SQL, but the implementation and testing is more complicated than needed. This simplifies the test cases and removes support for data types that don't have clear equality semantics: 1. Removed support for floating point and decimal types. 2. Removed the heavy randomized tests. The underlying CountMinSketch implementation already had pretty good test coverage through randomized tests, and the SPARK-18429 implementation is just to add an aggregate function wrapper around CountMinSketch. There is no need for randomized tests at three different levels of the implementations. ## How was this patch tested? A lot of the change is to simplify test cases. Author: Reynold Xin <rxin@databricks.com> Closes #16093 from rxin/SPARK-18663.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala304
3 files changed, 103 insertions, 230 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index f5f185f2c5..612c19831f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
@@ -42,9 +40,9 @@ import org.apache.spark.util.sketch.CountMinSketch
@ExpressionDescription(
usage = """
_FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
- confidence and seed. The result is an array of bytes, which should be deserialized to a
- `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
- size estimation.
+ confidence and seed. The result is an array of bytes, which can be deserialized to a
+ `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for
+ cardinality estimation using sub-linear space.
""")
case class CountMinSketchAgg(
child: Expression,
@@ -75,13 +73,13 @@ case class CountMinSketchAgg(
} else if (!epsExpression.foldable || !confidenceExpression.foldable ||
!seedExpression.foldable) {
TypeCheckFailure(
- "The eps, confidence or seed provided must be a literal or constant foldable")
+ "The eps, confidence or seed provided must be a literal or foldable")
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
seedExpression.eval() == null) {
TypeCheckFailure("The eps, confidence or seed provided should not be null")
- } else if (eps <= 0D) {
+ } else if (eps <= 0.0) {
TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
- } else if (confidence <= 0D || confidence >= 1D) {
+ } else if (confidence <= 0.0 || confidence >= 1.0) {
TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
} else {
TypeCheckSuccess
@@ -97,9 +95,6 @@ case class CountMinSketchAgg(
// Ignore empty rows
if (value != null) {
child.dataType match {
- // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
- // into acceptable types for `CountMinSketch`.
- case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
@@ -115,14 +110,11 @@ case class CountMinSketchAgg(
override def eval(buffer: CountMinSketch): Any = serialize(buffer)
override def serialize(buffer: CountMinSketch): Array[Byte] = {
- val out = new ByteArrayOutputStream()
- buffer.writeTo(out)
- out.toByteArray
+ buffer.toByteArray
}
override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
- val in = new ByteArrayInputStream(storageFormat)
- CountMinSketch.readFrom(in)
+ CountMinSketch.readFrom(storageFormat)
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
@@ -132,8 +124,7 @@ case class CountMinSketchAgg(
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def inputTypes: Seq[AbstractDataType] = {
- Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
- DoubleType, DoubleType, IntegerType)
+ Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType)
}
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
index 8456e24460..fcb370ae84 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -86,7 +86,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
(headBufferSize + bufferSize) * 2
}
- val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count =>
+ Seq(100, 1000, 10000, 100000, 1000000, 10000000).foreach { count =>
val buffer = new PercentileDigest(relativeError)
// Worst case, data is linear sorted
(0 until count).foreach(buffer.add(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 6e08e29c04..10479630f3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -17,199 +17,114 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
+import java.{lang => jl}
-import scala.reflect.ClassTag
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
-import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
+/**
+ * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]].
+ */
class CountMinSketchAggSuite extends SparkFunSuite {
private val childExpression = BoundReference(0, IntegerType, nullable = true)
private val epsOfTotalCount = 0.0001
private val confidence = 0.99
private val seed = 42
-
- test("serialize and de-serialize") {
- // Check empty serialize and de-serialize
- val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
- Literal(seed))
- val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
- assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
-
- // Check non-empty serialize and de-serialize
- val random = new Random(31)
- (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
- buffer.add(value)
- }
- assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+ private val rand = new Random(seed)
+
+ /** Creates a count-min sketch aggregate expression, using the child expression defined above. */
+ private def cms(eps: jl.Double, confidence: jl.Double, seed: jl.Integer): CountMinSketchAgg = {
+ new CountMinSketchAgg(
+ child = childExpression,
+ epsExpression = Literal(eps, DoubleType),
+ confidenceExpression = Literal(confidence, DoubleType),
+ seedExpression = Literal(seed, IntegerType))
}
- def testHighLevelInterface[T: ClassTag](
- dataType: DataType,
- sampledItemIndices: Array[Int],
- allItems: Array[T],
- exactFreq: Map[Any, Long]): Any = {
- test(s"high level interface, update, merge, eval... - $dataType") {
+ /**
+ * Creates a new test case that compares our aggregate function with a reference implementation
+ * (using the underlying [[CountMinSketch]]).
+ *
+ * This works by splitting the items into two separate groups, aggregates them, and then merges
+ * the two groups back (to emulate partial aggregation), and then compares the result with
+ * that generated by [[CountMinSketch]] directly. This assumes insertion order does not impact
+ * the result in count-min sketch.
+ */
+ private def testDataType[T](dataType: DataType, items: Seq[T]): Unit = {
+ test("test data type " + dataType) {
val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
assert(!agg.nullable)
- val group1 = 0 until sampledItemIndices.length / 2
- val group1Buffer = agg.createAggregationBuffer()
- group1.foreach { index =>
- val input = InternalRow(allItems(sampledItemIndices(index)))
- agg.update(group1Buffer, input)
+ val (seq1, seq2) = items.splitAt(items.size / 2)
+ val buf1 = addToAggregateBuffer(agg, seq1)
+ val buf2 = addToAggregateBuffer(agg, seq2)
+
+ val sketch = agg.createAggregationBuffer()
+ agg.merge(sketch, buf1)
+ agg.merge(sketch, buf2)
+
+ // Validate cardinality estimation against reference implementation.
+ val referenceSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ items.foreach { item =>
+ referenceSketch.add(item match {
+ case u: UTF8String => u.getBytes
+ case _ => item
+ })
}
- val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length
- val group2Buffer = agg.createAggregationBuffer()
- group2.foreach { index =>
- val input = InternalRow(allItems(sampledItemIndices(index)))
- agg.update(group2Buffer, input)
+ items.foreach { item =>
+ withClue(s"For item $item") {
+ val itemToTest = item match {
+ case u: UTF8String => u.getBytes
+ case _ => item
+ }
+ assert(referenceSketch.estimateCount(itemToTest) == sketch.estimateCount(itemToTest))
+ }
}
-
- var mergeBuffer = agg.createAggregationBuffer()
- agg.merge(mergeBuffer, group1Buffer)
- agg.merge(mergeBuffer, group2Buffer)
- checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
- // Merge in a different order
- mergeBuffer = agg.createAggregationBuffer()
- agg.merge(mergeBuffer, group2Buffer)
- agg.merge(mergeBuffer, group1Buffer)
- checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
- // Merge with an empty partition
- val emptyBuffer = agg.createAggregationBuffer()
- agg.merge(mergeBuffer, emptyBuffer)
- checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
}
- }
-
- def testLowLevelInterface[T: ClassTag](
- dataType: DataType,
- sampledItemIndices: Array[Int],
- allItems: Array[T],
- exactFreq: Map[Any, Long]): Any = {
- test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
- val inputAggregationBufferOffset = 1
- val mutableAggregationBufferOffset = 2
- // Phase one, partial mode aggregation
- val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
- Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
- .withNewInputAggBufferOffset(inputAggregationBufferOffset)
- .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
-
- val mutableAggBuffer = new GenericInternalRow(
- new Array[Any](mutableAggregationBufferOffset + 1))
- agg.initialize(mutableAggBuffer)
-
- sampledItemIndices.foreach { i =>
- agg.update(mutableAggBuffer, InternalRow(allItems(i)))
- }
- agg.serializeAggregateBufferInPlace(mutableAggBuffer)
-
- // Serialize the aggregation buffer
- val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
- val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
-
- // Phase 2: final mode aggregation
- // Re-initialize the aggregation buffer
- agg.initialize(mutableAggBuffer)
- agg.merge(mutableAggBuffer, inputAggBuffer)
- checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+ def addToAggregateBuffer[T](agg: CountMinSketchAgg, items: Seq[T]): CountMinSketch = {
+ val buf = agg.createAggregationBuffer()
+ items.foreach { item => agg.update(buf, InternalRow(item)) }
+ buf
}
}
- private def checkResult[T: ClassTag](
- result: Any,
- data: Array[T],
- exactFreq: Map[Any, Long]): Unit = {
- result match {
- case bytesData: Array[Byte] =>
- val in = new ByteArrayInputStream(bytesData)
- val cms = CountMinSketch.readFrom(in)
- val probCorrect = {
- val numErrors = data.map { i =>
- val count = exactFreq.getOrElse(getProbeItem(i), 0L)
- val item = i match {
- case dec: Decimal => dec.toJavaBigDecimal
- case str: UTF8String => str.getBytes
- case _ => i
- }
- val ratio = (cms.estimateCount(item) - count).toDouble / data.length
- if (ratio > epsOfTotalCount) 1 else 0
- }.sum
+ testDataType[Byte](ByteType, Seq.fill(100) { rand.nextInt(10).toByte })
- 1D - numErrors.toDouble / data.length
- }
+ testDataType[Short](ShortType, Seq.fill(100) { rand.nextInt(10).toShort })
- assert(
- probCorrect > confidence,
- s"Confidence not reached: required $confidence, reached $probCorrect"
- )
- case _ => fail("unexpected return type")
- }
- }
+ testDataType[Int](IntegerType, Seq.fill(100) { rand.nextInt(10) })
- private def getProbeItem[T: ClassTag](item: T): Any = item match {
- // Use a string to represent the content of an array of bytes
- case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
- case i => identity(i)
- }
+ testDataType[Long](LongType, Seq.fill(100) { rand.nextInt(10) })
- def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
- // Uses fixed seed to ensure reproducible test execution
- val r = new Random(31)
+ testDataType[UTF8String](StringType, Seq.fill(100) { UTF8String.fromString(rand.nextString(1)) })
- val numAllItems = 1000000
- val allItems = Array.fill(numAllItems)(itemGenerator(r))
+ testDataType[Array[Byte]](BinaryType, Seq.fill(100) { rand.nextString(1).getBytes() })
- val numSamples = numAllItems / 10
- val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+ test("serialize and de-serialize") {
+ // Check empty serialize and de-serialize
+ val agg = cms(epsOfTotalCount, confidence, seed)
+ val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+ assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
- val exactFreq = {
- val sampledItems = sampledItemIndices.map(allItems)
- sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+ // Check non-empty serialize and de-serialize
+ val random = new Random(31)
+ for (i <- 0 until 10) {
+ buffer.add(random.nextInt(100))
}
-
- testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
- testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
- }
-
- testItemType[Byte](ByteType) { _.nextInt().toByte }
-
- testItemType[Short](ShortType) { _.nextInt().toShort }
-
- testItemType[Int](IntegerType) { _.nextInt() }
-
- testItemType[Long](LongType) { _.nextLong() }
-
- testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }
-
- testItemType[Float](FloatType) { _.nextFloat() }
-
- testItemType[Double](DoubleType) { _.nextDouble() }
-
- testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
-
- testItemType[Boolean](BooleanType) { _.nextBoolean() }
-
- testItemType[Array[Byte]](BinaryType) { r =>
- r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+ assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
}
-
- test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
+ test("fails analysis if eps, confidence or seed provided is not foldable") {
val wrongEps = new CountMinSketchAgg(
childExpression,
epsExpression = AttributeReference("a", DoubleType)(),
@@ -227,88 +142,55 @@ class CountMinSketchAggSuite extends SparkFunSuite {
seedExpression = AttributeReference("c", IntegerType)())
Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
- assertEqual(
- wrongAgg.checkInputDataTypes(),
- TypeCheckFailure(
- "The eps, confidence or seed provided must be a literal or constant foldable")
- )
+ assertResult(
+ TypeCheckFailure("The eps, confidence or seed provided must be a literal or foldable")) {
+ wrongAgg.checkInputDataTypes()
+ }
}
}
test("fails analysis if parameters are invalid") {
// parameters are null
- val wrongEps = new CountMinSketchAgg(
- childExpression,
- epsExpression = Cast(Literal(null), DoubleType),
- confidenceExpression = Literal(confidence),
- seedExpression = Literal(seed))
- val wrongConfidence = new CountMinSketchAgg(
- childExpression,
- epsExpression = Literal(epsOfTotalCount),
- confidenceExpression = Cast(Literal(null), DoubleType),
- seedExpression = Literal(seed))
- val wrongSeed = new CountMinSketchAgg(
- childExpression,
- epsExpression = Literal(epsOfTotalCount),
- confidenceExpression = Literal(confidence),
- seedExpression = Cast(Literal(null), IntegerType))
+ val wrongEps = cms(null, confidence, seed)
+ val wrongConfidence = cms(epsOfTotalCount, null, seed)
+ val wrongSeed = cms(epsOfTotalCount, confidence, null)
Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
- assertEqual(
- wrongAgg.checkInputDataTypes(),
- TypeCheckFailure("The eps, confidence or seed provided should not be null")
- )
+ assertResult(TypeCheckFailure("The eps, confidence or seed provided should not be null")) {
+ wrongAgg.checkInputDataTypes()
+ }
}
// parameters are out of the valid range
Seq(0.0, -1000.0).foreach { invalidEps =>
- val invalidAgg = new CountMinSketchAgg(
- childExpression,
- epsExpression = Literal(invalidEps),
- confidenceExpression = Literal(confidence),
- seedExpression = Literal(seed))
- assertEqual(
- invalidAgg.checkInputDataTypes(),
- TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")
- )
+ val invalidAgg = cms(invalidEps, confidence, seed)
+ assertResult(
+ TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")) {
+ invalidAgg.checkInputDataTypes()
+ }
}
Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
- val invalidAgg = new CountMinSketchAgg(
- childExpression,
- epsExpression = Literal(epsOfTotalCount),
- confidenceExpression = Literal(invalidConfidence),
- seedExpression = Literal(seed))
- assertEqual(
- invalidAgg.checkInputDataTypes(),
- TypeCheckFailure(
- s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")
- )
+ val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed)
+ assertResult(TypeCheckFailure(
+ s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")) {
+ invalidAgg.checkInputDataTypes()
+ }
}
}
- private def assertEqual[T](left: T, right: T): Unit = {
- assert(left == right)
- }
-
test("null handling") {
def isEqual(result: Any, other: CountMinSketch): Boolean = {
- result match {
- case bytesData: Array[Byte] =>
- val in = new ByteArrayInputStream(bytesData)
- val cms = CountMinSketch.readFrom(in)
- cms.equals(other)
- case _ => fail("unexpected return type")
- }
+ other.equals(CountMinSketch.readFrom(result.asInstanceOf[Array[Byte]]))
}
- val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
- Literal(seed))
+ val agg = cms(epsOfTotalCount, confidence, seed)
val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
val buffer = new GenericInternalRow(new Array[Any](1))
agg.initialize(buffer)
// Empty aggregation buffer
assert(isEqual(agg.eval(buffer), emptyCms))
+
// Empty input row
agg.update(buffer, InternalRow(null))
assert(isEqual(agg.eval(buffer), emptyCms))