diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-03-21 22:22:15 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-03-21 22:22:15 +0800 |
commit | 17a3f00676ca02155557f6ee55a1565e96893792 (patch) | |
tree | f82a7bab8007eae9bf8d1b7c586ba3f9c8dd4c44 /sql | |
parent | 2c5b18fb0fdeabd378dd97e91f72d1eac4e21cc7 (diff) | |
download | spark-17a3f00676ca02155557f6ee55a1565e96893792.tar.gz spark-17a3f00676ca02155557f6ee55a1565e96893792.tar.bz2 spark-17a3f00676ca02155557f6ee55a1565e96893792.zip |
[SPARK-14000][SQL] case class with a tuple field can't work in Dataset
## What changes were proposed in this pull request?
When we validate an encoder, we may call `dataType` on unresolved expressions. This PR fix the validation so that we will resolve attributes first.
## How was this patch tested?
a new test in `DatasetSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #11816 from cloud-fan/encoder.
Diffstat (limited to 'sql')
4 files changed, 29 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 333a54ee76..ccc65b4e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -600,7 +600,10 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { + protected[sql] def resolveExpression( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false) = { // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 58f6d0eb9e..918233ddcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -282,8 +282,14 @@ case class ExpressionEncoder[T]( // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. // Note that, `BoundReference` contains the expected type, but here we need the actual type, so - // we unbound it by the given `schema` and propagate the actual type to `GetStructField`. - val unbound = fromRowExpression transform { + // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after + // we resolve the `fromRowExpression`. + val resolved = SimpleAnalyzer.resolveExpression( + fromRowExpression, + LocalRelation(schema), + throws = true) + + val unbound = resolved transform { case b: BoundReference => schema(b.ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 44cdc8d881..c06dcc9867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def dataType: DataType = childSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable - override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" + + override def toString: String = { + val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal" + s"$child.${name.getOrElse(fieldName)}" + } + override def sql: String = child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d7fa23651b..04d3a25fcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -case class OtherTuple(_1: String, _2: Int) - class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -636,8 +634,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(OuterObject.InnerClass("foo")).toDS(), OuterObject.InnerClass("foo")) } + + test("SPARK-14000: case class with tuple type field") { + checkDataset( + Seq(TupleClass((1, "a"))).toDS(), + TupleClass(1, "a") + ) + } } +case class OtherTuple(_1: String, _2: Int) + +case class TupleClass(data: (Int, String)) + class OuterClass extends Serializable { case class InnerClass(a: String) } |