aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-08-22 15:48:35 -0700
committerReynold Xin <rxin@databricks.com>2016-08-22 15:48:35 -0700
commit84770b59f773f132073cd2af4204957fc2d7bf35 (patch)
treef1f4c739df710ebcc7bfe7a459234102c1cb698b /sql
parent929cb8beed9b7014231580cc002853236a5337d6 (diff)
downloadspark-84770b59f773f132073cd2af4204957fc2d7bf35.tar.gz
spark-84770b59f773f132073cd2af4204957fc2d7bf35.tar.bz2
spark-84770b59f773f132073cd2af4204957fc2d7bf35.zip
[SPARK-17162] Range does not support SQL generation
## What changes were proposed in this pull request? The range operator previously didn't support SQL generation, which made it not possible to use in views. ## How was this patch tested? Unit tests. cc hvanhovell Author: Eric Liang <ekl@databricks.com> Closes #14724 from ericl/spark-17162.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala3
-rw-r--r--sql/hive/src/test/resources/sqlgen/range.sql4
-rw-r--r--sql/hive/src/test/resources/sqlgen/range_with_splits.sql4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala14
8 files changed, 44 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index 7fdf7fa0c0..6b3bb68538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -28,9 +28,6 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
* Rule that resolves table-valued function references.
*/
object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
- private lazy val defaultParallelism =
- SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
-
/**
* List of argument names and their types, used to declare a function.
*/
@@ -84,25 +81,25 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
"range" -> Map(
/* range(end) */
tvf("end" -> LongType) { case Seq(end: Long) =>
- Range(0, end, 1, defaultParallelism)
+ Range(0, end, 1, None)
},
/* range(start, end) */
tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
- Range(start, end, 1, defaultParallelism)
+ Range(start, end, 1, None)
},
/* range(start, end, step) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
case Seq(start: Long, end: Long, step: Long) =>
- Range(start, end, step, defaultParallelism)
+ Range(start, end, step, None)
},
/* range(start, end, step, numPartitions) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
"numPartitions" -> IntegerType) {
case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
- Range(start, end, step, numPartitions)
+ Range(start, end, step, Some(numPartitions))
})
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index af1736e607..010aec7ba1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -422,17 +422,20 @@ case class Sort(
/** Factory for constructing new `Range` nodes. */
object Range {
- def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+ def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = {
val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
new Range(start, end, step, numSlices, output)
}
+ def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+ Range(start, end, step, Some(numSlices))
+ }
}
case class Range(
start: Long,
end: Long,
step: Long,
- numSlices: Int,
+ numSlices: Option[Int],
output: Seq[Attribute])
extends LeafNode with MultiInstanceRelation {
@@ -449,6 +452,14 @@ case class Range(
}
}
+ def toSQL(): String = {
+ if (numSlices.isDefined) {
+ s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step, ${numSlices.get})"
+ } else {
+ s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step)"
+ }
+ }
+
override def newInstance(): Range = copy(output = output.map(_.newInstance()))
override lazy val statistics: Statistics = {
@@ -457,11 +468,7 @@ case class Range(
}
override def simpleString: String = {
- if (step == 1) {
- s"Range ($start, $end, splits=$numSlices)"
- } else {
- s"Range ($start, $end, step=$step, splits=$numSlices)"
- }
+ s"Range ($start, $end, step=$step, splits=$numSlices)"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index af1de511da..dde91b0a86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -208,6 +208,9 @@ class SQLBuilder private (
case p: LocalRelation =>
p.toSQL(newSubqueryName())
+ case p: Range =>
+ p.toSQL()
+
case OneRowRelation =>
""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index ad8a716898..3562083b06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -318,7 +318,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
def start: Long = range.start
def step: Long = range.step
- def numSlices: Int = range.numSlices
+ def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
def numElements: BigInt = range.numElements
override val output: Seq[Attribute] = range.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index e397cfa058..f0d7b64c3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -179,8 +179,7 @@ case class CreateViewCommand(
sparkSession.sql(viewSQL).queryExecution.assertAnalyzed()
} catch {
case NonFatal(e) =>
- throw new RuntimeException(
- "Failed to analyze the canonicalized SQL. It is possible there is a bug in Spark.", e)
+ throw new RuntimeException(s"Failed to analyze the canonicalized SQL: ${viewSQL}", e)
}
val viewSchema = if (userSpecifiedColumns.isEmpty) {
diff --git a/sql/hive/src/test/resources/sqlgen/range.sql b/sql/hive/src/test/resources/sqlgen/range.sql
new file mode 100644
index 0000000000..53c72ea71e
--- /dev/null
+++ b/sql/hive/src/test/resources/sqlgen/range.sql
@@ -0,0 +1,4 @@
+-- This file is automatically generated by LogicalPlanToSQLSuite.
+select * from range(100)
+--------------------------------------------------------------------------------
+SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(0, 100, 1)) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/resources/sqlgen/range_with_splits.sql b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql
new file mode 100644
index 0000000000..83d637d54a
--- /dev/null
+++ b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql
@@ -0,0 +1,4 @@
+-- This file is automatically generated by LogicalPlanToSQLSuite.
+select * from range(1, 100, 20, 10)
+--------------------------------------------------------------------------------
+SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(1, 100, 20, 10)) AS gen_subquery_0) AS gen_subquery_1
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index 742b065891..9c6da6a628 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -23,7 +23,10 @@ import java.nio.file.{Files, NoSuchFileException, Paths}
import scala.util.control.NonFatal
import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -180,7 +183,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("Test should fail if the SQL query cannot be regenerated") {
- spark.range(10).createOrReplaceTempView("not_sql_gen_supported_table_so_far")
+ case class Unsupported() extends LeafNode with MultiInstanceRelation {
+ override def newInstance(): Unsupported = copy()
+ override def output: Seq[Attribute] = Nil
+ }
+ Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far")
sql("select * from not_sql_gen_supported_table_so_far")
val m3 = intercept[org.scalatest.exceptions.TestFailedException] {
checkSQL("select * from not_sql_gen_supported_table_so_far", "in")
@@ -196,6 +203,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}
+ test("range") {
+ checkSQL("select * from range(100)", "range")
+ checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits")
+ }
+
test("in") {
checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in")
}