diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-12-16 13:18:56 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-12-16 13:20:12 -0800 |
commit | a783a8ed49814a09fde653433a3d6de398ddf888 (patch) | |
tree | 9d6a7be7840d682edfefd16a450ce4128c882a4c | |
parent | 1a8b2a17db7ab7a213d553079b83274aeebba86f (diff) | |
download | spark-a783a8ed49814a09fde653433a3d6de398ddf888.tar.gz spark-a783a8ed49814a09fde653433a3d6de398ddf888.tar.bz2 spark-a783a8ed49814a09fde653433a3d6de398ddf888.zip |
[SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10293 from cloud-fan/err-msg.
5 files changed, 93 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e50971173c..8102c93c6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -227,9 +227,10 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 363178b0e2..7a4401cf58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -244,9 +244,41 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + var maxOrdinal = -1 + fromRowExpression.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + val unbound = fromRowExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } } val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 10ce10aaf6..58f6a7ec8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,14 +104,14 @@ object ExtractValue { case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${name.getOrElse(field.name)}" + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 0289988342..815a03f7c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest { val innerCls = classOf[StringLongClass] val cls = classOf[ComplexClass] - val structType = new StructType().add("a", IntegerType).add("b", LongType) - val attrs = Seq('a.int, 'b.struct(structType)) + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, Seq( 'a.int.cast(LongType), If( - 'b.struct(structType).isNull, + 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), NewInstance( innerCls, Seq( toExternalString( - GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), - GetStructField('b.struct(structType), 1, Some("b"))), + GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), + GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), false, ObjectType(innerCls)) )), @@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[Long]) val cls = classOf[StringLongClass] - val structType = new StructType().add("a", StringType).add("b", ByteType) - val attrs = Seq('a.struct(structType), 'b.int) + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( classOf[Tuple2[_, _]], @@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest { NewInstance( cls, Seq( - toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), - GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), false, ObjectType(cls)), 'b.int.cast(LongType)), @@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest { compareExpressions(fromRowExpr, expected) } + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:bigint,c:int>\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<a:string> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string>\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<x:bigint> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:struct<x:bigint>>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + private def toExternalString(e: Expression): Expression = { Invoke(e, "toString", ObjectType(classOf[String]), Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 62fd47234b..9f1b19253e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } |