aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-16 13:18:56 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-16 13:20:12 -0800
commita783a8ed49814a09fde653433a3d6de398ddf888 (patch)
tree9d6a7be7840d682edfefd16a450ce4128c882a4c
parent1a8b2a17db7ab7a213d553079b83274aeebba86f (diff)
downloadspark-a783a8ed49814a09fde653433a3d6de398ddf888.tar.gz
spark-a783a8ed49814a09fde653433a3d6de398ddf888.tar.bz2
spark-a783a8ed49814a09fde653433a3d6de398ddf888.zip
[SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder
Author: Wenchen Fan <wenchen@databricks.com> Closes #10293 from cloud-fan/err-msg.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala2
5 files changed, 93 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index e50971173c..8102c93c6f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -227,9 +227,10 @@ package object dsl {
AttributeReference(s, mapType, nullable = true)()
/** Creates a new AttributeReference of type struct */
- def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
def struct(structType: StructType): AttributeReference =
AttributeReference(s, structType, nullable = true)()
+ def struct(attrs: AttributeReference*): AttributeReference =
+ struct(StructType.fromAttributes(attrs))
}
implicit class DslAttribute(a: AttributeReference) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 363178b0e2..7a4401cf58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -244,9 +244,41 @@ case class ExpressionEncoder[T](
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(schema)
+ def fail(st: StructType, maxOrdinal: Int): Unit = {
+ throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
+ " - Target schema: " + this.schema.simpleString)
+ }
+
+ var maxOrdinal = -1
+ fromRowExpression.foreach {
+ case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
+ case _ =>
+ }
+ if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
+ fail(StructType.fromAttributes(schema), maxOrdinal)
+ }
+
val unbound = fromRowExpression transform {
- case b: BoundReference => positionToAttribute(b.ordinal)
+ case b: BoundReference => schema(b.ordinal)
+ }
+
+ val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
+ unbound.foreach {
+ case g: GetStructField =>
+ val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
+ if (maxOrdinal < g.ordinal) {
+ exprToMaxOrdinal.update(g.child, g.ordinal)
+ }
+ case _ =>
+ }
+ exprToMaxOrdinal.foreach {
+ case (expr, maxOrdinal) =>
+ val schema = expr.dataType.asInstanceOf[StructType]
+ if (maxOrdinal != schema.length - 1) {
+ fail(schema, maxOrdinal)
+ }
}
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 10ce10aaf6..58f6a7ec8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -104,14 +104,14 @@ object ExtractValue {
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression {
- private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+ private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
- override def dataType: DataType = field.dataType
- override def nullable: Boolean = child.nullable || field.nullable
- override def toString: String = s"$child.${name.getOrElse(field.name)}"
+ override def dataType: DataType = childSchema(ordinal).dataType
+ override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
+ override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
protected override def nullSafeEval(input: Any): Any =
- input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
+ input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 0289988342..815a03f7c1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest {
val innerCls = classOf[StringLongClass]
val cls = classOf[ComplexClass]
- val structType = new StructType().add("a", IntegerType).add("b", LongType)
- val attrs = Seq('a.int, 'b.struct(structType))
+ val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
'a.int.cast(LongType),
If(
- 'b.struct(structType).isNull,
+ 'b.struct('a.int, 'b.long).isNull,
Literal.create(null, ObjectType(innerCls)),
NewInstance(
innerCls,
Seq(
toExternalString(
- GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)),
- GetStructField('b.struct(structType), 1, Some("b"))),
+ GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
+ GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))),
false,
ObjectType(innerCls))
)),
@@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest {
ExpressionEncoder[Long])
val cls = classOf[StringLongClass]
- val structType = new StructType().add("a", StringType).add("b", ByteType)
- val attrs = Seq('a.struct(structType), 'b.int)
+ val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
classOf[Tuple2[_, _]],
@@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest {
NewInstance(
cls,
Seq(
- toExternalString(GetStructField('a.struct(structType), 0, Some("a"))),
- GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)),
+ toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
+ GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)),
false,
ObjectType(cls)),
'b.int.cast(LongType)),
@@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest {
compareExpressions(fromRowExpr, expected)
}
+ test("the real number of fields doesn't match encoder schema: tuple encoder") {
+ val encoder = ExpressionEncoder[(String, Long)]
+
+ {
+ val attrs = Seq('a.string, 'b.long, 'c.int)
+ assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string,b:bigint,c:int>\n" +
+ " - Target schema: struct<_1:string,_2:bigint>")
+ }
+
+ {
+ val attrs = Seq('a.string)
+ assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ "Try to map struct<a:string> to Tuple2, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string>\n" +
+ " - Target schema: struct<_1:string,_2:bigint>")
+ }
+ }
+
+ test("the real number of fields doesn't match encoder schema: nested tuple encoder") {
+ val encoder = ExpressionEncoder[(String, (Long, String))]
+
+ {
+ val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
+ assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
+ " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+ }
+
+ {
+ val attrs = Seq('a.string, 'b.struct('x.long))
+ assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ "Try to map struct<x:bigint> to Tuple2, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
+ " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+ }
+ }
+
private def toExternalString(e: Expression): Expression = {
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 62fd47234b..9f1b19253e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
"b", create_row(Map("a" -> "b")))
checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
"b", create_row(Seq("a", "b")))
- checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")),
+ checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")),
1, create_row(create_row(1)))
}