aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala44
2 files changed, 79 insertions, 2 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index dca9e5e503..ede7d9a0c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -660,11 +660,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("limit") {
checkAnswer(
- sql("SELECT * FROM testData LIMIT 10"),
+ sql("SELECT * FROM testData LIMIT 9 + 1"),
testData.take(10).toSeq)
checkAnswer(
- sql("SELECT * FROM arrayData LIMIT 1"),
+ sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"),
arrayData.collect().take(1).map(Row.fromTuple).toSeq)
checkAnswer(
@@ -672,6 +672,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
mapData.collect().take(1).map(Row.fromTuple).toSeq)
}
+ test("non-foldable expressions in LIMIT") {
+ val e = intercept[AnalysisException] {
+ sql("SELECT * FROM testData LIMIT key > 3")
+ }.getMessage
+ assert(e.contains("The limit expression must evaluate to a constant value, " +
+ "but got (testdata.`key` > 3)"))
+ }
+
+ test("Expressions in limit clause are not integer") {
+ var e = intercept[AnalysisException] {
+ sql("SELECT * FROM testData LIMIT true")
+ }.getMessage
+ assert(e.contains("The limit expression must be integer type, but got boolean"))
+
+ e = intercept[AnalysisException] {
+ sql("SELECT * FROM testData LIMIT 'a'")
+ }.getMessage
+ assert(e.contains("The limit expression must be integer type, but got string"))
+ }
+
+ test("negative in LIMIT or TABLESAMPLE") {
+ val expected = "The limit expression must be equal to or greater than 0, but got -1"
+ var e = intercept[AnalysisException] {
+ sql("SELECT * FROM testData TABLESAMPLE (-1 rows)")
+ }.getMessage
+ assert(e.contains(expected))
+
+ e = intercept[AnalysisException] {
+ sql("SELECT * FROM testData LIMIT -1")
+ }.getMessage
+ assert(e.contains(expected))
+ }
+
test("CTE feature") {
checkAnswer(
sql("with q1 as (select * from testData limit 10) select * from q1"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
index 4de3cf605c..ab55242ec0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class StatisticsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
@@ -31,4 +33,46 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
+ test("estimates the size of limit") {
+ withTempTable("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
+ val df = sql(s"""SELECT * FROM test limit $limit""")
+
+ val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
+ g.statistics.sizeInBytes
+ }
+ assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesGlobalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
+
+ val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
+ l.statistics.sizeInBytes
+ }
+ assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizesLocalLimit.head === BigInt(expected),
+ s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
+ }
+ }
+ }
+
+ test("estimates the size of a limit 0 on outer join") {
+ withTempTable("test") {
+ Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
+ .createOrReplaceTempView("test")
+ val df1 = spark.table("test")
+ val df2 = spark.table("test").limit(0)
+ val df = df1.join(df2, Seq("k"), "left")
+
+ val sizes = df.queryExecution.analyzed.collect { case g: Join =>
+ g.statistics.sizeInBytes
+ }
+
+ assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
+ assert(sizes.head === BigInt(96),
+ s"expected exact size 96 for table 'test', got: ${sizes.head}")
+ }
+ }
+
}