aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-06 22:13:50 -0700
committerReynold Xin <rxin@databricks.com>2015-07-06 22:13:50 -0700
commitc46aaf47f38163e9c7be671d7b8398512df34e62 (patch)
tree78f48031b0c0b330cec13a9239dce660c5e33b87
parent6718c1eb671faaf5c1d865ad5d01dbf78dae9cd2 (diff)
downloadspark-c46aaf47f38163e9c7be671d7b8398512df34e62.tar.gz
spark-c46aaf47f38163e9c7be671d7b8398512df34e62.tar.bz2
spark-c46aaf47f38163e9c7be671d7b8398512df34e62.zip
[SPARK-8759][SQL] add default eval to binary and unary expression according to default behavior of nullable
We have `nullSafeCodeGen` to provide default code generation for binary and unary expression, and we can do the same thing for `eval`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7157 from cloud-fan/refactor and squashes the following commits: f3987c6 [Wenchen Fan] refactor Expression
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala69
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala97
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala186
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala153
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala84
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala138
11 files changed, 292 insertions, 543 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 2d99d1a3fe..4f73ba40b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -114,8 +114,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
}
- override def foldable: Boolean = child.foldable
-
override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
override def toString: String = s"CAST($child, $dataType)"
@@ -426,10 +424,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
- override def eval(input: InternalRow): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) null else cast(evaluated)
- }
+ protected override def nullSafeEval(input: Any): Any = cast(input)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
// TODO: Add support for more data types.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index cafbbafdca..386feb95b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -184,6 +184,27 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
override def nullable: Boolean = child.nullable
/**
+ * Default behavior of evaluation according to the default nullability of UnaryExpression.
+ * If subclass of UnaryExpression override nullable, probably should also override this.
+ */
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ null
+ } else {
+ nullSafeEval(value)
+ }
+ }
+
+ /**
+ * Called by default [[eval]] implementation. If subclass of UnaryExpression keep the default
+ * nullability, they can override this method to save null-check code. If we need full control
+ * of evaluation process, we should override [[eval]].
+ */
+ protected def nullSafeEval(input: Any): Any =
+ sys.error(s"UnaryExpressions must override either eval or nullSafeEval")
+
+ /**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
@@ -198,21 +219,24 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: String => String): String = {
- nullSafeCodeGen(ctx, ev, (result, eval) => {
- s"$result = ${f(eval)};"
+ nullSafeCodeGen(ctx, ev, eval => {
+ s"${ev.primitive} = ${f(eval)};"
})
}
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
+ *
+ * @param f function that accepts the non-null evaluation result name of child and returns Java
+ * code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
- f: (String, String) => String): String = {
+ f: String => String): String = {
val eval = child.gen(ctx)
- val resultCode = f(ev.primitive, eval.primitive)
+ val resultCode = f(eval.primitive)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
@@ -236,6 +260,32 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable
/**
+ * Default behavior of evaluation according to the default nullability of BinaryExpression.
+ * If subclass of BinaryExpression override nullable, probably should also override this.
+ */
+ override def eval(input: InternalRow): Any = {
+ val value1 = left.eval(input)
+ if (value1 == null) {
+ null
+ } else {
+ val value2 = right.eval(input)
+ if (value2 == null) {
+ null
+ } else {
+ nullSafeEval(value1, value2)
+ }
+ }
+ }
+
+ /**
+ * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
+ * nullability, they can override this method to save null-check code. If we need full control
+ * of evaluation process, we should override [[eval]].
+ */
+ protected def nullSafeEval(input1: Any, input2: Any): Any =
+ sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
+
+ /**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
@@ -246,8 +296,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (String, String) => String): String = {
- nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
- s"$result = ${f(eval1, eval2)};"
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ s"${ev.primitive} = ${f(eval1, eval2)};"
})
}
@@ -255,14 +305,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
+ *
+ * @param f function that accepts the 2 non-null evaluation result names of children
+ * and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
- f: (String, String, String) => String): String = {
+ f: (String, String) => String): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
- val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive)
+ val resultCode = f(eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 3020e7fc96..e451c7ffbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -122,18 +122,16 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
- override def eval(input: InternalRow): Any = {
- val baseValue = child.eval(input).asInstanceOf[InternalRow]
- if (baseValue == null) null else baseValue(ordinal)
- }
+ protected override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[InternalRow](ordinal)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, eval) => {
+ nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- $result = ${ctx.getColumn(eval, dataType, ordinal)};
+ ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
}
"""
})
@@ -152,12 +150,9 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable || containsNull || field.nullable
- override def eval(input: InternalRow): Any = {
- val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
- if (baseValue == null) null else {
- baseValue.map { row =>
- if (row == null) null else row(ordinal)
- }
+ protected override def nullSafeEval(input: Any): Any = {
+ input.asInstanceOf[Seq[InternalRow]].map { row =>
+ if (row == null) null else row(ordinal)
}
}
@@ -165,7 +160,7 @@ case class GetArrayStructFields(
val arraySeqClass = "scala.collection.mutable.ArraySeq"
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
- nullSafeCodeGen(ctx, ev, (result, eval) => {
+ nullSafeCodeGen(ctx, ev, eval => {
s"""
final int n = $eval.size();
final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
@@ -175,7 +170,7 @@ case class GetArrayStructFields(
values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
}
}
- $result = (${ctx.javaType(dataType)}) values;
+ ${ev.primitive} = (${ctx.javaType(dataType)}) values;
"""
})
}
@@ -193,22 +188,6 @@ abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValu
/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def toString: String = s"$child[$ordinal]"
-
- override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- null
- } else {
- val o = ordinal.eval(input)
- if (o == null) {
- null
- } else {
- evalNotNull(value, o)
- }
- }
- }
-
- protected def evalNotNull(value: Any, ordinal: Any): Any
}
/**
@@ -219,7 +198,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
- protected def evalNotNull(value: Any, ordinal: Any) = {
+ protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
@@ -232,13 +211,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int index = (int)$eval2;
if (index >= $eval1.size() || index < 0) {
${ev.isNull} = true;
} else {
- $result = (${ctx.boxedType(dataType)})$eval1.apply(index);
+ ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
}
"""
})
@@ -253,16 +232,16 @@ case class GetMapValue(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
- protected def evalNotNull(value: Any, ordinal: Any) = {
+ protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
if ($eval1.contains($eval2)) {
- $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
+ ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
} else {
${ev.isNull} = true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4fbf4c8700..dca6642665 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -26,18 +26,6 @@ abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>
override def dataType: DataType = child.dataType
-
- override def eval(input: InternalRow): Any = {
- val evalE = child.eval(input)
- if (evalE == null) {
- null
- } else {
- evalInternal(evalE)
- }
- }
-
- protected def evalInternal(evalE: Any): Any =
- sys.error(s"UnaryArithmetics must override either eval or evalInternal")
}
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
@@ -53,7 +41,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
}
- protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
+ protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
}
case class UnaryPositive(child: Expression) extends UnaryArithmetic {
@@ -62,7 +50,7 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
defineCodeGen(ctx, ev, c => c)
- protected override def evalInternal(evalE: Any) = evalE
+ protected override def nullSafeEval(input: Any): Any = input
}
/**
@@ -74,7 +62,7 @@ case class Abs(child: Expression) extends UnaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def evalInternal(evalE: Any) = numeric.abs(evalE)
+ protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
}
abstract class BinaryArithmetic extends BinaryOperator {
@@ -94,20 +82,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
protected def checkTypesInternal(t: DataType): TypeCheckResult
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if(evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- evalInternal(evalE1, evalE2)
- }
- }
- }
-
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -122,9 +96,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
-
- protected def evalInternal(evalE1: Any, evalE2: Any): Any =
- sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}
private[sql] object BinaryArithmetic {
@@ -143,7 +114,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -158,7 +129,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -173,7 +144,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -194,15 +165,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
override def eval(input: InternalRow): Any = {
- val evalE2 = right.eval(input)
- if (evalE2 == null || evalE2 == 0) {
+ val input2 = right.eval(input)
+ if (input2 == null || input2 == 0) {
null
} else {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
+ val input1 = left.eval(input)
+ if (input1 == null) {
null
} else {
- div(evalE1, evalE2)
+ div(input1, input2)
}
}
}
@@ -260,15 +231,15 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
override def eval(input: InternalRow): Any = {
- val evalE2 = right.eval(input)
- if (evalE2 == null || evalE2 == 0) {
+ val input2 = right.eval(input)
+ if (input2 == null || input2 == 0) {
null
} else {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
+ val input1 = left.eval(input)
+ if (input1 == null) {
null
} else {
- integral.rem(evalE1, evalE2)
+ integral.rem(input1, input2)
}
}
}
@@ -317,17 +288,17 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- val evalE2 = right.eval(input)
- if (evalE1 == null) {
- evalE2
- } else if (evalE2 == null) {
- evalE1
+ val input1 = left.eval(input)
+ val input2 = right.eval(input)
+ if (input1 == null) {
+ input2
+ } else if (input2 == null) {
+ input1
} else {
- if (ordering.compare(evalE1, evalE2) < 0) {
- evalE2
+ if (ordering.compare(input1, input2) < 0) {
+ input2
} else {
- evalE1
+ input1
}
}
}
@@ -371,17 +342,17 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- val evalE2 = right.eval(input)
- if (evalE1 == null) {
- evalE2
- } else if (evalE2 == null) {
- evalE1
+ val input1 = left.eval(input)
+ val input2 = right.eval(input)
+ if (input1 == null) {
+ input2
+ } else if (input2 == null) {
+ input1
} else {
- if (ordering.compare(evalE1, evalE2) < 0) {
- evalE1
+ if (ordering.compare(input1, input2) < 0) {
+ input1
} else {
- evalE2
+ input2
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
index 9002dda7bf..2d47124d24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -45,7 +45,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
}
- protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
}
/**
@@ -70,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
}
- protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
}
/**
@@ -95,7 +95,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
}
- protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
}
/**
@@ -122,5 +122,5 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
}
- protected override def evalInternal(evalE: Any) = not(evalE)
+ protected override def nullSafeEval(input: Any): Any = not(input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index f5c2dde191..2fa74b4ffc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -30,14 +30,8 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
- override def eval(input: InternalRow): Any = {
- val childResult = child.eval(input)
- if (childResult == null) {
- null
- } else {
- childResult.asInstanceOf[Decimal].toUnscaledLong
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toUnscaledLong
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
@@ -54,26 +48,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override def dataType: DataType = DecimalType(precision, scale)
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
- override def eval(input: InternalRow): Decimal = {
- val childResult = child.eval(input)
- if (childResult == null) {
- null
- } else {
- new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ Decimal(input.asInstanceOf[Long], precision, scale)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.decimalType} ${ev.primitive} = null;
-
- if (!${ev.isNull}) {
- ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
- ${eval.primitive}, $precision, $scale);
+ nullSafeCodeGen(ctx, ev, eval => {
+ s"""
+ ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale);
${ev.isNull} = ${ev.primitive} == null;
- }
"""
+ })
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 9250045398..9dca8513c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -61,21 +61,16 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
- override def eval(input: InternalRow): Any = {
- val evalE = child.eval(input)
- if (evalE == null) {
- null
- } else {
- val result = f(evalE.asInstanceOf[Double])
- if (result.isNaN) null else result
- }
+ protected override def nullSafeEval(input: Any): Any = {
+ val result = f(input.asInstanceOf[Double])
+ if (result.isNaN) null else result
}
// name of function in java.lang.Math
def funcName: String = name.toLowerCase
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, eval) => {
+ nullSafeCodeGen(ctx, ev, eval => {
s"""
${ev.primitive} = java.lang.Math.${funcName}($eval);
if (Double.valueOf(${ev.primitive}).isNaN()) {
@@ -101,19 +96,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
override def dataType: DataType = DoubleType
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double])
- if (result.isNaN) null else result
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
+ if (result.isNaN) null else result
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -194,39 +179,29 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
override def dataType: DataType = LongType
- override def foldable: Boolean = child.foldable
-
// If the value not in the range of [0, 20], it still will be null, so set it to be true here.
override def nullable: Boolean = true
- override def eval(input: InternalRow): Any = {
- val evalE = child.eval(input)
- if (evalE == null) {
+ protected override def nullSafeEval(input: Any): Any = {
+ val value = input.asInstanceOf[jl.Integer]
+ if (value > 20 || value < 0) {
null
} else {
- val input = evalE.asInstanceOf[jl.Integer]
- if (input > 20 || input < 0) {
- null
- } else {
- Factorial.factorial(input)
- }
+ Factorial.factorial(value)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- if (${eval.primitive} > 20 || ${eval.primitive} < 0) {
+ nullSafeCodeGen(ctx, ev, eval => {
+ s"""
+ if ($eval > 20 || $eval < 0) {
${ev.isNull} = true;
} else {
${ev.primitive} =
- org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive});
+ org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval);
}
- }
- """
+ """
+ })
}
}
@@ -235,17 +210,14 @@ case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2);
+ nullSafeCodeGen(ctx, ev, eval => {
+ s"""
+ ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2);
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
- }
- """
+ """
+ })
}
}
@@ -283,14 +255,8 @@ case class Bin(child: Expression)
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
- override def eval(input: InternalRow): Any = {
- val evalE = child.eval(input)
- if (evalE == null) {
- null
- } else {
- UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long]))
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c) =>
@@ -326,17 +292,10 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
override def dataType: DataType = StringType
- override def eval(input: InternalRow): Any = {
- val num = child.eval(input)
- if (num == null) {
- null
- } else {
- child.dataType match {
- case LongType => hex(num.asInstanceOf[Long])
- case BinaryType => hex(num.asInstanceOf[Array[Byte]])
- case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
- }
- }
+ protected override def nullSafeEval(num: Any): Any = child.dataType match {
+ case LongType => hex(num.asInstanceOf[Long])
+ case BinaryType => hex(num.asInstanceOf[Array[Byte]])
+ case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
}
private[this] def hex(bytes: Array[Byte]): UTF8String = {
@@ -377,14 +336,8 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
override def nullable: Boolean = true
override def dataType: DataType = BinaryType
- override def eval(input: InternalRow): Any = {
- val num = child.eval(input)
- if (num == null) {
- null
- } else {
- unhex(num.asInstanceOf[UTF8String].getBytes)
- }
- }
+ protected override def nullSafeEval(num: Any): Any =
+ unhex(num.asInstanceOf[UTF8String].getBytes)
private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
val out = new Array[Byte]((bytes.length + 1) >> 1)
@@ -429,21 +382,10 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
- val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
- evalE2.asInstanceOf[Double] + 0.0)
- if (result.isNaN) null else result
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
+ val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
+ if (result.isNaN) null else result
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -480,25 +422,15 @@ case class ShiftLeft(left: Expression, right: Expression)
override def dataType: DataType = left.dataType
- override def eval(input: InternalRow): Any = {
- val valueLeft = left.eval(input)
- if (valueLeft != null) {
- val valueRight = right.eval(input)
- if (valueRight != null) {
- valueLeft match {
- case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer]
- case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer]
- }
- } else {
- null
- }
- } else {
- null
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ input1 match {
+ case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
}
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
+ defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
}
}
@@ -516,25 +448,15 @@ case class ShiftRight(left: Expression, right: Expression)
override def dataType: DataType = left.dataType
- override def eval(input: InternalRow): Any = {
- val valueLeft = left.eval(input)
- if (valueLeft != null) {
- val valueRight = right.eval(input)
- if (valueRight != null) {
- valueLeft match {
- case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer]
- case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer]
- }
- } else {
- null
- }
- } else {
- null
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ input1 match {
+ case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
}
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
+ defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
}
}
@@ -552,25 +474,15 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
override def dataType: DataType = left.dataType
- override def eval(input: InternalRow): Any = {
- val valueLeft = left.eval(input)
- if (valueLeft != null) {
- val valueRight = right.eval(input)
- if (valueRight != null) {
- valueLeft match {
- case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer]
- case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer]
- }
- } else {
- null
- }
- } else {
- null
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ input1 match {
+ case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
+ case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
}
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
+ defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e008af3966..3b59cd431b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -37,14 +37,8 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes
override def inputTypes: Seq[DataType] = Seq(BinaryType)
- override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- null
- } else {
- UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]]))
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]]))
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
@@ -67,76 +61,56 @@ case class Sha2(left: Expression, right: Expression)
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- val bitLength = evalE2.asInstanceOf[Int]
- val input = evalE1.asInstanceOf[Array[Byte]]
- bitLength match {
- case 224 =>
- // DigestUtils doesn't support SHA-224 now
- try {
- val md = MessageDigest.getInstance("SHA-224")
- md.update(input)
- UTF8String.fromBytes(md.digest())
- } catch {
- // SHA-224 is not supported on the system, return null
- case noa: NoSuchAlgorithmException => null
- }
- case 256 | 0 =>
- UTF8String.fromString(DigestUtils.sha256Hex(input))
- case 384 =>
- UTF8String.fromString(DigestUtils.sha384Hex(input))
- case 512 =>
- UTF8String.fromString(DigestUtils.sha512Hex(input))
- case _ => null
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val bitLength = input2.asInstanceOf[Int]
+ val input = input1.asInstanceOf[Array[Byte]]
+ bitLength match {
+ case 224 =>
+ // DigestUtils doesn't support SHA-224 now
+ try {
+ val md = MessageDigest.getInstance("SHA-224")
+ md.update(input)
+ UTF8String.fromBytes(md.digest())
+ } catch {
+ // SHA-224 is not supported on the system, return null
+ case noa: NoSuchAlgorithmException => null
}
- }
+ case 256 | 0 =>
+ UTF8String.fromString(DigestUtils.sha256Hex(input))
+ case 384 =>
+ UTF8String.fromString(DigestUtils.sha384Hex(input))
+ case 512 =>
+ UTF8String.fromString(DigestUtils.sha512Hex(input))
+ case _ => null
}
}
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval1 = left.gen(ctx)
- val eval2 = right.gen(ctx)
val digestUtils = "org.apache.commons.codec.digest.DigestUtils"
-
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${eval2.code}
- if (!${eval2.isNull}) {
- if (${eval2.primitive} == 224) {
- try {
- java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
- md.update(${eval1.primitive});
- ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
- } catch (java.security.NoSuchAlgorithmException e) {
- ${ev.isNull} = true;
- }
- } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
- ${ev.primitive} =
- ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
- } else if (${eval2.primitive} == 384) {
- ${ev.primitive} =
- ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
- } else if (${eval2.primitive} == 512) {
- ${ev.primitive} =
- ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
- } else {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ s"""
+ if ($eval2 == 224) {
+ try {
+ java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
+ md.update($eval1);
+ ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+ } catch (java.security.NoSuchAlgorithmException e) {
${ev.isNull} = true;
}
+ } else if ($eval2 == 256 || $eval2 == 0) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1));
+ } else if ($eval2 == 384) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1));
+ } else if ($eval2 == 512) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1));
} else {
${ev.isNull} = true;
}
- }
- """
+ """
+ })
}
}
@@ -150,19 +124,12 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
- override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- null
- } else {
- UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]]))
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]]))
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
- "org.apache.spark.unsafe.types.UTF8String.fromString" +
- s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+ s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
)
}
}
@@ -177,30 +144,20 @@ case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTyp
override def inputTypes: Seq[DataType] = Seq(BinaryType)
- override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- null
- } else {
- val checksum = new CRC32
- checksum.update(value.asInstanceOf[Array[Byte]], 0, value.asInstanceOf[Array[Byte]].length)
- checksum.getValue
- }
+ protected override def nullSafeEval(input: Any): Any = {
+ val checksum = new CRC32
+ checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length)
+ checksum.getValue
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val value = child.gen(ctx)
val CRC32 = "java.util.zip.CRC32"
- s"""
- ${value.code}
- boolean ${ev.isNull} = ${value.isNull};
- long ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${CRC32} checksum = new ${CRC32}();
- checksum.update(${value.primitive}, 0, ${value.primitive}.length);
+ nullSafeCodeGen(ctx, ev, value => {
+ s"""
+ $CRC32 checksum = new $CRC32();
+ checksum.update($value, 0, $value.length);
${ev.primitive} = checksum.getValue();
- }
- """
+ """
+ })
}
-
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 0b479f466c..402a0aa232 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -74,12 +74,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
override def inputTypes: Seq[DataType] = Seq(BooleanType)
- override def eval(input: InternalRow): Any = {
- child.eval(input) match {
- case null => null
- case b: Boolean => !b
- }
- }
+ protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean]
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"!($c)")
@@ -105,17 +100,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
-case class InSet(value: Expression, hset: Set[Any])
- extends Predicate {
-
- override def children: Seq[Expression] = value :: Nil
+case class InSet(child: Expression, hset: Set[Any])
+ extends UnaryExpression with Predicate {
- override def foldable: Boolean = value.foldable
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
- override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}"
+ override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
- hset.contains(value.eval(input))
+ hset.contains(child.eval(input))
}
}
@@ -127,15 +119,15 @@ case class And(left: Expression, right: Expression)
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def eval(input: InternalRow): Any = {
- val l = left.eval(input)
- if (l == false) {
+ val input1 = left.eval(input)
+ if (input1 == false) {
false
} else {
- val r = right.eval(input)
- if (r == false) {
+ val input2 = right.eval(input)
+ if (input2 == false) {
false
} else {
- if (l != null && r != null) {
+ if (input1 != null && input2 != null) {
true
} else {
null
@@ -176,15 +168,15 @@ case class Or(left: Expression, right: Expression)
override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def eval(input: InternalRow): Any = {
- val l = left.eval(input)
- if (l == true) {
+ val input1 = left.eval(input)
+ if (input1 == true) {
true
} else {
- val r = right.eval(input)
- if (r == true) {
+ val input2 = right.eval(input)
+ if (input2 == true) {
true
} else {
- if (l != null && r != null) {
+ if (input1 != null && input2 != null) {
false
} else {
null
@@ -232,20 +224,6 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
protected def checkTypesInternal(t: DataType): TypeCheckResult
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- evalInternal(evalE1, evalE2)
- }
- }
- }
-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)) {
// faster version
@@ -254,9 +232,6 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}
-
- protected def evalInternal(evalE1: Any, evalE2: Any): Any =
- sys.error(s"BinaryComparisons must override either eval or evalInternal")
}
private[sql] object BinaryComparison {
@@ -277,9 +252,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
- protected override def evalInternal(l: Any, r: Any) = {
- if (left.dataType != BinaryType) l == r
- else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ if (left.dataType != BinaryType) input1 == input2
+ else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -295,15 +270,18 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
override def eval(input: InternalRow): Any = {
- val l = left.eval(input)
- val r = right.eval(input)
- if (l == null && r == null) {
+ val input1 = left.eval(input)
+ val input2 = right.eval(input)
+ if (input1 == null && input2 == null) {
true
- } else if (l == null || r == null) {
+ } else if (input1 == null || input2 == null) {
false
} else {
- if (left.dataType != BinaryType) l == r
- else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
+ if (left.dataType != BinaryType) {
+ input1 == input2
+ } else {
+ java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+ }
}
}
@@ -327,7 +305,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
@@ -338,7 +316,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
@@ -349,7 +327,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
@@ -360,5 +338,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
- protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 5d51a4ca65..9b44fb1ed5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -135,6 +135,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
*/
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
+ override def nullable: Boolean = left.nullable
override def dataType: DataType = left.dataType
override def eval(input: InternalRow): Any = {
@@ -183,12 +184,8 @@ case class CountSet(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
- override def eval(input: InternalRow): Any = {
- val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
- if (childEval != null) {
- childEval.size.toLong
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[OpenHashSet[Any]].size.toLong
override def toString: String = s"$child.count()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 1a14a7a449..6e6a7fb171 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -31,7 +31,6 @@ trait StringRegexExpression extends ExpectsInputTypes {
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
- override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
@@ -50,22 +49,12 @@ trait StringRegexExpression extends ExpectsInputTypes {
protected def pattern(str: String) = if (cache == null) compile(str) else cache
- override def eval(input: InternalRow): Any = {
- val l = left.eval(input)
- if (l == null) {
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val regex = pattern(input2.asInstanceOf[UTF8String].toString())
+ if(regex == null) {
null
} else {
- val r = right.eval(input)
- if(r == null) {
- null
- } else {
- val regex = pattern(r.asInstanceOf[UTF8String].toString())
- if(regex == null) {
- null
- } else {
- matches(regex, l.asInstanceOf[UTF8String].toString())
- }
- }
+ matches(regex, input1.asInstanceOf[UTF8String].toString())
}
}
}
@@ -120,14 +109,8 @@ trait CaseConversionExpression extends ExpectsInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType)
- override def eval(input: InternalRow): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- convert(evaluated.asInstanceOf[UTF8String])
- }
- }
+ protected override def nullSafeEval(input: Any): Any =
+ convert(input.asInstanceOf[UTF8String])
}
/**
@@ -160,20 +143,10 @@ trait StringComparison extends ExpectsInputTypes {
def compare(l: UTF8String, r: UTF8String): Boolean
- override def nullable: Boolean = left.nullable || right.nullable
-
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
- override def eval(input: InternalRow): Any = {
- val leftEval = left.eval(input)
- if(leftEval == null) {
- null
- } else {
- val rightEval = right.eval(input)
- if (rightEval == null) null
- else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String])
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any =
+ compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])
override def toString: String = s"$nodeName($left, $right)"
}
@@ -288,10 +261,8 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
- override def eval(input: InternalRow): Any = {
- val string = child.eval(input)
- if (string == null) null else string.asInstanceOf[UTF8String].length
- }
+ protected override def nullSafeEval(string: Any): Any =
+ string.asInstanceOf[UTF8String].length
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"($c).length()")
@@ -310,24 +281,13 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def dataType: DataType = IntegerType
- override def eval(input: InternalRow): Any = {
- val leftValue = left.eval(input)
- if (leftValue == null) {
- null
- } else {
- val rightValue = right.eval(input)
- if(rightValue == null) {
- null
- } else {
- StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString)
- }
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any =
+ StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val stringUtils = classOf[StringUtils].getName
- nullSafeCodeGen(ctx, ev, (res, left, right) =>
- s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());")
+ defineCodeGen(ctx, ev, (left, right) =>
+ s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())")
}
}
@@ -338,17 +298,12 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
- override def eval(input: InternalRow): Any = {
- val string = child.eval(input)
- if (string == null) {
- null
+ protected override def nullSafeEval(string: Any): Any = {
+ val bytes = string.asInstanceOf[UTF8String].getBytes
+ if (bytes.length > 0) {
+ bytes(0).asInstanceOf[Int]
} else {
- val bytes = string.asInstanceOf[UTF8String].getBytes
- if (bytes.length > 0) {
- bytes(0).asInstanceOf[Int]
- } else {
- 0
- }
+ 0
}
}
}
@@ -360,15 +315,10 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
- override def eval(input: InternalRow): Any = {
- val bytes = child.eval(input)
- if (bytes == null) {
- null
- } else {
- UTF8String.fromBytes(
- org.apache.commons.codec.binary.Base64.encodeBase64(
- bytes.asInstanceOf[Array[Byte]]))
- }
+ protected override def nullSafeEval(bytes: Any): Any = {
+ UTF8String.fromBytes(
+ org.apache.commons.codec.binary.Base64.encodeBase64(
+ bytes.asInstanceOf[Array[Byte]]))
}
}
@@ -379,14 +329,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)
- override def eval(input: InternalRow): Any = {
- val string = child.eval(input)
- if (string == null) {
- null
- } else {
- org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
- }
- }
+ protected override def nullSafeEval(string: Any): Any =
+ org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
}
/**
@@ -402,19 +346,9 @@ case class Decode(bin: Expression, charset: Expression)
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)
- override def eval(input: InternalRow): Any = {
- val l = bin.eval(input)
- if (l == null) {
- null
- } else {
- val r = charset.eval(input)
- if (r == null) {
- null
- } else {
- val fromCharset = r.asInstanceOf[UTF8String].toString
- UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset))
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val fromCharset = input2.asInstanceOf[UTF8String].toString
+ UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset))
}
}
@@ -431,19 +365,9 @@ case class Encode(value: Expression, charset: Expression)
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
- override def eval(input: InternalRow): Any = {
- val l = value.eval(input)
- if (l == null) {
- null
- } else {
- val r = charset.eval(input)
- if (r == null) {
- null
- } else {
- val toCharset = r.asInstanceOf[UTF8String].toString
- l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
- }
- }
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val toCharset = input2.asInstanceOf[UTF8String].toString
+ input1.asInstanceOf[UTF8String].toString.getBytes(toCharset)
}
}