aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-11-01 17:30:37 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-01 17:30:37 +0100
commit0cba535af3c65618f342fa2d7db9647f5e6f6f1b (patch)
tree5b16f0dd27997083403ef7c621c001e615189e2a
parent5441a6269e00e3903ae6c1ea8deb4ddf3d2e9975 (diff)
downloadspark-0cba535af3c65618f342fa2d7db9647f5e6f6f1b.tar.gz
spark-0cba535af3c65618f342fa2d7db9647f5e6f6f1b.tar.bz2
spark-0cba535af3c65618f342fa2d7db9647f5e6f6f1b.zip
Revert "[SPARK-16839][SQL] redundant aliases after cleanupAliases"
This reverts commit 5441a6269e00e3903ae6c1ea8deb4ddf3d2e9975.
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala211
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala4
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/group-by.sql2
-rw-r--r--sql/core/src/test/resources/sql-tests/results/group-by.sql.out4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala20
-rw-r--r--sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala12
15 files changed, 200 insertions, 170 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 5002655fc0..9289db57b6 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1222,16 +1222,16 @@ test_that("column functions", {
# Test struct()
df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)),
schema = c("a", "b", "c"))
- result <- collect(select(df, alias(struct("a", "c"), "d")))
+ result <- collect(select(df, struct("a", "c")))
expected <- data.frame(row.names = 1:2)
- expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)),
- listToStruct(list(a = 4L, c = 6L)))
+ expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
+ listToStruct(list(a = 4L, c = 6L)))
expect_equal(result, expected)
- result <- collect(select(df, alias(struct(df$a, df$b), "d")))
+ result <- collect(select(df, struct(df$a, df$b)))
expected <- data.frame(row.names = 1:2)
- expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)),
- listToStruct(list(a = 4L, b = 5L)))
+ expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
+ listToStruct(list(a = 4L, b = 5L)))
expect_equal(result, expected)
# Test encode(), decode()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 5011f2fdbf..f8f4799322 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.trees.{TreeNodeRef}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.types._
@@ -83,7 +83,6 @@ class Analyzer(
ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
- ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
@@ -654,12 +653,11 @@ class Analyzer(
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
- case c: CreateNamedStruct if containsStar(c.valExprs) =>
- val newChildren = c.children.grouped(2).flatMap {
- case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
- case kv => kv
- }
- c.copy(children = newChildren.toList )
+ case c: CreateStruct if containsStar(c.children) =>
+ c.copy(children = c.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
@@ -1143,7 +1141,7 @@ class Analyzer(
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
// Get the left hand side expressions.
val expressions = e match {
- case cns : CreateNamedStruct => cns.valExprs
+ case CreateStruct(exprs) => exprs
case expr => Seq(expr)
}
resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
@@ -2074,8 +2072,18 @@ object EliminateUnions extends Rule[LogicalPlan] {
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
+ var stop = false
e.transformDown {
- case Alias(child, _) => child
+ // CreateStruct is a special case, we need to retain its top level Aliases as they decide the
+ // name of StructField. We also need to stop transform down this expression, or the Aliases
+ // under CreateStruct will be mistakenly trimmed.
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
}
}
@@ -2108,8 +2116,15 @@ object CleanupAliases extends Rule[LogicalPlan] {
case a: AppendColumns => a
case other =>
+ var stop = false
other transformExpressionsDown {
- case Alias(child, _) => child
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
}
}
}
@@ -2202,19 +2217,3 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}
-
-/**
- * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
- */
-object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
- case e: CreateNamedStruct if !e.resolved =>
- val children = e.children.grouped(2).flatMap {
- case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
- Seq(Literal(e.name), e)
- case kv =>
- kv
- }
- CreateNamedStruct(children.toList)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index b028d07fb8..3e836ca375 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -357,7 +357,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
- CreateStruct.registryEntry,
+ expression[CreateStruct]("struct"),
// misc functions
expression[AssertTrue]("assert_true"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 03e054d098..a81fa1ce3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -119,6 +119,7 @@ object UnsafeProjection {
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(unsafeExprs)
@@ -144,6 +145,7 @@ object UnsafeProjection {
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index e9623f96e1..917aa08731 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -18,11 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -174,70 +172,101 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
}
/**
- * An expression representing a not yet available attribute name. This expression is unevaluable
- * and as its name suggests it is a temporary place holder until we're able to determine the
- * actual attribute name.
+ * Returns a Row containing the evaluation of all children expressions.
*/
-case object NamePlaceholder extends LeafExpression with Unevaluable {
- override lazy val resolved: Boolean = false
- override def foldable: Boolean = false
+@ExpressionDescription(
+ usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.")
+case class CreateStruct(children: Seq[Expression]) extends Expression {
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override lazy val dataType: StructType = {
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
+ }
+ }
+ StructType(fields)
+ }
+
override def nullable: Boolean = false
- override def dataType: DataType = StringType
- override def prettyName: String = "NamePlaceholder"
- override def toString: String = prettyName
-}
-/**
- * Returns a Row containing the evaluation of all children expressions.
- */
-object CreateStruct extends FunctionBuilder {
- def apply(children: Seq[Expression]): CreateNamedStruct = {
- CreateNamedStruct(children.zipWithIndex.flatMap {
- case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
- case (e: NamedExpression, _) => Seq(NamePlaceholder, e)
- case (e, index) => Seq(Literal(s"col${index + 1}"), e)
- })
+ override def eval(input: InternalRow): Any = {
+ InternalRow(children.map(_.eval(input)): _*)
}
- /**
- * Entry to use in the function registry.
- */
- val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
- val info: ExpressionInfo = new ExpressionInfo(
- "org.apache.spark.sql.catalyst.expressions.NamedStruct",
- "struct",
- "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
- "")
- ("struct", (info, this))
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val rowClass = classOf[GenericInternalRow].getName
+ val values = ctx.freshName("values")
+ ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
+ ev.copy(code = s"""
+ boolean ${ev.isNull} = false;
+ this.$values = new Object[${children.size}];""" +
+ ctx.splitExpressions(
+ ctx.INPUT_ROW,
+ children.zipWithIndex.map { case (e, i) =>
+ val eval = e.genCode(ctx)
+ eval.code + s"""
+ if (${eval.isNull}) {
+ $values[$i] = null;
+ } else {
+ $values[$i] = ${eval.value};
+ }"""
+ }) +
+ s"""
+ final InternalRow ${ev.value} = new $rowClass($values);
+ this.$values = null;
+ """)
}
+
+ override def prettyName: String = "struct"
}
+
/**
- * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
+ * Creates a struct with the given field names and values
+ *
+ * @param children Seq(name1, val1, name2, val2, ...)
*/
-trait CreateNamedStructLike extends Expression {
- lazy val (nameExprs, valExprs) = children.grouped(2).map {
- case Seq(name, value) => (name, value)
- }.toList.unzip
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
+// scalastyle:on line.size.limit
+case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
- lazy val names = nameExprs.map(_.eval(EmptyRow))
+ /**
+ * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
+ * StructType.
+ */
+ def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
+ case (v, n) => Alias(v, n.toString)()
+ }
- override def nullable: Boolean = false
+ private lazy val (nameExprs, valExprs) =
+ children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
- override def foldable: Boolean = valExprs.forall(_.foldable)
+ private lazy val names = nameExprs.map(_.eval(EmptyRow))
override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map {
- case (name, expr) =>
- val metadata = expr match {
- case ne: NamedExpression => ne.metadata
- case _ => Metadata.empty
- }
- StructField(name.toString, expr.dataType, expr.nullable, metadata)
+ case (name, valExpr: NamedExpression) =>
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, valExpr.metadata)
+ case (name, valExpr) =>
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}
+ override def foldable: Boolean = valExprs.forall(_.foldable)
+
+ override def nullable: Boolean = false
+
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
@@ -245,8 +274,8 @@ trait CreateNamedStructLike extends Expression {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
- "Only foldable StringType expressions are allowed to appear at odd position, got:" +
- s" ${invalidNames.mkString(",")}")
+ s"Only foldable StringType expressions are allowed to appear at odd position , got :" +
+ s" ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {
@@ -255,29 +284,9 @@ trait CreateNamedStructLike extends Expression {
}
}
- /**
- * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
- * StructType.
- */
- def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
- case (v, n) => Alias(v, n.toString)()
- }
-
override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
-}
-
-/**
- * Creates a struct with the given field names and values
- *
- * @param children Seq(name1, val1, name2, val2, ...)
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
-// scalastyle:on line.size.limit
-case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
@@ -308,13 +317,75 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
}
/**
+ * Returns a Row containing the evaluation of all children expressions. This is a variant that
+ * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
+ * this expression automatically at runtime.
+ */
+case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override lazy val resolved: Boolean = childrenResolved
+
+ override lazy val dataType: StructType = {
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
+ }
+ }
+ StructType(fields)
+ }
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ InternalRow(children.map(_.eval(input)): _*)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val eval = GenerateUnsafeProjection.createCode(ctx, children)
+ ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value)
+ }
+
+ override def prettyName: String = "struct_unsafe"
+}
+
+
+/**
* Creates a struct with the given field names and values. This is a variant that returns
* UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
* this expression automatically at runtime.
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
-case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
+case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression {
+
+ private lazy val (nameExprs, valExprs) =
+ children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
+
+ private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+
+ override lazy val dataType: StructType = {
+ val fields = names.zip(valExprs).map {
+ case (name, valExpr: NamedExpression) =>
+ StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata)
+ case (name, valExpr) =>
+ StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+ }
+ StructType(fields)
+ }
+
+ override def foldable: Boolean = valExprs.forall(_.foldable)
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ InternalRow(valExprs.map(_.eval(input)): _*)
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 35aca91cf8..38e9bb6c16 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -681,8 +681,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// inline table comes in two styles:
// style 1: values (1), (2), (3) -- multiple columns are supported
// style 2: values 1, 2, 3 -- only a single column is supported here
- case struct: CreateNamedStruct => struct.valExprs // style 1
- case child => Seq(child) // style 2
+ case CreateStruct(children) => children // style 1
+ case child => Seq(child) // style 2
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 817de48de2..590774c043 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.ShouldMatchers
-
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -27,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
-
-class AnalysisSuite extends AnalysisTest with ShouldMatchers {
+class AnalysisSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
test("union project *") {
@@ -221,36 +218,9 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
// CreateStruct is a special case that we should not trim Alias for it.
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
- expected = testRelation.select(CreateNamedStruct(Seq(
- Literal(a.name), a,
- Literal("a+1"), (a + 1))).as("col"))
- checkAnalysis(plan, expected)
- }
-
- test("Analysis may leave unnecassary aliases") {
- val att1 = testRelation.output.head
- var plan = testRelation.select(
- CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"),
- att1
- )
- val prevPlan = getAnalyzer(true).execute(plan)
- plan = prevPlan.select(CreateArray(Seq(
- CreateStruct(Seq(att1, (att1 + 1).as("a_plus_1"))).as("col1"),
- /** alias should be eliminated by [[CleanupAliases]] */
- "col".attr.as("col2")
- )).as("arr"))
- plan = getAnalyzer(true).execute(plan)
-
- val expectedPlan = prevPlan.select(
- CreateArray(Seq(
- CreateNamedStruct(Seq(
- Literal(att1.name), att1,
- Literal("a_plus_1"), (att1 + 1))),
- 'col.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull
- )).as("arr")
- )
-
- checkAnalysis(plan, expectedPlan)
+ checkAnalysis(plan, plan)
+ plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
+ checkAnalysis(plan, plan)
}
test("SPARK-10534: resolve attribute references in order by clause") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index c21c6de32c..0c307b2b85 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -243,6 +243,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
+ checkMetadata(CreateStructUnsafe(Seq(a, b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 067b0bac63..05e867bf5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -183,9 +183,6 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
UnresolvedAlias(a, Some(Column.generateAlias))
- // Wait until the struct is resolved. This will generate a nicer looking alias.
- case struct: CreateNamedStructLike => UnresolvedAlias(struct)
-
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 6141fab4af..f873f34a84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -137,7 +137,7 @@ object ColumnStatStruct {
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
- private def getStruct(exprs: Seq[Expression]): CreateNamedStruct = {
+ private def getStruct(exprs: Seq[Expression]): CreateStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
@@ -168,7 +168,7 @@ object ColumnStatStruct {
}
}
- def apply(attr: Attribute, relativeSD: Double): CreateNamedStruct = attr.dataType match {
+ def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match {
// Use aggregate functions to compute statistics we need.
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
case StringType => getStruct(stringColumnStat(attr, relativeSD))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index d496af686d..6741703d9d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -14,4 +14,4 @@ select 'foo' from myview where int_col == 0 group by 1;
select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1;
-- group-by should not produce any rows (sort aggregate).
-select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1;
+select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index dede3a09ce..9127bd4dd4 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -44,8 +44,8 @@ struct<foo:string,approx_count_distinct(int_col):bigint>
-- !query 5
-select 'foo', max(struct(int_col)) as agg_struct from myview where int_col == 0 group by 1
+select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1
-- !query 5 schema
-struct<foo:string,agg_struct:struct<int_col:int>>
+struct<foo:string,max(struct(int_col)):struct<int_col:int>>
-- !query 5 output
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 90000445df..6eb571b91f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -190,12 +190,6 @@ private[hive] class TestHiveSparkSession(
new File(Thread.currentThread().getContextClassLoader.getResource(path).getFile)
}
- private def quoteHiveFile(path : String) = if (Utils.isWindows) {
- getHiveFile(path).getPath.replace('\\', '/')
- } else {
- getHiveFile(path).getPath
- }
-
def getWarehousePath(): String = {
val tempConf = new SQLConf
sc.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) }
@@ -231,16 +225,16 @@ private[hive] class TestHiveSparkSession(
val hiveQTestUtilTables: Seq[TestTable] = Seq(
TestTable("src",
"CREATE TABLE src (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
TestTable("src1",
"CREATE TABLE src1 (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
TestTable("srcpart", () => {
sql(
"CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
sql(
- s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}'
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
|OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
""".stripMargin)
}
@@ -250,7 +244,7 @@ private[hive] class TestHiveSparkSession(
"CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
sql(
- s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}'
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
|OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
""".stripMargin)
}
@@ -275,7 +269,7 @@ private[hive] class TestHiveSparkSession(
sql(
s"""
- |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}'
+ |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}'
|INTO TABLE src_thrift
""".stripMargin)
}),
@@ -314,7 +308,7 @@ private[hive] class TestHiveSparkSession(
|)
""".stripMargin.cmd,
s"""
- |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/episodes.avro")}'
+ |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}'
|INTO TABLE episodes
""".stripMargin.cmd
),
@@ -385,7 +379,7 @@ private[hive] class TestHiveSparkSession(
TestTable("src_json",
s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE
""".stripMargin.cmd,
- s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
)
hiveQTestUtilTables.foreach(registerTestTable)
diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
index cdda29af50..de0116a4dc 100644
--- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
+++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql
@@ -7,4 +7,4 @@ having b.key in (select a.key
where a.value > 'val_9' and a.value = min(b.value))
order by b.key
--------------------------------------------------------------------------------
-SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (named_struct('gen_attr_0', `gen_attr_0`, 'gen_attr_4', `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b
+SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS b
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index 12d18dc87c..c7f10e569f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, NoSuchFileException, Paths}
-import scala.io.Source
import scala.util.control.NonFatal
import org.apache.spark.sql.Column
@@ -110,15 +109,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
Files.write(path, answerText.getBytes(StandardCharsets.UTF_8))
} else {
val goldenFileName = s"sqlgen/$answerFile.sql"
- val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName)
- if (resourceStream == null) {
+ val resourceFile = getClass.getClassLoader.getResource(goldenFileName)
+ if (resourceFile == null) {
throw new NoSuchFileException(goldenFileName)
}
- val answerText = try {
- Source.fromInputStream(resourceStream).mkString
- } finally {
- resourceStream.close
- }
+ val path = resourceFile.getPath
+ val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8)
val sqls = answerText.split(separator)
assert(sqls.length == 2, "Golden sql files should have a separator.")
val expectedSQL = sqls(1).trim()