aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-07-22 12:19:59 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-22 12:19:59 -0700
commit86f80e2b4759e574fe3eb91695f81b644db87242 (patch)
tree5b6d6fa49365f71ad5d2c93ee9e51aafc381ae42
parent76520955fddbda87a5c53d0a394dedc91dce67e8 (diff)
downloadspark-86f80e2b4759e574fe3eb91695f81b644db87242.tar.gz
spark-86f80e2b4759e574fe3eb91695f81b644db87242.tar.bz2
spark-86f80e2b4759e574fe3eb91695f81b644db87242.zip
[SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNamedStruct
JIRA: https://issues.apache.org/jira/browse/SPARK-9165 Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7537 from yjshen/array_struct_codegen and squashes the following commits: 3a6dce6 [Yijie Shen] use infix notion in createArray test 5e90f0a [Yijie Shen] resolve comments: classOf 39cefb8 [Yijie Shen] codegen for createArray createStruct & createNamedStruct
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala65
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala16
2 files changed, 76 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index f9fd04c02a..20b1eaab8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
* Returns an Array containing the evaluation of all children expressions.
*/
-case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback {
+case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege
children.map(_.eval(input))
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ s"""
+ boolean ${ev.isNull} = false;
+ $arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size});
+ """ +
+ children.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
+ if (${eval.isNull}) {
+ ${ev.primitive}.update($i, null);
+ } else {
+ ${ev.primitive}.update($i, ${eval.primitive});
+ }
+ """
+ }.mkString("\n")
+ }
+
override def prettyName: String = "array"
}
/**
* Returns a Row containing the evaluation of all children expressions.
- * TODO: [[CreateStruct]] does not support codegen.
*/
-case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
+case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
InternalRow(children.map(_.eval(input)): _*)
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val rowClass = classOf[GenericMutableRow].getName
+ s"""
+ boolean ${ev.isNull} = false;
+ final $rowClass ${ev.primitive} = new $rowClass(${children.size});
+ """ +
+ children.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
+ if (${eval.isNull}) {
+ ${ev.primitive}.update($i, null);
+ } else {
+ ${ev.primitive}.update($i, ${eval.primitive});
+ }
+ """
+ }.mkString("\n")
+ }
+
override def prettyName: String = "struct"
}
@@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
-case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
+case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
@@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with
InternalRow(valExprs.map(_.eval(input)): _*)
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val rowClass = classOf[GenericMutableRow].getName
+ s"""
+ boolean ${ev.isNull} = false;
+ final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
+ """ +
+ valExprs.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
+ if (${eval.isNull}) {
+ ${ev.primitive}.update($i, null);
+ } else {
+ ${ev.primitive}.update($i, ${eval.primitive});
+ }
+ """
+ }.mkString("\n")
+ }
+
override def prettyName: String = "named_struct"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index e304214363..a8aee8f634 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
}
+ test("CreateArray") {
+ val intSeq = Seq(5, 10, 15, 20, 25)
+ val longSeq = intSeq.map(_.toLong)
+ val strSeq = intSeq.map(_.toString)
+ checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
+ checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
+ checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)
+
+ val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType)
+ val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType)
+ val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType)
+ checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
+ checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
+ checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
+ }
+
test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)