aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala11
2 files changed, 13 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 34bd243d58..b19344f043 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -40,7 +40,7 @@ private[sql] object FrequentItems extends Logging {
if (baseMap.size < size) {
baseMap += key -> count
} else {
- val minCount = baseMap.values.min
+ val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min
val remainder = count - minCount
if (remainder >= 0) {
baseMap += key -> count // something will get kicked out, so we can add this
@@ -83,7 +83,7 @@ private[sql] object FrequentItems extends Logging {
df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
- require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
+ require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
val numCols = cols.length
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index ab7733b239..73026c749d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -235,6 +235,17 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(items.length === 1)
}
+ test("SPARK-15709: Prevent `UnsupportedOperationException: empty.min` in `freqItems`") {
+ val ds = spark.createDataset(Seq(1, 2, 2, 3, 3, 3))
+
+ intercept[IllegalArgumentException] {
+ ds.stat.freqItems(Seq("value"), 0)
+ }
+ intercept[IllegalArgumentException] {
+ ds.stat.freqItems(Seq("value"), 2)
+ }
+ }
+
test("sampleBy") {
val df = spark.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)