aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-02 10:21:14 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-02 10:21:14 -0700
commit0513c3ac93e0a25d6eedbafe6c0561e71c92880a (patch)
tree53952439f22f1c6c13a8c477343ea0172a67a577 /sql/catalyst/src
parent214d1be4fd4a34399b6a2adb2618784de459a48d (diff)
downloadspark-0513c3ac93e0a25d6eedbafe6c0561e71c92880a.tar.gz
spark-0513c3ac93e0a25d6eedbafe6c0561e71c92880a.tar.bz2
spark-0513c3ac93e0a25d6eedbafe6c0561e71c92880a.zip
[SPARK-14637][SQL] object expressions cleanup
## What changes were proposed in this pull request? Simplify and clean up some object expressions: 1. simplify the logic to handle `propagateNull` 2. add `propagateNull` parameter to `Invoke` 3. simplify the unbox logic in `Invoke` 4. other minor cleanup TODO: simplify `MapObjects` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12399 from cloud-fan/object.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala218
1 files changed, 100 insertions, 118 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 1e418540a2..523eed825f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -64,33 +64,29 @@ case class StaticInvoke(
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
- if (propagateNull) {
- val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
- } else {
- ""
- }
-
- val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
- ev.copy(code = s"""
- ${argGen.map(_.code).mkString("\n")}
-
- boolean ${ev.isNull} = !$argsNonNull;
- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ val callFunc = s"$objectName.$functionName($argString)"
- if ($argsNonNull) {
- ${ev.value} = $objectName.$functionName($argString);
- $objNullCheck
- }
- """)
+ val setIsNull = if (propagateNull && arguments.nonEmpty) {
+ s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
- ev.copy(code = s"""
- ${argGen.map(_.code).mkString("\n")}
+ s"boolean ${ev.isNull} = false;"
+ }
- $javaType ${ev.value} = $objectName.$functionName($argString);
- final boolean ${ev.isNull} = ${ev.value} == null;
- """)
+ // If the function can return null, we do an extra check to make sure our null bit is still set
+ // correctly.
+ val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
}
+
+ val code = s"""
+ ${argGen.map(_.code).mkString("\n")}
+ $setIsNull
+ final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
+ $postNullCheck
+ """
+ ev.copy(code = code)
}
}
@@ -111,7 +107,8 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
- arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
+ arguments: Seq[Expression] = Nil,
+ propagateNull: Boolean = true) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
@@ -130,60 +127,53 @@ case class Invoke(
case _ => None
}
- lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
- case (IntegerType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Integer)$s).intValue()"
- case (LongType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Long)$s).longValue()"
- case (FloatType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Float)$s).floatValue()"
- case (ShortType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Short)$s).shortValue()"
- case (ByteType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Byte)$s).byteValue()"
- case (DoubleType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Double)$s).doubleValue()"
- case (BooleanType, "java.lang.Object") => (s: String) =>
- s"((java.lang.Boolean)$s).booleanValue()"
- case _ => identity[String] _
- }
-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
- // If the function can return null, we do an extra check to make sure our null bit is still set
- // correctly.
- val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"boolean ${ev.isNull} = ${ev.value} == null;"
+ val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
+ s"${obj.value}.$functionName($argString)"
} else {
- ev.isNull = obj.isNull
- ""
+ s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
}
- val value = unboxer(s"${obj.value}.$functionName($argString)")
+ val setIsNull = if (propagateNull && arguments.nonEmpty) {
+ s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
+ } else {
+ s"boolean ${ev.isNull} = ${obj.isNull};"
+ }
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
- s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
+ s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
} else {
s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try {
- ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
+ ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
} catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e);
}
"""
}
- ev.copy(code = s"""
+ // If the function can return null, we do an extra check to make sure our null bit is still set
+ // correctly.
+ val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
+ }
+
+ val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
+ $setIsNull
$evaluate
- $objNullCheck
- """)
+ $postNullCheck
+ """
+ ev.copy(code = code)
}
override def toString: String = s"$targetObject.$functionName"
@@ -246,11 +236,13 @@ case class NewInstance(
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
- val setup =
- s"""
- ${argGen.map(_.code).mkString("\n")}
- ${outer.map(_.code).getOrElse("")}
- """.stripMargin
+ var isNull = ev.isNull
+ val setIsNull = if (propagateNull && arguments.nonEmpty) {
+ s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
+ } else {
+ isNull = "false"
+ ""
+ }
val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
@@ -258,27 +250,13 @@ case class NewInstance(
s"new $className($argString)"
}
- if (propagateNull && argGen.nonEmpty) {
- val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
-
- ev.copy(code = s"""
- $setup
-
- boolean ${ev.isNull} = true;
- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
- if ($argsNonNull) {
- ${ev.value} = $constructorCall;
- ${ev.isNull} = false;
- }
- """)
- } else {
- ev.copy(code = s"""
- $setup
-
- final $javaType ${ev.value} = $constructorCall;
- final boolean ${ev.isNull} = false;
- """)
- }
+ val code = s"""
+ ${argGen.map(_.code).mkString("\n")}
+ ${outer.map(_.code).getOrElse("")}
+ $setIsNull
+ final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
+ """
+ ev.copy(code = code, isNull = isNull)
}
override def toString: String = s"newInstance($cls)"
@@ -306,13 +284,14 @@ case class UnwrapOption(
val javaType = ctx.javaType(dataType)
val inputObject = child.genCode(ctx)
- ev.copy(code = s"""
+ val code = s"""
${inputObject.code}
- boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
+ final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} =
- ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
- """)
+ ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
+ """
+ ev.copy(code = code)
}
}
@@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
- ev.copy(code = s"""
+ val code = s"""
${inputObject.code}
- boolean ${ev.isNull} = false;
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
- """)
+ """
+ ev.copy(code = code, isNull = "false")
}
}
@@ -474,7 +453,7 @@ case class MapObjects private(
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
- ev.copy(code = s"""
+ val code = s"""
${genInputData.code}
boolean ${ev.isNull} = ${genInputData.value} == null;
@@ -504,7 +483,8 @@ case class MapObjects private(
${ev.isNull} = false;
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
}
- """)
+ """
+ ev.copy(code = code)
}
}
@@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
}
"""
}
+
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+
+ val code = s"""
$values = new Object[${children.size}];
$childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
- """)
+ """
+ ev.copy(code = code, isNull = "false")
}
}
@@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
// Code to serialize.
val input = child.genCode(ctx)
- ev.copy(code = s"""
+ val javaType = ctx.javaType(dataType)
+ val serialize = s"$serializer.serialize(${input.value}, null).array()"
+
+ val code = s"""
${input.code}
- final boolean ${ev.isNull} = ${input.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = $serializer.serialize(${input.value}, null).array();
- }
- """)
+ final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
+ """
+ ev.copy(code = code, isNull = input.isNull)
}
override def dataType: DataType = BinaryType
@@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
serializer,
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
- // Code to serialize.
+ // Code to deserialize.
val input = child.genCode(ctx)
- ev.copy(code = s"""
+ val javaType = ctx.javaType(dataType)
+ val deserialize =
+ s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
+
+ val code = s"""
${input.code}
- final boolean ${ev.isNull} = ${input.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = (${ctx.javaType(dataType)})
- $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
- }
- """)
+ final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
+ """
+ ev.copy(code = code, isNull = input.isNull)
}
override def dataType: DataType = ObjectType(tag.runtimeClass)
@@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
"""
}
- ev.isNull = instanceGen.isNull
- ev.value = instanceGen.value
-
- ev.copy(code = s"""
+ val code = s"""
${instanceGen.code}
if (!${instanceGen.isNull}) {
${initialize.mkString("\n")}
}
- """)
+ """
+ ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
}
}
@@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
- val idx = ctx.references.length
- ctx.references += errMsg
- ExprCode(code = s"""
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
+
+ val code = s"""
${childGen.code}
if (${childGen.isNull}) {
- throw new RuntimeException((String) references[$idx]);
- }""", isNull = "false", value = childGen.value)
+ throw new RuntimeException(this.$errMsgField);
+ }
+ """
+ ev.copy(code = code, isNull = "false", value = childGen.value)
}
}