aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala120
1 files changed, 86 insertions, 34 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ed812e0679..f3d42fc0b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-
-case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns -a.")
+case class UnaryMinus(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -58,7 +60,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}
-case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a.")
+case class UnaryPositive(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -77,9 +82,10 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
* A function that get the absolute value of the numeric value.
*/
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
- extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.",
+ extended = "> SELECT _FUNC_('-1');\n 1")
+case class Abs(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
@@ -123,7 +129,9 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a+b.")
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -152,7 +160,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
-case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a-b.")
+case class Subtract(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -181,7 +192,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}
-case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Multiplies a by b.")
+case class Multiply(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -193,7 +207,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Divides a by b.",
+ extended = "> SELECT 3 _FUNC_ 2;\n 1.5")
+case class Divide(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -237,25 +255,42 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $divide;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $divide;
+ }
+ }
+ """
+ }
}
}
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns the remainder when dividing a by b.")
+case class Remainder(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -299,21 +334,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $remainder;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $remainder;
+ }
+ }
+ """
+ }
}
}
@@ -429,7 +478,10 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Returns the positive modulo",
+ extended = "> SELECT _FUNC_(10,3);\n 1")
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def toString: String = s"pmod($left, $right)"