aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala24
2 files changed, 36 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 9329148aa2..db463029ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.{Row, Column, DataFrame}
private[sql] object FrequentItems extends Logging {
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
-
/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
@@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging {
if (baseMap.size < size) {
baseMap += key -> count
} else {
- // TODO: Make this more efficient... A flatMap?
- baseMap.retain((k, v) => v > count)
- baseMap.transform((k, v) => v - count)
+ val minCount = baseMap.values.min
+ val remainder = count - minCount
+ if (remainder >= 0) {
+ baseMap += key -> count // something will get kicked out, so we can add this
+ baseMap.retain((k, v) => v > minCount)
+ baseMap.transform((k, v) => v - minCount)
+ } else {
+ baseMap.transform((k, v) => v - count)
+ }
}
}
this
@@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}.toArray
- val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)(
+ val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
- val key = row.get(i, colInfo(i)._2)
+ val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
@@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
- val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
- val resultRow = InternalRow(justItems : _*)
+ val justItems = freqItems.map(m => m.baseMap.keys.toArray)
+ val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
- new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
+ new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 07a675e64f..0e7659f443 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -123,12 +123,30 @@ class DataFrameStatSuite extends QueryTest {
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
- items.getSeq[Int](0) should contain (1)
- items.getSeq[String](1) should contain (toLetter(1))
+ assert(items.getSeq[Int](0).contains(1))
+ assert(items.getSeq[String](1).contains(toLetter(1)))
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
- items2.getSeq[Double](0) should contain (-1.0)
+ assert(items2.getSeq[Double](0).contains(-1.0))
+ }
+
+ test("Frequent Items 2") {
+ val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
+ // this is a regression test, where when merging partitions, we omitted values with higher
+ // counts than those that existed in the map when the map was full. This test should also fail
+ // if anything like SPARK-9614 is observed once again
+ val df = rows.mapPartitionsWithIndex { (idx, iter) =>
+ if (idx == 3) { // must come from one of the later merges, therefore higher partition index
+ Iterator("3", "3", "3", "3", "3")
+ } else {
+ Iterator("0", "1", "2", "3", "4")
+ }
+ }.toDF("a")
+ val results = df.stat.freqItems(Array("a"), 0.25)
+ val items = results.collect().head.getSeq[String](0)
+ assert(items.contains("3"))
+ assert(items.length === 1)
}
test("sampleBy") {