From 7ad579ee972987863c09827554a6330aa54433b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 17 Dec 2014 12:43:51 -0800 Subject: [SPARK-3698][SQL] Fix case insensitive resolution of GetField. Based on #2543. Author: Michael Armbrust Closes #3724 from marmbrus/resolveGetField and squashes the following commits: 0a47aae [Michael Armbrust] Fix case insensitive resolution of GetField. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++++ .../apache/spark/sql/catalyst/expressions/complexTypes.scala | 8 +++++++- .../apache/spark/sql/hive/execution/HiveResolutionSuite.scala | 11 +++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ea9bb39786..3705fcc1f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.types.StructType /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -187,6 +188,15 @@ class Analyzer(catalog: Catalog, val result = q.resolveChildren(name, resolver).getOrElse(u) logDebug(s"Resolving $u to $result") result + + // Resolve field names using the resolver. + case f @ GetField(child, fieldName) if !f.resolved && child.resolved => + child.dataType match { + case StructType(fields) => + val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName)) + resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f) + case _ => f + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 917b346086..b12821d44b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -92,7 +92,13 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio lazy val ordinal = structType.fields.indexOf(field) - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] + override lazy val resolved = childrenResolved && fieldResolved + + /** Returns true only if the fieldName is found in the child struct. */ + private def fieldResolved = child.dataType match { + case StructType(fields) => fields.map(_.name).contains(fieldName) + case _ => false + } override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index ee9d08ff75..422e843d2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -27,6 +27,17 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { + + case class NestedData(a: Seq[NestedData2], B: NestedData2) + case class NestedData2(a: NestedData3, B: NestedData3) + case class NestedData3(a: Int, B: Int) + + test("SPARK-3698: case insensitive test for nested data") { + sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested") + // This should be successfully analyzed + sql("SELECT a[0].A.A from nested").queryExecution.analyzed + } + createQueryTest("table.attr", "SELECT src.key FROM src ORDER BY key LIMIT 1") -- cgit v1.2.3