diff options
4 files changed, 27 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 04e306da23..97b123ec2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -120,6 +120,7 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Signum]("sign"), expression[Signum]("signum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8b78c50000..167e460d5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -58,6 +58,15 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } +case class UnaryPositive(child: Expression) extends UnaryArithmetic { + override def toString: String = s"positive($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + defineCodeGen(ctx, ev, c => c) + + protected override def evalInternal(evalE: Any) = evalE +} + case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f8f1efcc7e..9132a786f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,6 +52,7 @@ object DefaultOptimizer extends Optimizer { LikeSimplification, BooleanSimplification, PushPredicateThroughJoin, + RemovePositive, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -633,6 +634,15 @@ object SimplifyCasts extends Rule[LogicalPlan] { } /** + * Removes [[UnaryPositive]] identify function + */ +object RemovePositive extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case UnaryPositive(child) => child + } +} + +/** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index faa1d1193b..e2daaf6b73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -262,4 +262,11 @@ class MathExpressionsSuite extends QueryTest { ctx.sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } + + test("positive") { + val df = Seq((1, -1, "abc")).toDF("a", "b", "c") + checkAnswer(df.selectExpr("positive(a)"), Row(1)) + checkAnswer(df.selectExpr("positive(b)"), Row(-1)) + checkAnswer(df.selectExpr("positive(c)"), Row("abc")) + } } |