aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala2
2 files changed, 9 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index c5960eb390..e83650fc8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -73,6 +73,13 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes
private lazy val numeric = TypeUtils.getNumeric(dataType)
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
+ case dt: DecimalType =>
+ defineCodeGen(ctx, ev, c => s"$c.abs()")
+ case dt: NumericType =>
+ defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
+ }
+
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index a85af9e04a..bc689810bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -278,6 +278,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
Decimal(-longVal, precision, scale)
}
}
+
+ def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this
}
object Decimal {