diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-06-30 07:58:49 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-30 07:58:49 -0700 |
commit | 08fab4843845136358f3a7251e8d90135126b419 (patch) | |
tree | 59c5e50bb7f22038de5cc4c64879d6d61ff4e2df /sql | |
parent | 2ed0c0ac4686ea779f98713978e37b97094edc1c (diff) | |
download | spark-08fab4843845136358f3a7251e8d90135126b419.tar.gz spark-08fab4843845136358f3a7251e8d90135126b419.tar.bz2 spark-08fab4843845136358f3a7251e8d90135126b419.zip |
[SPARK-8590] [SQL] add code gen for ExtractValue
TODO: use array instead of Seq as internal representation for `ArrayType`
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #6982 from cloud-fan/extract-value and squashes the following commits:
e203bc1 [Wenchen Fan] address comments
4da0f0b [Wenchen Fan] some clean up
f679969 [Wenchen Fan] fix bug
e64f942 [Wenchen Fan] remove generic
e3f8427 [Wenchen Fan] fix style and address comments
fc694e8 [Wenchen Fan] add code gen for extract value
Diffstat (limited to 'sql')
11 files changed, 199 insertions, 101 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5db2fcfcb2..dc0b4ac5cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e5dc7b9b5c..aed48921bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe + /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * * @param f accepts two variable names and returns Java code to compute the output. @@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ctx: CodeGenContext, ev: GeneratedExpressionCode, f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (left.dataType != right.dataType) { - // log.warn(s"${left.dataType} != ${right.dataType}") - } + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitive, eval2.primitive) - + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; @@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.isNull}) { ${eval2.code} if (!${eval2.isNull}) { - ${ev.primitive} = $resultCode; + $resultCode } else { ${ev.isNull} = true; } @@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ctx: CodeGenContext, ev: GeneratedExpressionCode, f: String => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s"$result = ${f(eval)};" + }) + } + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval = child.gen(ctx) - // reuse the previous isNull - ev.isNull = eval.isNull + val resultCode = f(ev.primitive, eval.primitive) eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = ${f(eval.primitive)}; + $resultCode } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d7c95ffd1..3020e7fc96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -21,6 +21,7 @@ import scala.collection.Map import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ object ExtractValue { @@ -38,7 +39,7 @@ object ExtractValue { def apply( child: Expression, extraction: Expression, - resolver: Resolver): ExtractValue = { + resolver: Resolver): Expression = { (child.dataType, extraction) match { case (StructType(fields), NonNullLiteral(v, StringType)) => @@ -73,7 +74,7 @@ object ExtractValue { def unapply(g: ExtractValue): Option[(Expression, Expression)] = { g match { case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case _ => Some((g.child, null)) + case s: ExtractValueWithStruct => Some((s.child, null)) } } @@ -101,11 +102,11 @@ object ExtractValue { * Note: concrete extract value expressions are created only by `ExtractValue.apply`, * we don't need to do type check for them. */ -trait ExtractValue extends UnaryExpression { - self: Product => +trait ExtractValue { + self: Expression => } -abstract class ExtractValueWithStruct extends ExtractValue { +abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { self: Product => def field: StructField @@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + $result = ${ctx.getColumn(eval, dataType, ordinal)}; + } + """ + }) + } } /** @@ -137,6 +150,7 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable || containsNull || field.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] @@ -146,18 +160,39 @@ case class GetArrayStructFields( } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = "scala.collection.mutable.ArraySeq" + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + final int n = $eval.size(); + final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n); + for (int j = 0; j < n; j++) { + InternalRow row = (InternalRow) $eval.apply(j); + if (row != null && !row.isNullAt($ordinal)) { + values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + } + } + $result = (${ctx.javaType(dataType)}) values; + """ + }) + } } -abstract class ExtractValueWithOrdinal extends ExtractValue { +abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { self: Product => def ordinal: Expression + def child: Expression + + override def left: Expression = child + override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable override def toString: String = s"$child[$ordinal]" - override def children: Seq[Expression] = child :: ordinal :: Nil override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression) baseValue(index) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + final int index = (int)$eval2; + if (index >= $eval1.size() || index < 0) { + ${ev.isNull} = true; + } else { + $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + } + """ + }) + } } /** @@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression) val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + if ($eval1.contains($eval2)) { + $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + } else { + ${ev.isNull} = true; + } + """ + }) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3d4d9e2d79..ae765c1653 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => - /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = + sys.error("BinaryArithmetics must override either decimalMethod or genCode") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 57e0bede5d..bf6a6a1240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -82,24 +82,24 @@ class CodeGenContext { /** * Returns the code to access a column in Row for a given DataType. */ - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"i.get${primitiveTypeName(jt)}($ordinal)" + s"$row.get${primitiveTypeName(jt)}($ordinal)" } else { - s"($jt)i.apply($ordinal)" + s"($jt)$row.apply($ordinal)" } } /** * Returns the code to update a column in Row for a given DataType. */ - def setColumn(dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"set${primitiveTypeName(jt)}($ordinal, $value)" + s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" } else { - s"update($ordinal, $value)" + s"$row.update($ordinal, $value)" } } @@ -127,6 +127,9 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType + case _: StructType => "InternalRow" + case _: ArrayType => s"scala.collection.Seq" + case _: MapType => s"scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 64ef357a4f..addb8023d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a022f3727b..da63f2fa97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + ${ev.primitive} = java.lang.Math.${funcName}($eval); if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; } - } - """ + """ + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 386cf6a8df..98cd5aa814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,10 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index efc6f50b78..daa9f4403f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres */ case class CountSet(child: Expression) extends UnaryExpression { - override def nullable: Boolean = child.nullable - override def dataType: DataType = LongType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 8656cc334d..3148309a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types._ /** - * Helper function to check for valid data types + * Helper functions to check for valid data types. */ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b80911e725..3515d044b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayItem") { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) testIntegralDataTypes { convert => - val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") } + val nullArray = Literal.create(null, typeA) + val nullInt = Literal.create(null, IntegerType) + checkEvaluation(GetArrayItem(nullArray, Literal(1)), null) + checkEvaluation(GetArrayItem(array, nullInt), null) + checkEvaluation(GetArrayItem(nullArray, nullInt), null) + + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } - test("CreateStruct") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row) + test("GetMapValue") { + val typeM = MapType(StringType, StringType) + val map = Literal.create(Map("a" -> "b"), typeM) + val nullMap = Literal.create(null, typeM) + val nullString = Literal.create(null, StringType) + + checkEvaluation(GetMapValue(map, Literal("a")), "b") + checkEvaluation(GetMapValue(map, nullString), null) + checkEvaluation(GetMapValue(nullMap, nullString), null) + checkEvaluation(GetMapValue(map, nullString), null) + + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) + checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { + test("GetStructField") { + val typeS = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), typeS) + val nullStruct = Literal.create(null, typeS) + + def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } + checkEvaluation(getStructField(struct, "a"), 1) + checkEvaluation(getStructField(nullStruct, "a"), null) + + val nestedStruct = Literal.create(create_row(create_row(1)), + StructType(StructField("a", typeS) :: Nil)) + checkEvaluation(getStructField(nestedStruct, "a"), create_row(1)) + + val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil) + val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable) + val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable) + + assert(getStructField(struct_fieldNotNullable, "a").nullable === false) + assert(getStructField(struct, "a").nullable === true) + assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true) + assert(getStructField(nullStruct, "a").nullable === true) + } - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + test("GetArrayStructFields") { + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) + val nullArrayStruct = Literal.create(null, typeAS) - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) + def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { + expr.dataType match { + case ArrayType(StructType(fields), containsNull) => + val field = fields.find(_.name == fieldName).get + GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + } + } + + checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + } - assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) + test("CreateStruct") { + val row = create_row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + } - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + test("test dsl for complex type") { + def quickResolve(u: UnresolvedExtractValue): Expression = { + ExtractValue(u.child, u.extraction, _ == _) + } - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + "b", create_row(Map("a" -> "b"))) + checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + "b", create_row(Seq("a", "b"))) + checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + 1, create_row(create_row(1))) } test("error message of ExtractValue") { |