aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala212
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala4
-rw-r--r--sql/core/src/test/resources/sql-tests/results/group-by.sql.out2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala20
-rw-r--r--sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala12
13 files changed, 163 insertions, 192 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f8f4799322..5011f2fdbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.{TreeNodeRef}
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.types._
@@ -83,6 +83,7 @@ class Analyzer(
ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
+ ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
@@ -653,11 +654,12 @@ class Analyzer(
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
- case c: CreateStruct if containsStar(c.children) =>
- c.copy(children = c.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
+ case c: CreateNamedStruct if containsStar(c.valExprs) =>
+ val newChildren = c.children.grouped(2).flatMap {
+ case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
+ case kv => kv
+ }
+ c.copy(children = newChildren.toList )
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
@@ -1141,7 +1143,7 @@ class Analyzer(
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
// Get the left hand side expressions.
val expressions = e match {
- case CreateStruct(exprs) => exprs
+ case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
@@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] {
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
- var stop = false
e.transformDown {
- // CreateStruct is a special case, we need to retain its top level Aliases as they decide the
- // name of StructField. We also need to stop transform down this expression, or the Aliases
- // under CreateStruct will be mistakenly trimmed.
- case c: CreateStruct if !stop =>
- stop = true
- c.copy(children = c.children.map(trimNonTopLevelAliases))
- case c: CreateStructUnsafe if !stop =>
- stop = true
- c.copy(children = c.children.map(trimNonTopLevelAliases))
- case Alias(child, _) if !stop => child
+ case Alias(child, _) => child
}
}
@@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
case a: AppendColumns => a
case other =>
- var stop = false
other transformExpressionsDown {
- case c: CreateStruct if !stop =>
- stop = true
- c.copy(children = c.children.map(trimNonTopLevelAliases))
- case c: CreateStructUnsafe if !stop =>
- stop = true
- c.copy(children = c.children.map(trimNonTopLevelAliases))
- case Alias(child, _) if !stop => child
+ case Alias(child, _) => child
}
}
}
@@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
+ */
+object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
+ case e: CreateNamedStruct if !e.resolved =>
+ val children = e.children.grouped(2).flatMap {
+ case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
+ Seq(Literal(e.name), e)
+ case kv =>
+ kv
+ }
+ CreateNamedStruct(children.toList)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3e836ca375..b028d07fb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -357,7 +357,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
- expression[CreateStruct]("struct"),
+ CreateStruct.registryEntry,
// misc functions
expression[AssertTrue]("assert_true"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index a81fa1ce3a..03e054d098 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -119,7 +119,6 @@ object UnsafeProjection {
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
- case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(unsafeExprs)
@@ -145,7 +144,6 @@ object UnsafeProjection {
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
- case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 917aa08731..dbfb2996ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -172,101 +174,71 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
}
/**
- * Returns a Row containing the evaluation of all children expressions.
+ * An expression representing a not yet available attribute name. This expression is unevaluable
+ * and as its name suggests it is a temporary place holder until we're able to determine the
+ * actual attribute name.
*/
-@ExpressionDescription(
- usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.")
-case class CreateStruct(children: Seq[Expression]) extends Expression {
-
- override def foldable: Boolean = children.forall(_.foldable)
-
- override lazy val dataType: StructType = {
- val fields = children.zipWithIndex.map { case (child, idx) =>
- child match {
- case ne: NamedExpression =>
- StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
- case _ =>
- StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
- }
- }
- StructType(fields)
- }
-
+case object NamePlaceholder extends LeafExpression with Unevaluable {
+ override lazy val resolved: Boolean = false
+ override def foldable: Boolean = false
override def nullable: Boolean = false
+ override def dataType: DataType = StringType
+ override def prettyName: String = "NamePlaceholder"
+ override def toString: String = prettyName
+}
- override def eval(input: InternalRow): Any = {
- InternalRow(children.map(_.eval(input)): _*)
+/**
+ * Returns a Row containing the evaluation of all children expressions.
+ */
+object CreateStruct extends FunctionBuilder {
+ def apply(children: Seq[Expression]): CreateNamedStruct = {
+ CreateNamedStruct(children.zipWithIndex.flatMap {
+ case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
+ case (e: NamedExpression, _) => Seq(NamePlaceholder, e)
+ case (e, index) => Seq(Literal(s"col${index + 1}"), e)
+ })
}
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val rowClass = classOf[GenericInternalRow].getName
- val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"this.$values = null;")
-
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
- this.$values = new Object[${children.size}];""" +
- ctx.splitExpressions(
- ctx.INPUT_ROW,
- children.zipWithIndex.map { case (e, i) =>
- val eval = e.genCode(ctx)
- eval.code + s"""
- if (${eval.isNull}) {
- $values[$i] = null;
- } else {
- $values[$i] = ${eval.value};
- }"""
- }) +
- s"""
- final InternalRow ${ev.value} = new $rowClass($values);
- this.$values = null;
- """)
+ /**
+ * Entry to use in the function registry.
+ */
+ val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
+ val info: ExpressionInfo = new ExpressionInfo(
+ "org.apache.spark.sql.catalyst.expressions.NamedStruct",
+ null,
+ "struct",
+ "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
+ "")
+ ("struct", (info, this))
}
-
- override def prettyName: String = "struct"
}
-
/**
- * Creates a struct with the given field names and values
- *
- * @param children Seq(name1, val1, name2, val2, ...)
+ * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
-// scalastyle:on line.size.limit
-case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
+trait CreateNamedStructLike extends Expression {
+ lazy val (nameExprs, valExprs) = children.grouped(2).map {
+ case Seq(name, value) => (name, value)
+ }.toList.unzip
- /**
- * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
- * StructType.
- */
- def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
- case (v, n) => Alias(v, n.toString)()
- }
+ lazy val names = nameExprs.map(_.eval(EmptyRow))
- private lazy val (nameExprs, valExprs) =
- children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
+ override def nullable: Boolean = false
- private lazy val names = nameExprs.map(_.eval(EmptyRow))
+ override def foldable: Boolean = valExprs.forall(_.foldable)
override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map {
- case (name, valExpr: NamedExpression) =>
- StructField(name.asInstanceOf[UTF8String].toString,
- valExpr.dataType, valExpr.nullable, valExpr.metadata)
- case (name, valExpr) =>
- StructField(name.asInstanceOf[UTF8String].toString,
- valExpr.dataType, valExpr.nullable, Metadata.empty)
+ case (name, expr) =>
+ val metadata = expr match {
+ case ne: NamedExpression => ne.metadata
+ case _ => Metadata.empty
+ }
+ StructField(name.toString, expr.dataType, expr.nullable, metadata)
}
StructType(fields)
}
- override def foldable: Boolean = valExprs.forall(_.foldable)
-
- override def nullable: Boolean = false
-
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
@@ -274,8 +246,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
- s"Only foldable StringType expressions are allowed to appear at odd position , got :" +
- s" ${invalidNames.mkString(",")}")
+ "Only foldable StringType expressions are allowed to appear at odd position, got:" +
+ s" ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {
@@ -284,9 +256,29 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
}
}
+ /**
+ * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
+ * StructType.
+ */
+ def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
+ case (v, n) => Alias(v, n.toString)()
+ }
+
override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
+}
+
+/**
+ * Creates a struct with the given field names and values
+ *
+ * @param children Seq(name1, val1, name2, val2, ...)
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
+// scalastyle:on line.size.limit
+case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
@@ -317,75 +309,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
}
/**
- * Returns a Row containing the evaluation of all children expressions. This is a variant that
- * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
- * this expression automatically at runtime.
- */
-case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
-
- override def foldable: Boolean = children.forall(_.foldable)
-
- override lazy val resolved: Boolean = childrenResolved
-
- override lazy val dataType: StructType = {
- val fields = children.zipWithIndex.map { case (child, idx) =>
- child match {
- case ne: NamedExpression =>
- StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
- case _ =>
- StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
- }
- }
- StructType(fields)
- }
-
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- InternalRow(children.map(_.eval(input)): _*)
- }
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val eval = GenerateUnsafeProjection.createCode(ctx, children)
- ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value)
- }
-
- override def prettyName: String = "struct_unsafe"
-}
-
-
-/**
* Creates a struct with the given field names and values. This is a variant that returns
* UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
* this expression automatically at runtime.
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
-case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression {
-
- private lazy val (nameExprs, valExprs) =
- children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
-
- private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
-
- override lazy val dataType: StructType = {
- val fields = names.zip(valExprs).map {
- case (name, valExpr: NamedExpression) =>
- StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata)
- case (name, valExpr) =>
- StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
- }
- StructType(fields)
- }
-
- override def foldable: Boolean = valExprs.forall(_.foldable)
-
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- InternalRow(valExprs.map(_.eval(input)): _*)
- }
-
+case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index ac1577b3ab..4b151c81d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -688,8 +688,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// inline table comes in two styles:
// style 1: values (1), (2), (3) -- multiple columns are supported
// style 2: values 1, 2, 3 -- only a single column is supported here
- case CreateStruct(children) => children // style 1
- case child => Seq(child) // style 2
+ case struct: CreateNamedStruct => struct.valExprs // style 1
+ case child => Seq(child) // style 2
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 590774c043..817de48de2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import org.scalatest.ShouldMatchers
+
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -25,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
-class AnalysisSuite extends AnalysisTest {
+
+class AnalysisSuite extends AnalysisTest with ShouldMatchers {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
test("union project *") {
@@ -218,9 +221,36 @@ class AnalysisSuite extends AnalysisTest {
// CreateStruct is a special case that we should not trim Alias for it.
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
- checkAnalysis(plan, plan)
- plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
- checkAnalysis(plan, plan)
+ expected = testRelation.select(CreateNamedStruct(Seq(
+ Literal(a.name), a,
+ Literal("a+1"), (a + 1))).as("col"))
+ checkAnalysis(plan, expected)
+ }
+
+ test("Analysis may leave unnecassary aliases") {
+ val att1 = testRelation.output.head
+ var plan = testRelation.select(
+ CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"),
+ att1
+ )
+ val prevPlan = getAnalyzer(true).execute(plan)
+ plan = prevPlan.select(CreateArray(Seq(
+ CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"),
+ /** alias should be eliminated by [[CleanupAliases]] */
+ "col".attr.as("col2")
+ )).as("arr"))
+ plan = getAnalyzer(true).execute(plan)
+
+ val expectedPlan = prevPlan.select(
+ CreateArray(Seq(
+ CreateNamedStruct(Seq(
+ Literal(att1.name), att1,
+ Literal("a_plus_1"), (att1 + 1))),
+ 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull
+ )).as("arr")
+ )
+
+ checkAnalysis(plan, expectedPlan)
}
test("SPARK-10534: resolve attribute references in order by clause") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 0c307b2b85..c21c6de32c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -243,7 +243,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
- checkMetadata(CreateStructUnsafe(Seq(a, b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 249408e0fb..7a131b30ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -186,6 +186,9 @@ class Column(val expr: Expression) extends Logging {
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
UnresolvedAlias(a, Some(Column.generateAlias))
+ // Wait until the struct is resolved. This will generate a nicer looking alias.
+ case struct: CreateNamedStructLike => UnresolvedAlias(struct)
+
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index f873f34a84..6141fab4af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -137,7 +137,7 @@ object ColumnStatStruct {
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
- private def getStruct(exprs: Seq[Expression]): CreateStruct = {
+ private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
@@ -168,7 +168,7 @@ object ColumnStatStruct {
}
}
- def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match {
+ def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match {
// Use aggregate functions to compute statistics we need.
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
case StringType => getStruct(stringColumnStat(attr, relativeSD))
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index a91f04e098..af6c930d64 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -87,7 +87,7 @@ struct<foo:string,approx_count_distinct(a):bigint>
-- !query 9
SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1
-- !query 9 schema
-struct<foo:string,max(struct(a)):struct<a:int>>
+struct<foo:string,max(named_struct(a, a)):struct<a:int>>
-- !query 9 output
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 6eb571b91f..90000445df 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -190,6 +190,12 @@ private[hive] class TestHiveSparkSession(
new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile)
}
+ private def quoteHiveFile(path : String) = if (Utils.isWindows) {
+ getHiveFile(path).getPath.replace('\\', '/')
+ } else {
+ getHiveFile(path).getPath
+ }
+
def getWarehousePath(): String = {
val tempConf = new SQLConf
sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) }
@@ -225,16 +231,16 @@ private[hive] class TestHiveSparkSession(
val hiveQTestUtilTables: Seq[TestTable] = Seq(
TestTable("src",
"CREATE TABLE src (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
+ s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
TestTable("src1",
"CREATE TABLE src1 (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
+ s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
TestTable("srcpart", () => {
sql(
"CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
sql(
- s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}'
|OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
""".stripMargin)
}
@@ -244,7 +250,7 @@ private[hive] class TestHiveSparkSession(
"CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
sql(
- s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}'
|OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
""".stripMargin)
}
@@ -269,7 +275,7 @@ private[hive] class TestHiveSparkSession(
sql(
s"""
- |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}'
+ |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}'
|INTO TABLE src_thrift
""".stripMargin)
}),
@@ -308,7 +314,7 @@ private[hive] class TestHiveSparkSession(
|)
""".stripMargin.cmd,
s"""
- |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}'
+ |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}'
|INTO TABLE episodes
""".stripMargin.cmd
),
@@ -379,7 +385,7 @@ private[hive] class TestHiveSparkSession(
TestTable("src_json",
s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE
""".stripMargin.cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
+ s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
)
hiveQTestUtilTables.foreach(registerTestTable)
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
index de0116a4dc..cdda29af50 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
@@ -7,4 +7,4 @@ having b.key in (select a.key
where a.value > 'val_9' and a.value = min(b.value))
order by b.key
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index c7f10e569f..12d18dc87c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, NoSuchFileException, Paths}
+import scala.io.Source
import scala.util.control.NonFatal
import org.apache.spark.sql.Column
@@ -109,12 +110,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
Files.write(path, answerText.getBytes(StandardCharsets.UTF_8))
} else {
val goldenFileName = s"sqlgen/$answerFile.sql"
- val resourceFile = getClass.getClassLoader.getResource(goldenFileName)
- if (resourceFile == null) {
+ val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName)
+ if (resourceStream == null) {
throw new NoSuchFileException(goldenFileName)
}
- val path = resourceFile.getPath
- val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8)
+ val answerText = try {
+ Source.fromInputStream(resourceStream).mkString
+ } finally {
+ resourceStream.close
+ }
val sqls = answerText.split(separator)
assert(sqls.length == 2, "Golden sql files should have a separator.")
val expectedSQL = sqls(1).trim()