aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-29 22:32:43 -0700
committerReynold Xin <rxin@databricks.com>2015-06-29 22:32:43 -0700
commitf79410c49b2225b2acdc58293574860230987775 (patch)
treeaef488f78e8329f7d9a3e44b82fdbb4a8081dab8 /sql
parentea775b0662b952849ac7fe2026fc3fd4714c37e3 (diff)
downloadspark-f79410c49b2225b2acdc58293574860230987775.tar.gz
spark-f79410c49b2225b2acdc58293574860230987775.tar.bz2
spark-f79410c49b2225b2acdc58293574860230987775.zip
[SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes.
Author: Reynold Xin <rxin@databricks.com> Closes #7109 from rxin/auto-cast and squashes the following commits: a914cc3 [Reynold Xin] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala118
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala10
6 files changed, 71 insertions, 79 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 976fa57cb9..c3d68197d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -116,7 +116,7 @@ trait HiveTypeCoercion {
IfCoercion ::
Division ::
PropagateTypes ::
- ExpectedInputConversion ::
+ AddCastForAutoCastInputTypes ::
Nil
/**
@@ -709,15 +709,15 @@ trait HiveTypeCoercion {
/**
* Casts types according to the expected input types for Expressions that have the trait
- * `ExpectsInputTypes`.
+ * [[AutoCastInputTypes]].
*/
- object ExpectedInputConversion extends Rule[LogicalPlan] {
+ object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
+ case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
case (child, actual, expected) =>
if (actual == expected) child else Cast(child, expected)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index f59db3d5df..e5dc7b9b5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -261,7 +261,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
*/
-trait ExpectsInputTypes {
+trait AutoCastInputTypes {
self: Expression =>
def expectedChildTypes: Seq[DataType]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 4b57ddd9c5..a022f3727b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -56,7 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param name The short name of the function
*/
abstract class UnaryMathExpression(f: Double => Double, name: String)
- extends UnaryExpression with Serializable with ExpectsInputTypes {
+ extends UnaryExpression with Serializable with AutoCastInputTypes {
self: Product =>
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
@@ -99,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
* @param name The short name of the function
*/
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
- extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
+ extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
@@ -211,19 +211,11 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
}
case class Bin(child: Expression)
- extends UnaryExpression with Serializable with ExpectsInputTypes {
-
- val name: String = "BIN"
-
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = true
- override def toString: String = s"$name($child)"
+ extends UnaryExpression with Serializable with AutoCastInputTypes {
override def expectedChildTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
- def funcName: String = name.toLowerCase
-
override def eval(input: InternalRow): Any = {
val evalE = child.eval(input)
if (evalE == null) {
@@ -239,61 +231,13 @@ case class Bin(child: Expression)
}
}
-////////////////////////////////////////////////////////////////////////////////////////////////////
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// Binary math functions
-////////////////////////////////////////////////////////////////////////////////////////////////////
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-
-case class Atan2(left: Expression, right: Expression)
- extends BinaryMathExpression(math.atan2, "ATAN2") {
-
- override def eval(input: InternalRow): Any = {
- val evalE1 = left.eval(input)
- if (evalE1 == null) {
- null
- } else {
- val evalE2 = right.eval(input)
- if (evalE2 == null) {
- null
- } else {
- // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
- val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
- evalE2.asInstanceOf[Double] + 0.0)
- if (result.isNaN) null else result
- }
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
- if (Double.valueOf(${ev.primitive}).isNaN()) {
- ${ev.isNull} = true;
- }
- """
- }
-}
-
-case class Pow(left: Expression, right: Expression)
- extends BinaryMathExpression(math.pow, "POWER") {
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
- if (Double.valueOf(${ev.primitive}).isNaN()) {
- ${ev.isNull} = true;
- }
- """
- }
-}
/**
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
- * Otherwise if the number is a STRING,
- * it converts each character into its hexadecimal representation and returns the resulting STRING.
- * Negative numbers would be treated as two's complement.
+ * Otherwise if the number is a STRING, it converts each character into its hex representation
+ * and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
-case class Hex(child: Expression)
- extends UnaryExpression with Serializable {
+case class Hex(child: Expression) extends UnaryExpression with Serializable {
override def dataType: DataType = StringType
@@ -337,7 +281,7 @@ case class Hex(child: Expression)
private def doHex(bytes: Array[Byte], length: Int): UTF8String = {
val value = new Array[Byte](length * 2)
var i = 0
- while(i < length) {
+ while (i < length) {
value(i * 2) = Character.toUpperCase(Character.forDigit(
(bytes(i) & 0xF0) >>> 4, 16)).toByte
value(i * 2 + 1) = Character.toUpperCase(Character.forDigit(
@@ -362,6 +306,54 @@ case class Hex(child: Expression)
}
}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Binary math functions
+////////////////////////////////////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+case class Atan2(left: Expression, right: Expression)
+ extends BinaryMathExpression(math.atan2, "ATAN2") {
+
+ override def eval(input: InternalRow): Any = {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
+ val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
+ evalE2.asInstanceOf[Double] + 0.0)
+ if (result.isNaN) null else result
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ """
+ }
+}
+
+case class Pow(left: Expression, right: Expression)
+ extends BinaryMathExpression(math.pow, "POWER") {
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ """
+ }
+}
+
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 9a39165a1f..27805bff29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String
* For input of type [[BinaryType]]
*/
case class Md5(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = StringType
@@ -61,7 +61,7 @@ case class Md5(child: Expression)
* the hash length is not one of the permitted values, the return value is NULL.
*/
case class Sha2(left: Expression, right: Expression)
- extends BinaryExpression with Serializable with ExpectsInputTypes {
+ extends BinaryExpression with Serializable with AutoCastInputTypes {
override def dataType: DataType = StringType
@@ -146,7 +146,7 @@ case class Sha2(left: Expression, right: Expression)
* A function that calculates a sha1 hash value and returns it as a hex string
* For input of type [[BinaryType]] or [[StringType]]
*/
-case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = StringType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3a12d03ba6..386cf6a8df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -70,7 +70,7 @@ trait PredicateHelper {
}
-case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
+case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"NOT $child"
@@ -123,7 +123,7 @@ case class InSet(value: Expression, hset: Set[Any])
}
case class And(left: Expression, right: Expression)
- extends BinaryExpression with Predicate with ExpectsInputTypes {
+ extends BinaryExpression with Predicate with AutoCastInputTypes {
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
@@ -172,7 +172,7 @@ case class And(left: Expression, right: Expression)
}
case class Or(left: Expression, right: Expression)
- extends BinaryExpression with Predicate with ExpectsInputTypes {
+ extends BinaryExpression with Predicate with AutoCastInputTypes {
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index a6225fdafe..ce184e4f32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-trait StringRegexExpression extends ExpectsInputTypes {
+trait StringRegexExpression extends AutoCastInputTypes {
self: BinaryExpression =>
def escape(v: String): String
@@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
-trait CaseConversionExpression extends ExpectsInputTypes {
+trait CaseConversionExpression extends AutoCastInputTypes {
self: UnaryExpression =>
def convert(v: UTF8String): UTF8String
@@ -158,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
}
/** A base trait for functions that compare two strings, returning a boolean. */
-trait StringComparison extends ExpectsInputTypes {
+trait StringComparison extends AutoCastInputTypes {
self: BinaryExpression =>
def compare(l: UTF8String, r: UTF8String): Boolean
@@ -221,7 +221,7 @@ case class EndsWith(left: Expression, right: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
- extends Expression with ExpectsInputTypes {
+ extends Expression with AutoCastInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -295,7 +295,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
/**
* A function that return the length of the given string expression.
*/
-case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = IntegerType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)