aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-30 07:58:49 -0700
committerDavies Liu <davies@databricks.com>2015-06-30 07:58:49 -0700
commit08fab4843845136358f3a7251e8d90135126b419 (patch)
tree59c5e50bb7f22038de5cc4c64879d6d61ff4e2df /sql
parent2ed0c0ac4686ea779f98713978e37b97094edc1c (diff)
downloadspark-08fab4843845136358f3a7251e8d90135126b419.tar.gz
spark-08fab4843845136358f3a7251e8d90135126b419.tar.bz2
spark-08fab4843845136358f3a7251e8d90135126b419.zip
[SPARK-8590] [SQL] add code gen for ExtractValue
TODO: use array instead of Seq as internal representation for `ArrayType` Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6982 from cloud-fan/extract-value and squashes the following commits: e203bc1 [Wenchen Fan] address comments 4da0f0b [Wenchen Fan] some clean up f679969 [Wenchen Fan] fix bug e64f942 [Wenchen Fan] remove generic e3f8427 [Wenchen Fan] fix style and address comments fc694e8 [Wenchen Fan] add code gen for extract value
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala131
11 files changed, 199 insertions, 101 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 5db2fcfcb2..dc0b4ac5cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
+ ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)});
"""
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index e5dc7b9b5c..aed48921bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def toString: String = s"($left $symbol $right)"
override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
+
/**
- * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
- * the same type. If either of the sub-expressions is null, the result of this computation
+ * Short hand for generating binary evaluation code.
+ * If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
@@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
- // TODO: Right now some timestamp tests fail if we enforce this...
- if (left.dataType != right.dataType) {
- // log.warn(s"${left.dataType} != ${right.dataType}")
- }
+ nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+ s"$result = ${f(eval1, eval2)};"
+ })
+ }
+ /**
+ * Short hand for generating binary evaluation code.
+ * If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ */
+ protected def nullSafeCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
- val resultCode = f(eval1.primitive, eval2.primitive)
-
+ val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
@@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
if (!${ev.isNull}) {
${eval2.code}
if (!${eval2.isNull}) {
- ${ev.primitive} = $resultCode;
+ $resultCode
} else {
${ev.isNull} = true;
}
@@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
+ nullSafeCodeGen(ctx, ev, (result, eval) => {
+ s"$result = ${f(eval)};"
+ })
+ }
+
+ /**
+ * Called by unary expressions to generate a code block that returns null if its parent returns
+ * null, and if not not null, use `f` to generate the expression.
+ */
+ protected def nullSafeCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String) => String): String = {
val eval = child.gen(ctx)
- // reuse the previous isNull
- ev.isNull = eval.isNull
+ val resultCode = f(ev.primitive, eval.primitive)
eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${ev.primitive} = ${f(eval.primitive)};
+ $resultCode
}
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4d7c95ffd1..3020e7fc96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -21,6 +21,7 @@ import scala.collection.Map
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
object ExtractValue {
@@ -38,7 +39,7 @@ object ExtractValue {
def apply(
child: Expression,
extraction: Expression,
- resolver: Resolver): ExtractValue = {
+ resolver: Resolver): Expression = {
(child.dataType, extraction) match {
case (StructType(fields), NonNullLiteral(v, StringType)) =>
@@ -73,7 +74,7 @@ object ExtractValue {
def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
g match {
case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
- case _ => Some((g.child, null))
+ case s: ExtractValueWithStruct => Some((s.child, null))
}
}
@@ -101,11 +102,11 @@ object ExtractValue {
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
* we don't need to do type check for them.
*/
-trait ExtractValue extends UnaryExpression {
- self: Product =>
+trait ExtractValue {
+ self: Expression =>
}
-abstract class ExtractValueWithStruct extends ExtractValue {
+abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue {
self: Product =>
def field: StructField
@@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
val baseValue = child.eval(input).asInstanceOf[InternalRow]
if (baseValue == null) null else baseValue(ordinal)
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, eval) => {
+ s"""
+ if ($eval.isNullAt($ordinal)) {
+ ${ev.isNull} = true;
+ } else {
+ $result = ${ctx.getColumn(eval, dataType, ordinal)};
+ }
+ """
+ })
+ }
}
/**
@@ -137,6 +150,7 @@ case class GetArrayStructFields(
containsNull: Boolean) extends ExtractValueWithStruct {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
+ override def nullable: Boolean = child.nullable || containsNull || field.nullable
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
@@ -146,18 +160,39 @@ case class GetArrayStructFields(
}
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arraySeqClass = "scala.collection.mutable.ArraySeq"
+ // TODO: consider using Array[_] for ArrayType child to avoid
+ // boxing of primitives
+ nullSafeCodeGen(ctx, ev, (result, eval) => {
+ s"""
+ final int n = $eval.size();
+ final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
+ for (int j = 0; j < n; j++) {
+ InternalRow row = (InternalRow) $eval.apply(j);
+ if (row != null && !row.isNullAt($ordinal)) {
+ values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
+ }
+ }
+ $result = (${ctx.javaType(dataType)}) values;
+ """
+ })
+ }
}
-abstract class ExtractValueWithOrdinal extends ExtractValue {
+abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue {
self: Product =>
def ordinal: Expression
+ def child: Expression
+
+ override def left: Expression = child
+ override def right: Expression = ordinal
/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
- override def foldable: Boolean = child.foldable && ordinal.foldable
override def toString: String = s"$child[$ordinal]"
- override def children: Seq[Expression] = child :: ordinal :: Nil
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
@@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
baseValue(index)
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+ s"""
+ final int index = (int)$eval2;
+ if (index >= $eval1.size() || index < 0) {
+ ${ev.isNull} = true;
+ } else {
+ $result = (${ctx.boxedType(dataType)})$eval1.apply(index);
+ }
+ """
+ })
+ }
}
/**
@@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression)
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+ s"""
+ if ($eval1.contains($eval2)) {
+ $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
+ } else {
+ ${ev.isNull} = true;
+ }
+ """
+ })
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 3d4d9e2d79..ae765c1653 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
- /** Name of the function for this expression on a [[Decimal]] type. */
- def decimalMethod: String = ""
override def dataType: DataType = left.dataType
@@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}
+ /** Name of the function for this expression on a [[Decimal]] type. */
+ def decimalMethod: String =
+ sys.error("BinaryArithmetics must override either decimalMethod or genCode")
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 57e0bede5d..bf6a6a1240 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -82,24 +82,24 @@ class CodeGenContext {
/**
* Returns the code to access a column in Row for a given DataType.
*/
- def getColumn(dataType: DataType, ordinal: Int): String = {
+ def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
val jt = javaType(dataType)
if (isPrimitiveType(jt)) {
- s"i.get${primitiveTypeName(jt)}($ordinal)"
+ s"$row.get${primitiveTypeName(jt)}($ordinal)"
} else {
- s"($jt)i.apply($ordinal)"
+ s"($jt)$row.apply($ordinal)"
}
}
/**
* Returns the code to update a column in Row for a given DataType.
*/
- def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
+ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
if (isPrimitiveType(jt)) {
- s"set${primitiveTypeName(jt)}($ordinal, $value)"
+ s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
} else {
- s"update($ordinal, $value)"
+ s"$row.update($ordinal, $value)"
}
}
@@ -127,6 +127,9 @@ class CodeGenContext {
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
+ case _: StructType => "InternalRow"
+ case _: ArrayType => s"scala.collection.Seq"
+ case _: MapType => s"scala.collection.Map"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 64ef357a4f..addb8023d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
- mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val code = s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index a022f3727b..da63f2fa97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
def funcName: String = name.toLowerCase
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
+ nullSafeCodeGen(ctx, ev, (result, eval) => {
+ s"""
+ ${ev.primitive} = java.lang.Math.${funcName}($eval);
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
- }
- """
+ """
+ })
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 386cf6a8df..98cd5aa814 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -69,10 +69,7 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
-
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def toString: String = s"NOT $child"
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index efc6f50b78..daa9f4403f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
*/
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
- override def nullable: Boolean = left.nullable || right.nullable
-
override def dataType: DataType = left.dataType
override def symbol: String = "++="
@@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
*/
case class CountSet(child: Expression) extends UnaryExpression {
- override def nullable: Boolean = child.nullable
-
override def dataType: DataType = LongType
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 8656cc334d..3148309a21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.types._
/**
- * Helper function to check for valid data types
+ * Helper functions to check for valid data types.
*/
object TypeUtils {
def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index b80911e725..3515d044b2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("GetArrayItem") {
+ val typeA = ArrayType(StringType)
+ val array = Literal.create(Seq("a", "b"), typeA)
testIntegralDataTypes { convert =>
- val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b")
}
+ val nullArray = Literal.create(null, typeA)
+ val nullInt = Literal.create(null, IntegerType)
+ checkEvaluation(GetArrayItem(nullArray, Literal(1)), null)
+ checkEvaluation(GetArrayItem(array, nullInt), null)
+ checkEvaluation(GetArrayItem(nullArray, nullInt), null)
+
+ val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType)))
+ checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}
- test("CreateStruct") {
- val row = InternalRow(1, 2, 3)
- val c1 = 'a.int.at(0).as("a")
- val c3 = 'c.int.at(2).as("c")
- checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row)
+ test("GetMapValue") {
+ val typeM = MapType(StringType, StringType)
+ val map = Literal.create(Map("a" -> "b"), typeM)
+ val nullMap = Literal.create(null, typeM)
+ val nullString = Literal.create(null, StringType)
+
+ checkEvaluation(GetMapValue(map, Literal("a")), "b")
+ checkEvaluation(GetMapValue(map, nullString), null)
+ checkEvaluation(GetMapValue(nullMap, nullString), null)
+ checkEvaluation(GetMapValue(map, nullString), null)
+
+ val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM))
+ checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}
- test("complex type") {
- val row = create_row(
- "^Ba*n", // 0
- null.asInstanceOf[UTF8String], // 1
- create_row("aa", "bb"), // 2
- Map("aa"->"bb"), // 3
- Seq("aa", "bb") // 4
- )
-
- val typeS = StructType(
- StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
- )
- val typeMap = MapType(StringType, StringType)
- val typeArray = ArrayType(StringType)
-
- checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
- Literal("aa")), "bb", row)
- checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
- checkEvaluation(
- GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
- checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
- Literal.create(null, StringType)), null, row)
-
- checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
- Literal(1)), "bb", row)
- checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
- checkEvaluation(
- GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row)
- checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
- Literal.create(null, IntegerType)), null, row)
-
- def getStructField(expr: Expression, fieldName: String): ExtractValue = {
+ test("GetStructField") {
+ val typeS = StructType(StructField("a", IntegerType) :: Nil)
+ val struct = Literal.create(create_row(1), typeS)
+ val nullStruct = Literal.create(null, typeS)
+
+ def getStructField(expr: Expression, fieldName: String): GetStructField = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
@@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
- ExtractValue(u.child, u.extraction, _ == _)
- }
+ checkEvaluation(getStructField(struct, "a"), 1)
+ checkEvaluation(getStructField(nullStruct, "a"), null)
+
+ val nestedStruct = Literal.create(create_row(create_row(1)),
+ StructType(StructField("a", typeS) :: Nil))
+ checkEvaluation(getStructField(nestedStruct, "a"), create_row(1))
+
+ val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil)
+ val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable)
+ val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable)
+
+ assert(getStructField(struct_fieldNotNullable, "a").nullable === false)
+ assert(getStructField(struct, "a").nullable === true)
+ assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true)
+ assert(getStructField(nullStruct, "a").nullable === true)
+ }
- checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
- checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)
+ test("GetArrayStructFields") {
+ val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
+ val nullArrayStruct = Literal.create(null, typeAS)
- val typeS_notNullable = StructType(
- StructField("a", StringType, nullable = false)
- :: StructField("b", StringType, nullable = false) :: Nil
- )
+ def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = {
+ expr.dataType match {
+ case ArrayType(StructType(fields), containsNull) =>
+ val field = fields.find(_.name == fieldName).get
+ GetArrayStructFields(expr, field, fields.indexOf(field), containsNull)
+ }
+ }
+
+ checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1))
+ checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
+ }
- assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true)
- assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
- === false)
+ test("CreateStruct") {
+ val row = create_row(1, 2, 3)
+ val c1 = 'a.int.at(0).as("a")
+ val c3 = 'c.int.at(2).as("c")
+ checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
+ }
- assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
- assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)
+ test("test dsl for complex type") {
+ def quickResolve(u: UnresolvedExtractValue): Expression = {
+ ExtractValue(u.child, u.extraction, _ == _)
+ }
- checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
- checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row)
- checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
+ checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")),
+ "b", create_row(Map("a" -> "b")))
+ checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
+ "b", create_row(Seq("a", "b")))
+ checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")),
+ 1, create_row(create_row(1)))
}
test("error message of ExtractValue") {