aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala43
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala15
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala64
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala22
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala5
7 files changed, 137 insertions, 38 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index bf5edb4759..1dda39d44e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -24,10 +24,11 @@ import scala.util.control.NonFatal
import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
/**
@@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
* supported by this builder (yet).
*/
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
+ require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans")
+
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
def toSQL: String = {
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
try {
+ canonicalizedPlan.transformAllExpressions {
+ case e: NonSQLExpression =>
+ throw new UnsupportedOperationException(
+ s"Expression $e doesn't have a SQL representation"
+ )
+ case e => e
+ }
+
val generatedSQL = toSQL(canonicalizedPlan)
logDebug(
s"""Built SQL query string successfully from given logical plan:
@@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
p.child match {
// Persisted data source relation
case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
- s"`$database`.`$table`"
+ s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
@@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: MetastoreRelation =>
build(
- s"`${p.databaseName}`.`${p.tableName}`",
- p.alias.map(a => s" AS `$a`").getOrElse("")
+ s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}",
+ p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("")
)
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
@@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
* The segments are trimmed so only a single space appears in the separation.
* For example, `build("a", " b ", " c")` becomes "a b c".
*/
- private def build(segments: String*): String = segments.map(_.trim).mkString(" ")
+ private def build(segments: String*): String =
+ segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index d5ed838ca4..bcafa045e0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.types._
@@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry(
// catch the exception and throw AnalysisException instead.
try {
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(
+ val udf = HiveGenericUDF(
name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
+ val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
+ udaf.dataType // Force it to check input data types.
+ udaf
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(
+ val udaf = HiveUDAFFunction(
name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
+ udaf.dataType // Force it to check input data types.
+ udaf
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
udtf.elementTypes // Force it to check input data types.
@@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val conversionHelper = new ConversionHelper(method, arguments)
- override val dataType = javaClassToDataType(method.getReturnType)
+ override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
@@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+ override def prettyName: String = name
+
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
@@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF(
}
@transient
- private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
+ private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]
- override val dataType: DataType = inspectorToDataType(returnInspector)
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.
@@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF(
var i = 0
while (i < children.length) {
val idx = i
- deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
+ deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(
() => {
children(idx).eval(input)
})
i += 1
}
- unwrap(function.evaluate(deferedObjects), returnInspector)
+ unwrap(function.evaluate(deferredObjects), returnInspector)
}
+ override def prettyName: String = name
+
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
-
- override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
- override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
+ override def prettyName: String = name
}
/**
@@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction(
override def supportsPartial: Boolean = false
- override val dataType: DataType = inspectorToDataType(returnInspector)
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
+
+ override def prettyName: String = name
override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else " "
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
index 3fb6543b1a..e4b4d1861a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
@@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
test("literal") {
checkSQL(Literal("foo"), "\"foo\"")
checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
- checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
- checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
+ checkSQL(Literal(1: Byte), "1Y")
+ checkSQL(Literal(2: Short), "2S")
checkSQL(Literal(4: Int), "4")
- checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
+ checkSQL(Literal(8: Long), "8L")
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
- checkSQL(Literal(2.5D), "2.5")
+ checkSQL(Literal(2.5D), "2.5D")
checkSQL(
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
// TODO tests for decimals
}
+ test("attributes") {
+ checkSQL('a.int, "`a`")
+ checkSQL(Symbol("foo bar").int, "`foo bar`")
+ // Keyword
+ checkSQL('int.int, "`int`")
+ }
+
test("binary comparisons") {
checkSQL('a.int === 'b.int, "(`a` = `b`)")
checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index dc8ac7e47f..5255b150aa 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
sql("DROP TABLE IF EXISTS t0")
sql("DROP TABLE IF EXISTS t1")
sql("DROP TABLE IF EXISTS t2")
+
sqlContext.range(10).write.saveAsTable("t0")
sqlContext
@@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}
}
+
+ test("plans with non-SQL expressions") {
+ sqlContext.udf.register("foo", (_: Int) * 2)
+ intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
+ }
+
+ test("named expression in column names shouldn't be quoted") {
+ def checkColumnNames(query: String, expectedColNames: String*): Unit = {
+ checkHiveQl(query)
+ assert(sql(query).columns === expectedColNames)
+ }
+
+ // Attributes
+ checkColumnNames(
+ """SELECT * FROM (
+ | SELECT 1 AS a, 2 AS b, 3 AS `we``ird`
+ |) s
+ """.stripMargin,
+ "a", "b", "we`ird"
+ )
+
+ checkColumnNames(
+ """SELECT x.a, y.a, x.b, y.b
+ |FROM (SELECT 1 AS a, 2 AS b) x
+ |INNER JOIN (SELECT 1 AS a, 2 AS b) y
+ |ON x.a = y.a
+ """.stripMargin,
+ "a", "a", "b", "b"
+ )
+
+ // String literal
+ checkColumnNames(
+ "SELECT 'foo', '\"bar\\''",
+ "foo", "\"bar\'"
+ )
+
+ // Numeric literals (should have CAST or suffixes in column names)
+ checkColumnNames(
+ "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D",
+ "1", "2", "3", "4", "5.1", "6.1"
+ )
+
+ // Aliases
+ checkColumnNames(
+ "SELECT 1 AS a",
+ "a"
+ )
+
+ // Complex type extractors
+ checkColumnNames(
+ """SELECT
+ | a.f1, b[0].f1, b.f1, c["foo"], d[0]
+ |FROM (
+ | SELECT
+ | NAMED_STRUCT("f1", 1, "f2", "foo") AS a,
+ | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b,
+ | MAP("foo", 1) AS c,
+ | ARRAY(1) AS d
+ |) s
+ """.stripMargin,
+ "f1", "b[0].f1", "f1", "c[foo]", "d[0]"
+ )
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
index cd055f9eca..d8d3448add 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
@@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton {
test("udf constant folding") {
Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t")
- val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan
- val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan
+ val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan
+ val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 27ea3e8041..fe446774ef 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
val df = sql(
"""
|SELECT
- | CAST(null as TINYINT),
- | CAST(null as SMALLINT),
- | CAST(null as INT),
- | CAST(null as BIGINT),
- | CAST(null as FLOAT),
- | CAST(null as DOUBLE),
- | CAST(null as DECIMAL(7,2)),
- | CAST(null as TIMESTAMP),
- | CAST(null as DATE),
- | CAST(null as STRING),
- | CAST(null as VARCHAR(10))
+ | CAST(null as TINYINT) as c0,
+ | CAST(null as SMALLINT) as c1,
+ | CAST(null as INT) as c2,
+ | CAST(null as BIGINT) as c3,
+ | CAST(null as FLOAT) as c4,
+ | CAST(null as DOUBLE) as c5,
+ | CAST(null as DECIMAL(7,2)) as c6,
+ | CAST(null as TIMESTAMP) as c7,
+ | CAST(null as DATE) as c8,
+ | CAST(null as STRING) as c9,
+ | CAST(null as VARCHAR(10)) as c10
|FROM orc_temp_table limit 1
""".stripMargin)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index c997453803..a6ca7d0386 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
sql(
s"""CREATE TABLE array_of_struct
|STORED AS PARQUET LOCATION '$path'
- |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b'))
+ |AS SELECT
+ | '1st' AS a,
+ | '2nd' AS b,
+ | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c
""".stripMargin)
checkAnswer(