From 0cba535af3c65618f342fa2d7db9647f5e6f6f1b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 1 Nov 2016 17:30:37 +0100 Subject: Revert "[SPARK-16839][SQL] redundant aliases after cleanupAliases" This reverts commit 5441a6269e00e3903ae6c1ea8deb4ddf3d2e9975. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 53 +++--- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 2 + .../catalyst/expressions/complexTypeCreator.scala | 211 ++++++++++++++------- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +--- .../catalyst/expressions/ComplexTypeSuite.scala | 1 + .../main/scala/org/apache/spark/sql/Column.scala | 3 - .../execution/command/AnalyzeColumnCommand.scala | 4 +- .../test/resources/sql-tests/inputs/group-by.sql | 2 +- .../resources/sql-tests/results/group-by.sql.out | 4 +- .../org/apache/spark/sql/hive/test/TestHive.scala | 20 +- .../test/resources/sqlgen/subquery_in_having_2.sql | 2 +- .../spark/sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- 15 files changed, 200 insertions(+), 170 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 5002655fc0..9289db57b6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, alias(struct("a", "c"), "d"))) + result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) - expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, alias(struct(df$a, df$b), "d"))) + result <- collect(select(df, struct(df$a, df$b))) expected <- data.frame(row.names = 1:2) - expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5011f2fdbf..f8f4799322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -83,7 +83,6 @@ class Analyzer( ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: - ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: @@ -654,12 +653,11 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) - case c: CreateNamedStruct if containsStar(c.valExprs) => - val newChildren = c.children.grouped(2).flatMap { - case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children - case kv => kv - } - c.copy(children = newChildren.toList ) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) @@ -1143,7 +1141,7 @@ class Analyzer( case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { - case cns : CreateNamedStruct => cns.valExprs + case CreateStruct(exprs) => exprs case expr => Seq(expr) } resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => @@ -2074,8 +2072,18 @@ object EliminateUnions extends Rule[LogicalPlan] { */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { + var stop = false e.transformDown { - case Alias(child, _) => child + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child } } @@ -2108,8 +2116,15 @@ object CleanupAliases extends Rule[LogicalPlan] { case a: AppendColumns => a case other => + var stop = false other transformExpressionsDown { - case Alias(child, _) => child + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child } } } @@ -2202,19 +2217,3 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } - -/** - * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. - */ -object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case e: CreateNamedStruct if !e.resolved => - val children = e.children.grouped(2).flatMap { - case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => - Seq(Literal(e.name), e) - case kv => - kv - } - CreateNamedStruct(children.toList) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b028d07fb8..3e836ca375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -357,7 +357,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - CreateStruct.registryEntry, + expression[CreateStruct]("struct"), // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 03e054d098..a81fa1ce3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -119,6 +119,7 @@ object UnsafeProjection { */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(unsafeExprs) @@ -144,6 +145,7 @@ object UnsafeProjection { subexpressionEliminationEnabled: Boolean): UnsafeProjection = { val e = exprs.map(BindReferences.bindReference(_, inputSchema)) .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e9623f96e1..917aa08731 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -174,70 +172,101 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } /** - * An expression representing a not yet available attribute name. This expression is unevaluable - * and as its name suggests it is a temporary place holder until we're able to determine the - * actual attribute name. + * Returns a Row containing the evaluation of all children expressions. */ -case object NamePlaceholder extends LeafExpression with Unevaluable { - override lazy val resolved: Boolean = false - override def foldable: Boolean = false +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") +case class CreateStruct(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + override def nullable: Boolean = false - override def dataType: DataType = StringType - override def prettyName: String = "NamePlaceholder" - override def toString: String = prettyName -} -/** - * Returns a Row containing the evaluation of all children expressions. - */ -object CreateStruct extends FunctionBuilder { - def apply(children: Seq[Expression]): CreateNamedStruct = { - CreateNamedStruct(children.zipWithIndex.flatMap { - case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) - case (e: NamedExpression, _) => Seq(NamePlaceholder, e) - case (e, index) => Seq(Literal(s"col${index + 1}"), e) - }) + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) } - /** - * Entry to use in the function registry. - */ - val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { - val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", - "struct", - "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", - "") - ("struct", (info, this)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") + ctx.addMutableState("Object[]", values, s"this.$values = null;") + + ev.copy(code = s""" + boolean ${ev.isNull} = false; + this.$values = new Object[${children.size}];""" + + ctx.splitExpressions( + ctx.INPUT_ROW, + children.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + }""" + }) + + s""" + final InternalRow ${ev.value} = new $rowClass($values); + this.$values = null; + """) } + + override def prettyName: String = "struct" } + /** - * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) */ -trait CreateNamedStructLike extends Expression { - lazy val (nameExprs, valExprs) = children.grouped(2).map { - case Seq(name, value) => (name, value) - }.toList.unzip +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { - lazy val names = nameExprs.map(_.eval(EmptyRow)) + /** + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } - override def nullable: Boolean = false + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - override def foldable: Boolean = valExprs.forall(_.foldable) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { - case (name, expr) => - val metadata = expr match { - case ne: NamedExpression => ne.metadata - case _ => Metadata.empty - } - StructField(name.toString, expr.dataType, expr.nullable, metadata) + case (name, valExpr: NamedExpression) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") @@ -245,8 +274,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { @@ -255,29 +284,9 @@ trait CreateNamedStructLike extends Expression { } } - /** - * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this - * StructType. - */ - def flatten: Seq[NamedExpression] = valExprs.zip(names).map { - case (v, n) => Alias(v, n.toString)() - } - override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } -} - -/** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -307,6 +316,44 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def prettyName: String = "named_struct" } +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = GenerateUnsafeProjection.createCode(ctx, children) + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) + } + + override def prettyName: String = "struct_unsafe" +} + + /** * Creates a struct with the given field names and values. This is a variant that returns * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with @@ -314,7 +361,31 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 35aca91cf8..38e9bb6c16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -681,8 +681,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // inline table comes in two styles: // style 1: values (1), (2), (3) -- multiple columns are supported // style 2: values 1, 2, 3 -- only a single column is supported here - case struct: CreateNamedStruct => struct.valExprs // style 1 - case child => Seq(child) // style 2 + case CreateStruct(children) => children // style 1 + case child => Seq(child) // style 2 } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 817de48de2..590774c043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.ShouldMatchers - import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -27,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ - -class AnalysisSuite extends AnalysisTest with ShouldMatchers { +class AnalysisSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { @@ -221,36 +218,9 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { // CreateStruct is a special case that we should not trim Alias for it. plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) - expected = testRelation.select(CreateNamedStruct(Seq( - Literal(a.name), a, - Literal("a+1"), (a + 1))).as("col")) - checkAnalysis(plan, expected) - } - - test("Analysis may leave unnecassary aliases") { - val att1 = testRelation.output.head - var plan = testRelation.select( - CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), - att1 - ) - val prevPlan = getAnalyzer(true).execute(plan) - plan = prevPlan.select(CreateArray(Seq( - CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"), - /** alias should be eliminated by [[CleanupAliases]] */ - "col".attr.as("col2") - )).as("arr")) - plan = getAnalyzer(true).execute(plan) - - val expectedPlan = prevPlan.select( - CreateArray(Seq( - CreateNamedStruct(Seq( - Literal(att1.name), att1, - Literal("a_plus_1"), (att1 + 1))), - 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull - )).as("arr") - ) - - checkAnalysis(plan, expectedPlan) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) } test("SPARK-10534: resolve attribute references in order by clause") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index c21c6de32c..0c307b2b85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -243,6 +243,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) + checkMetadata(CreateStructUnsafe(Seq(a, b))) checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 067b0bac63..05e867bf5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -183,9 +183,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - // Wait until the struct is resolved. This will generate a nicer looking alias. - case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6141fab4af..f873f34a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -137,7 +137,7 @@ object ColumnStatStruct { private def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) - private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = { + private def getStruct(exprs: Seq[Expression]): CreateStruct = { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() @@ -168,7 +168,7 @@ object ColumnStatStruct { } } - def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match { + def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD)) case StringType => getStruct(stringColumnStat(attr, relativeSD)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index d496af686d..6741703d9d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -14,4 +14,4 @@ select 'foo' from myview where int_col == 0 group by 1; select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; -- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1; +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index dede3a09ce..9127bd4dd4 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -44,8 +44,8 @@ struct -- !query 5 -select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 -- !query 5 schema -struct> +struct> -- !query 5 output diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 90000445df..6eb571b91f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -190,12 +190,6 @@ private[hive] class TestHiveSparkSession( new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile) } - private def quoteHiveFile(path : String) = if (Utils.isWindows) { - getHiveFile(path).getPath.replace('\\', '/') - } else { - getHiveFile(path).getPath - } - def getWarehousePath(): String = { val tempConf = new SQLConf sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } @@ -231,16 +225,16 @@ private[hive] class TestHiveSparkSession( val hiveQTestUtilTables: Seq[TestTable] = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { sql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -250,7 +244,7 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) } @@ -275,7 +269,7 @@ private[hive] class TestHiveSparkSession( sql( s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' |INTO TABLE src_thrift """.stripMargin) }), @@ -314,7 +308,7 @@ private[hive] class TestHiveSparkSession( |) """.stripMargin.cmd, s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}' + |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' |INTO TABLE episodes """.stripMargin.cmd ), @@ -385,7 +379,7 @@ private[hive] class TestHiveSparkSession( TestTable("src_json", s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd) ) hiveQTestUtilTables.foreach(registerTestTable) diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index cdda29af50..de0116a4dc 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 12d18dc87c..c7f10e569f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.nio.charset.StandardCharsets import java.nio.file.{Files, NoSuchFileException, Paths} -import scala.io.Source import scala.util.control.NonFatal import org.apache.spark.sql.Column @@ -110,15 +109,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) } else { val goldenFileName = s"sqlgen/$answerFile.sql" - val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName) - if (resourceStream == null) { + val resourceFile = getClass.getClassLoader.getResource(goldenFileName) + if (resourceFile == null) { throw new NoSuchFileException(goldenFileName) } - val answerText = try { - Source.fromInputStream(resourceStream).mkString - } finally { - resourceStream.close - } + val path = resourceFile.getPath + val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) val sqls = answerText.split(separator) assert(sqls.length == 2, "Golden sql files should have a separator.") val expectedSQL = sqls(1).trim() -- cgit v1.2.3