aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-07-17 09:38:08 -0700
committerReynold Xin <rxin@databricks.com>2015-07-17 09:38:08 -0700
commiteba6a1af4c8ffb21934a59a61a419d625f37cceb (patch)
tree963d47659b358384adf73a4e8040a328d1a04c80 /sql
parent305e77cd83f3dbe680a920d5329c2e8c58452d5b (diff)
downloadspark-eba6a1af4c8ffb21934a59a61a419d625f37cceb.tar.gz
spark-eba6a1af4c8ffb21934a59a61a419d625f37cceb.tar.bz2
spark-eba6a1af4c8ffb21934a59a61a419d625f37cceb.zip
[SPARK-8945][SQL] Add add and subtract expressions for IntervalType
JIRA: https://issues.apache.org/jira/browse/SPARK-8945 Add add and subtract expressions for IntervalType. Author: Liang-Chi Hsieh <viirya@appier.com> This patch had conflicts when merged, resolved by Committer: Reynold Xin <rxin@databricks.com> Closes #7398 from viirya/interval_add_subtract and squashes the following commits: acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 5abae28 [Liang-Chi Hsieh] For comments. 6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract dbe3906 [Liang-Chi Hsieh] For comments. 13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 83ec129 [Liang-Chi Hsieh] Remove intervalMethod. acfe1ab [Liang-Chi Hsieh] Fix scala style. d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala60
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala17
6 files changed, 82 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 382cbe3b84..1616d1bc0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.Interval
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
override def dataType: DataType = child.dataType
@@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
+ case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
- protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
+ protected override def nullSafeEval(input: Any): Any = {
+ if (dataType.isInstanceOf[IntervalType]) {
+ input.asInstanceOf[Interval].negate()
+ } else {
+ numeric.negate(input)
+ }
+ }
}
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def prettyName: String = "positive"
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
override def dataType: DataType = child.dataType
@@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
- override def inputType: AbstractDataType = NumericType
+ override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
override def symbol: String = "+"
- override def decimalMethod: String = "$plus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ if (dataType.isInstanceOf[IntervalType]) {
+ input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
+ } else {
+ numeric.plus(input1, input2)
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
+ case dt: DecimalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
+ case ByteType | ShortType =>
+ defineCodeGen(ctx, ev,
+ (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ case IntervalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
+ case _ =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
+ }
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
- override def inputType: AbstractDataType = NumericType
+ override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
override def symbol: String = "-"
- override def decimalMethod: String = "$minus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType)
- protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
+ protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+ if (dataType.isInstanceOf[IntervalType]) {
+ input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
+ } else {
+ numeric.minus(input1, input2)
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
+ case dt: DecimalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
+ case ByteType | ShortType =>
+ defineCodeGen(ctx, ev,
+ (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ case IntervalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
+ case _ =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
+ }
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 45dc146488..7c388bc346 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types._
// These classes are here to avoid issues with serialization and integration with quasiquotes.
@@ -69,6 +69,7 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initialValue))
}
+ final val intervalType: String = classOf[Interval].getName
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -137,6 +138,7 @@ class CodeGenContext {
case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
+ case IntervalType => intervalType
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 3a7a7ae440..e1fdb29541 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types._
object Literal {
def apply(v: Any): Literal = v match {
@@ -42,6 +42,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
+ case i: Interval => Literal(i, IntervalType)
case null => Literal(null, NullType)
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 076d7b5a51..40bf4b299c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -91,6 +91,12 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)
+ /**
+ * Types that include numeric types and interval type. They are only used in unary_minus,
+ * unary_positive, add and subtract operations.
+ */
+ val NumericAndInterval = TypeCollection(NumericType, IntervalType)
+
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ed0d20e7de..ad15136ee9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for unary arithmetic") {
- assertError(UnaryMinus('stringField), "expected to be of type numeric")
+ assertError(UnaryMinus('stringField), "type (numeric or interval)")
assertError(Abs('stringField), "expected to be of type numeric")
assertError(BitwiseNot('stringField), "expected to be of type integral")
}
@@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "accepts numeric type")
- assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
+ assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 231440892b..5b8b70ed5a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Currently we don't yet support nanosecond
checkIntervalParseError("select interval 23 nanosecond")
}
+
+ test("SPARK-8945: add and subtract expressions for interval type") {
+ import org.apache.spark.unsafe.types.Interval
+
+ val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
+ checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
+
+ checkAnswer(df.select(df("i") + new Interval(2, 123)),
+ Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
+
+ checkAnswer(df.select(df("i") - new Interval(2, 123)),
+ Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
+
+ // unary minus
+ checkAnswer(df.select(-df("i")),
+ Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
+ }
}