aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-21 22:22:15 +0800
committerCheng Lian <lian@databricks.com>2016-03-21 22:22:15 +0800
commit17a3f00676ca02155557f6ee55a1565e96893792 (patch)
treef82a7bab8007eae9bf8d1b7c586ba3f9c8dd4c44 /sql
parent2c5b18fb0fdeabd378dd97e91f72d1eac4e21cc7 (diff)
downloadspark-17a3f00676ca02155557f6ee55a1565e96893792.tar.gz
spark-17a3f00676ca02155557f6ee55a1565e96893792.tar.bz2
spark-17a3f00676ca02155557f6ee55a1565e96893792.zip
[SPARK-14000][SQL] case class with a tuple field can't work in Dataset
## What changes were proposed in this pull request? When we validate an encoder, we may call `dataType` on unresolved expressions. This PR fix the validation so that we will resolve attributes first. ## How was this patch tested? a new test in `DatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #11816 from cloud-fan/encoder.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala13
4 files changed, 29 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 333a54ee76..ccc65b4e52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -600,7 +600,10 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
- private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
+ protected[sql] def resolveExpression(
+ expr: Expression,
+ plan: LogicalPlan,
+ throws: Boolean = false) = {
// Resolve expression in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 58f6d0eb9e..918233ddcd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -282,8 +282,14 @@ case class ExpressionEncoder[T](
// If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
// `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
// Note that, `BoundReference` contains the expected type, but here we need the actual type, so
- // we unbound it by the given `schema` and propagate the actual type to `GetStructField`.
- val unbound = fromRowExpression transform {
+ // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
+ // we resolve the `fromRowExpression`.
+ val resolved = SimpleAnalyzer.resolveExpression(
+ fromRowExpression,
+ LocalRelation(schema),
+ throws = true)
+
+ val unbound = resolved transform {
case b: BoundReference => schema(b.ordinal)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 44cdc8d881..c06dcc9867 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -110,7 +110,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
override def dataType: DataType = childSchema(ordinal).dataType
override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
- override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
+
+ override def toString: String = {
+ val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal"
+ s"$child.${name.getOrElse(fieldName)}"
+ }
+
override def sql: String =
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d7fa23651b..04d3a25fcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-case class OtherTuple(_1: String, _2: Int)
-
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -636,8 +634,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq(OuterObject.InnerClass("foo")).toDS(),
OuterObject.InnerClass("foo"))
}
+
+ test("SPARK-14000: case class with tuple type field") {
+ checkDataset(
+ Seq(TupleClass((1, "a"))).toDS(),
+ TupleClass(1, "a")
+ )
+ }
}
+case class OtherTuple(_1: String, _2: Int)
+
+case class TupleClass(data: (Int, String))
+
class OuterClass extends Serializable {
case class InnerClass(a: String)
}