diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 120 |
1 files changed, 86 insertions, 34 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ed812e0679..f3d42fc0b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +@ExpressionDescription( + usage = "_FUNC_(a) - Returns -a.") +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -58,7 +60,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a.") +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -77,9 +82,10 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects * A function that get the absolute value of the numeric value. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", - extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.", + extended = "> SELECT _FUNC_('-1');\n 1") +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -123,7 +129,9 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a+b.") +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -152,7 +160,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a-b.") +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -181,7 +192,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Multiplies a by b.") +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -193,7 +207,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -237,25 +255,42 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $divide; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $divide; + } + } + """ + } } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -299,21 +334,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $remainder; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $remainder; + } + } + """ + } } } @@ -429,7 +478,10 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns the positive modulo", + extended = "> SELECT _FUNC_(10,3);\n 1") +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" |