aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-25 22:16:53 -0700
committerDavies Liu <davies@databricks.com>2015-06-25 22:16:53 -0700
commit40360112c417b5432564f4bcb8a9100f4066b55e (patch)
treebc716f5bf1cd5a9af01ea353d7b8afd345932d04 /sql
parent47c874babe7779c7a2f32e0b891503ef6bebcab0 (diff)
downloadspark-40360112c417b5432564f4bcb8a9100f4066b55e.tar.gz
spark-40360112c417b5432564f4bcb8a9100f4066b55e.tar.bz2
spark-40360112c417b5432564f4bcb8a9100f4066b55e.zip
[SPARK-8620] [SQL] cleanup CodeGenContext
fix docs, remove nativeTypes , use java type to get boxed type ,default value, etc. to avoid handle `DateType` and `TimestampType` as int and long again and again. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7010 from cloud-fan/cg and squashes the following commits: aa01cf9 [Wenchen Fan] cleanup CodeGenContext
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala130
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala1
4 files changed, 82 insertions, 88 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8bd7fc18a8..8d66968a2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -467,11 +467,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
defineCodeGen(ctx, ev, c => s"!$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")
-
- case (_: DecimalType, IntegerType) =>
- defineCodeGen(ctx, ev, c => s"($c).toInt()")
case (_: DecimalType, dt: NumericType) =>
- defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
+ defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 47c5455435..e20e3a9dca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -59,6 +59,14 @@ class CodeGenContext {
val stringType: String = classOf[UTF8String].getName
val decimalType: String = classOf[Decimal].getName
+ final val JAVA_BOOLEAN = "boolean"
+ final val JAVA_BYTE = "byte"
+ final val JAVA_SHORT = "short"
+ final val JAVA_INT = "int"
+ final val JAVA_LONG = "long"
+ final val JAVA_FLOAT = "float"
+ final val JAVA_DOUBLE = "double"
+
private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
@@ -72,98 +80,94 @@ class CodeGenContext {
}
/**
- * Return the code to access a column for given DataType
+ * Returns the code to access a column in Row for a given DataType.
*/
def getColumn(dataType: DataType, ordinal: Int): String = {
- if (isNativeType(dataType)) {
- s"i.${accessorForType(dataType)}($ordinal)"
+ val jt = javaType(dataType)
+ if (isPrimitiveType(jt)) {
+ s"i.get${primitiveTypeName(jt)}($ordinal)"
} else {
- s"(${boxedType(dataType)})i.apply($ordinal)"
+ s"($jt)i.apply($ordinal)"
}
}
/**
- * Return the code to update a column in Row for given DataType
+ * Returns the code to update a column in Row for a given DataType.
*/
def setColumn(dataType: DataType, ordinal: Int, value: String): String = {
- if (isNativeType(dataType)) {
- s"${mutatorForType(dataType)}($ordinal, $value)"
+ val jt = javaType(dataType)
+ if (isPrimitiveType(jt)) {
+ s"set${primitiveTypeName(jt)}($ordinal, $value)"
} else {
s"update($ordinal, $value)"
}
}
/**
- * Return the name of accessor in Row for a DataType
+ * Returns the name used in accessor and setter for a Java primitive type.
*/
- def accessorForType(dt: DataType): String = dt match {
- case IntegerType => "getInt"
- case other => s"get${boxedType(dt)}"
+ def primitiveTypeName(jt: String): String = jt match {
+ case JAVA_INT => "Int"
+ case _ => boxedType(jt)
}
- /**
- * Return the name of mutator in Row for a DataType
- */
- def mutatorForType(dt: DataType): String = dt match {
- case IntegerType => "setInt"
- case other => s"set${boxedType(dt)}"
- }
+ def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
/**
- * Return the Java type for a DataType
+ * Returns the Java type for a DataType.
*/
def javaType(dt: DataType): String = dt match {
- case IntegerType => "int"
- case LongType => "long"
- case ShortType => "short"
- case ByteType => "byte"
- case DoubleType => "double"
- case FloatType => "float"
- case BooleanType => "boolean"
+ case BooleanType => JAVA_BOOLEAN
+ case ByteType => JAVA_BYTE
+ case ShortType => JAVA_SHORT
+ case IntegerType => JAVA_INT
+ case LongType => JAVA_LONG
+ case FloatType => JAVA_FLOAT
+ case DoubleType => JAVA_DOUBLE
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
- case DateType => "int"
- case TimestampType => "long"
+ case DateType => JAVA_INT
+ case TimestampType => JAVA_LONG
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
}
/**
- * Return the boxed type in Java
+ * Returns the boxed type in Java.
*/
- def boxedType(dt: DataType): String = dt match {
- case IntegerType => "Integer"
- case LongType => "Long"
- case ShortType => "Short"
- case ByteType => "Byte"
- case DoubleType => "Double"
- case FloatType => "Float"
- case BooleanType => "Boolean"
- case DateType => "Integer"
- case TimestampType => "Long"
- case _ => javaType(dt)
+ def boxedType(jt: String): String = jt match {
+ case JAVA_BOOLEAN => "Boolean"
+ case JAVA_BYTE => "Byte"
+ case JAVA_SHORT => "Short"
+ case JAVA_INT => "Integer"
+ case JAVA_LONG => "Long"
+ case JAVA_FLOAT => "Float"
+ case JAVA_DOUBLE => "Double"
+ case other => other
}
+ def boxedType(dt: DataType): String = boxedType(javaType(dt))
+
/**
- * Return the representation of default value for given DataType
+ * Returns the representation of default value for a given Java Type.
*/
- def defaultValue(dt: DataType): String = dt match {
- case BooleanType => "false"
- case FloatType => "-1.0f"
- case ShortType => "(short)-1"
- case LongType => "-1L"
- case ByteType => "(byte)-1"
- case DoubleType => "-1.0"
- case IntegerType => "-1"
- case DateType => "-1"
- case TimestampType => "-1L"
+ def defaultValue(jt: String): String = jt match {
+ case JAVA_BOOLEAN => "false"
+ case JAVA_BYTE => "(byte)-1"
+ case JAVA_SHORT => "(short)-1"
+ case JAVA_INT => "-1"
+ case JAVA_LONG => "-1L"
+ case JAVA_FLOAT => "-1.0f"
+ case JAVA_DOUBLE => "-1.0"
case _ => "null"
}
+ def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
+
/**
- * Generate code for equal expression in Java
+ * Generates code for equal expression in Java.
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
@@ -172,7 +176,7 @@ class CodeGenContext {
}
/**
- * Generate code for compare expression in Java
+ * Generates code for compare expression in Java.
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// java boolean doesn't support > or < operator
@@ -184,25 +188,17 @@ class CodeGenContext {
}
/**
- * List of data types that have special accessors and setters in [[InternalRow]].
+ * List of java data types that have special accessors and setters in [[InternalRow]].
*/
- val nativeTypes =
- Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
+ val primitiveTypes =
+ Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
/**
- * Returns true if the data type has a special accessor and setter in [[InternalRow]].
+ * Returns true if the Java type has a special accessor and setter in [[InternalRow]].
*/
- def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
+ def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
- /**
- * List of data types who's Java type is primitive type
- */
- val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)
-
- /**
- * Returns true if the Java type is primitive type
- */
- def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
+ def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index e362625469..624e1cf4e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -72,54 +72,56 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
}.mkString("\n ")
- val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
+ val specificAccessorFunctions = ctx.primitiveTypes.map { jt =>
val cases = expressions.zipWithIndex.flatMap {
- case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
- List(s"case $i: return c$i;")
- case _ => Nil
+ case (e, i) if ctx.javaType(e.dataType) == jt =>
+ Some(s"case $i: return c$i;")
+ case _ => None
}.mkString("\n ")
if (cases.length > 0) {
+ val getter = "get" + ctx.primitiveTypeName(jt)
s"""
@Override
- public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
+ public $jt $getter(int i) {
if (isNullAt(i)) {
- return ${ctx.defaultValue(dataType)};
+ return ${ctx.defaultValue(jt)};
}
switch (i) {
$cases
}
throw new IllegalArgumentException("Invalid index: " + i
- + " in ${ctx.accessorForType(dataType)}");
+ + " in $getter");
}"""
} else {
""
}
- }.mkString("\n")
+ }.filter(_.length > 0).mkString("\n")
- val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
+ val specificMutatorFunctions = ctx.primitiveTypes.map { jt =>
val cases = expressions.zipWithIndex.flatMap {
- case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
- List(s"case $i: { c$i = value; return; }")
- case _ => Nil
+ case (e, i) if ctx.javaType(e.dataType) == jt =>
+ Some(s"case $i: { c$i = value; return; }")
+ case _ => None
}.mkString("\n ")
if (cases.length > 0) {
+ val setter = "set" + ctx.primitiveTypeName(jt)
s"""
@Override
- public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
+ public void $setter(int i, $jt value) {
nullBits[i] = false;
switch (i) {
$cases
}
throw new IllegalArgumentException("Invalid index: " + i +
- " in ${ctx.mutatorForType(dataType)}");
+ " in $setter}");
}"""
} else {
""
}
- }.mkString("\n")
+ }.filter(_.length > 0).mkString("\n")
val hashValues = expressions.zipWithIndex.map { case (e, i) =>
- val col = newTermName(s"c$i")
+ val col = s"c$i"
val nonNull = e.dataType match {
case BooleanType => s"$col ? 0 : 1"
case ByteType | ShortType | IntegerType | DateType => s"$col"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 44416e79cd..a6225fdafe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.expressions.Substring
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String