From eba6a1af4c8ffb21934a59a61a419d625f37cceb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Jul 2015 09:38:08 -0700 Subject: [SPARK-8945][SQL] Add add and subtract expressions for IntervalType JIRA: https://issues.apache.org/jira/browse/SPARK-8945 Add add and subtract expressions for IntervalType. Author: Liang-Chi Hsieh This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #7398 from viirya/interval_add_subtract and squashes the following commits: acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 5abae28 [Liang-Chi Hsieh] For comments. 6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract dbe3906 [Liang-Chi Hsieh] For comments. 13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 83ec129 [Liang-Chi Hsieh] Remove intervalMethod. acfe1ab [Liang-Chi Hsieh] Fix scala style. d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType. --- .../sql/catalyst/expressions/arithmetic.scala | 60 ++++++++++++++++++---- .../expressions/codegen/CodeGenerator.scala | 4 +- .../spark/sql/catalyst/expressions/literals.scala | 3 +- .../apache/spark/sql/types/AbstractDataType.scala | 6 +++ .../analysis/ExpressionTypeCheckingSuite.scala | 6 +-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 17 ++++++ 6 files changed, 82 insertions(+), 14 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 382cbe3b84..1616d1bc0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.Interval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } - protected override def nullSafeEval(input: Any): Any = numeric.negate(input) + protected override def nullSafeEval(input: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input.asInstanceOf[Interval].negate() + } else { + numeric.negate(input) + } + } } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" - override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + } else { + numeric.plus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" - override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + } else { + numeric.minus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 45dc146488..7c388bc346 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ // These classes are here to avoid issues with serialization and integration with quasiquotes. @@ -69,6 +69,7 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initialValue)) } + final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -137,6 +138,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" + case IntervalType => intervalType case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a7a7ae440..e1fdb29541 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ object Literal { def apply(v: Any): Literal = v match { @@ -42,6 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case i: Interval => Literal(i, IntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 076d7b5a51..40bf4b299c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -91,6 +91,12 @@ private[sql] object TypeCollection { TimestampType, DateType, StringType, BinaryType) + /** + * Types that include numeric types and interval type. They are only used in unary_minus, + * unary_positive, add and subtract operations. + */ + val NumericAndInterval = TypeCollection(NumericType, IntervalType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ed0d20e7de..ad15136ee9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "expected to be of type numeric") + assertError(UnaryMinus('stringField), "type (numeric or interval)") assertError(Abs('stringField), "expected to be of type numeric") assertError(BitwiseNot('stringField), "expected to be of type integral") } @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") assertError(Divide('booleanField, 'booleanField), "accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 231440892b..5b8b70ed5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Currently we don't yet support nanosecond checkIntervalParseError("select interval 23 nanosecond") } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.Interval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + + checkAnswer(df.select(df("i") + new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + + checkAnswer(df.select(df("i") - new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + } } -- cgit v1.2.3