aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-15 13:06:38 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 13:06:38 -0700
commit85842760dc4616577162f44cc0fa9db9bd23bd9c (patch)
tree3f0d8c9e0b9cb75c6fed3e2e3d6b5302a384d600 /sql/catalyst
parent785f95586b951d7b05481ee925fb95c20c4d6b6f (diff)
downloadspark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.gz
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.bz2
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.zip
[SPARK-6638] [SQL] Improve performance of StringType in SQL
This PR change the internal representation for StringType from java.lang.String to UTF8String, which is implemented use ArrayByte. This PR should not break any public API, Row.getString() will still return java.lang.String. This is the first step of improve the performance of String in SQL. cc rxin Author: Davies Liu <davies@databricks.com> Closes #5350 from davies/string and squashes the following commits: 3b7bfa8 [Davies Liu] fix schema of AddJar 2772f0d [Davies Liu] fix new test failure 6d776a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 59025c8 [Davies Liu] address comments from @marmbrus 341ec2c [Davies Liu] turn off scala style check in UTF8StringSuite 744788f [Davies Liu] Merge branch 'master' of github.com:apache/spark into string b04a19c [Davies Liu] add comment for getString/setString 08d897b [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 5116b43 [Davies Liu] rollback unrelated changes 1314a37 [Davies Liu] address comments from Yin 867bf50 [Davies Liu] fix String filter push down 13d9d42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 2089d24 [Davies Liu] add hashcode check back ac18ae6 [Davies Liu] address comment fd11364 [Davies Liu] optimize UTF8String 8d17f21 [Davies Liu] fix hive compatibility tests e5fa5b8 [Davies Liu] remove clone in UTF8String 28f3d81 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 28d6f32 [Davies Liu] refactor 537631c [Davies Liu] some comment about Date 9f4c194 [Davies Liu] convert data type for data source 956b0a4 [Davies Liu] fix hive tests 73e4363 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 9dc32d1 [Davies Liu] fix some hive tests 23a766c [Davies Liu] refactor 8b45864 [Davies Liu] fix codegen with UTF8String bb52e44 [Davies Liu] fix scala style c7dd4d2 [Davies Liu] fix some catalyst tests 38c303e [Davies Liu] fix python sql tests 5f9e120 [Davies Liu] fix sql tests 6b499ac [Davies Liu] fix style a85fb27 [Davies Liu] refactor d32abd1 [Davies Liu] fix utf8 for python api 4699c3a [Davies Liu] use Array[Byte] in UTF8String 21f67c6 [Davies Liu] cleanup 685fd07 [Davies Liu] use UTF8String instead of String for StringType
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala90
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala214
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala70
20 files changed, 543 insertions, 157 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index d794f034f5..ac8a782976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.{StructType, DateUtils}
+import org.apache.spark.sql.types.StructType
object Row {
/**
@@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
+ // TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 91976fef6d..d4f9fdacda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -77,6 +77,9 @@ object CatalystTypeConverters {
}
new GenericRowWithSchema(ar, structType)
+ case (d: String, _) =>
+ UTF8String(d)
+
case (d: BigDecimal, _) =>
Decimal(d)
@@ -175,6 +178,11 @@ object CatalystTypeConverters {
case other => other
}
+ case dataType: StringType => (item: Any) => extractOption(item) match {
+ case s: String => UTF8String(s)
+ case other => other
+ }
+
case _ =>
(item: Any) => extractOption(item) match {
case d: BigDecimal => Decimal(d)
@@ -184,6 +192,26 @@ object CatalystTypeConverters {
}
}
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ *
+ * Note: This should be called before do evaluation on Row
+ * (It does not support UDT)
+ * This is used to create an RDD or test results with correct types for Catalyst.
+ */
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: String => UTF8String(s)
+ case d: java.sql.Date => DateUtils.fromJavaDate(d)
+ case d: BigDecimal => Decimal(d)
+ case d: java.math.BigDecimal => Decimal(d)
+ case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
+ case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
+ case m: Map[Any, Any] =>
+ m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
+ case other => other
+ }
+
/**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
@@ -211,6 +239,9 @@ object CatalystTypeConverters {
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)
+ case (s: UTF8String, StringType) =>
+ s.toString()
+
case (other, _) =>
other
}
@@ -262,6 +293,12 @@ object CatalystTypeConverters {
case other => other
}
+ case StringType =>
+ (item: Any) => item match {
+ case s: UTF8String => s.toString()
+ case other => other
+ }
+
case other =>
(item: Any) => item
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 01d5c15122..d9521953ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -138,6 +138,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
+ case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 3aeb964994..35c7f00d4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal.create("NaN", StringType)
+ val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -563,6 +563,10 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
+ // Compatible with Hive
+ case Substring(e, start, len) if e.dataType != StringType =>
+ Substring(Cast(e, StringType), start, len)
+
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 31f1a5fdc7..adf941ab2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._
/** Cast the child expression to the target data type. */
@@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
- case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
- case DateType => buildCast[Int](_, d => DateUtils.toString(d))
- case TimestampType => buildCast[Timestamp](_, timestampToString)
- case _ => buildCast[Any](_, _.toString)
+ case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
+ case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
+ case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
+ case _ => buildCast[Any](_, o => UTF8String(o.toString))
}
// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
- case StringType => buildCast[String](_, _.getBytes("UTF-8"))
+ case StringType => buildCast[UTF8String](_, _.getBytes)
}
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, _.length() != 0)
+ buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
@@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => {
+ buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
+ val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
@@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s =>
- try DateUtils.fromJavaDate(Date.valueOf(s))
+ buildCast[UTF8String](_, s =>
+ try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
@@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toLong catch {
+ buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toInt catch {
+ buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toShort catch {
+ buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toByte catch {
+ buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
+ buildCast[UTF8String](_, s => try {
+ changePrecision(Decimal(s.toString.toDouble), target)
+ } catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toDouble catch {
+ buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toFloat catch {
+ buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 47b6f358ed..3475ed05f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}
- override def update(ordinal: Int, value: Any): Unit = {
- if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
+ override def update(ordinal: Int, value: Any) {
+ if (value == null) {
+ setNullAt(ordinal)
+ } else {
+ values(ordinal).update(value)
+ }
}
- override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
- override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).toString
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d141354a0f..be2c101d63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children
- case expressions.Literal(value: String, dataType) =>
+ case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
+ val $primitiveTerm: ${termForType(dataType)} =
+ org.apache.spark.sql.types.UTF8String(${value.getBytes})
""".children
case expressions.Literal(value: Int, dataType) =>
@@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children
case Cast(child @ DateType(), StringType) =>
- child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
+ child.castOrNull(c =>
+ q"""org.apache.spark.sql.types.UTF8String(
+ org.apache.spark.sql.types.DateUtils.toString($c))""",
+ StringType)
case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- ${eval.primitiveTerm}.toString
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children
+ case EqualTo(e1: BinaryType, e2: BinaryType) =>
+ (e1, e2).evaluateAs (BooleanType) {
+ case (eval1, eval2) =>
+ q"""
+ java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
+ $eval2.asInstanceOf[Array[Byte]])
+ """
+ }
+
case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
@@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
- $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
+ $localLoggerTree.debug(
+ ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
@@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
+ case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
@@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
+ case StringType => q"$destinationRow.update($ordinal, $value)"
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
@@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "String"
+ case StringType => "org.apache.spark.sql.types.UTF8String"
}
protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => ru.Literal(Constant("<uninit>"))
+ case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 69397a73a8..6f572ff959 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val specificAccessorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // getString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) return $elementName" :: Nil
case _ => Nil
}
-
- q"""
- override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ // Row() need this interface to compile
+ case StringType =>
+ q"""
+ override def getString(i: Int): String = {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val specificMutatorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // setString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
-
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ case StringType =>
+ // MutableRow() need this interface to compile
+ q"""
+ override def setString(i: Int, value: String) {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val hashValues = expressions.zipWithIndex.map { case (e,i) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 860b72fad3..67caadb839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
/**
@@ -85,8 +85,11 @@ case class UserDefinedGenerator(
override protected def makeOutput(): Seq[Attribute] = schema
override def eval(input: Row): TraversableOnce[Row] = {
+ // TODO(davies): improve this
+ // Convert the objects into Scala Type before calling function, we need schema to support UDT
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputRow = new InterpretedProjection(children)
- function(inputRow(input))
+ function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0e2d593e94..18cba4cc46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._
object Literal {
@@ -29,7 +30,7 @@ object Literal {
case f: Float => Literal(f, FloatType)
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
- case s: String => Literal(s, StringType)
+ case s: String => Literal(UTF8String(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
@@ -42,7 +43,9 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
- def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
+ def create(v: Any, dataType: DataType): Literal = {
+ Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7e47cb3fff..fcd6352079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val r = right.eval(input)
if (r == null) null
else if (left.dataType != BinaryType) l == r
- else BinaryType.ordering.compare(
- l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0
+ else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0a275b8408..1b62e17ff4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{StructType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
/**
@@ -37,6 +37,7 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
+ // TODO(davies): add setDate() and setDecimal()
}
/**
@@ -114,9 +115,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
override def getString(i: Int): String = {
- values(i).asInstanceOf[String]
+ values(i) match {
+ case null => null
+ case s: String => s
+ case utf8: UTF8String => utf8.toString
+ }
}
+ // TODO(davies): add getDate and getDecimal
+
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
@@ -189,8 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
-
+ override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index acfbbace60..d597bf7ce7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
-import scala.collection.IndexedSeqOptimized
-
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType}
+import org.apache.spark.sql.types._
trait StringRegexExpression {
self: BinaryExpression =>
@@ -60,38 +57,17 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
- val regex = pattern(r.asInstanceOf[String])
+ val regex = pattern(r.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
- matches(regex, l.asInstanceOf[String])
+ matches(regex, l.asInstanceOf[UTF8String].toString)
}
}
}
}
}
-trait CaseConversionExpression {
- self: UnaryExpression =>
-
- type EvaluatedType = Any
-
- def convert(v: String): String
-
- override def foldable: Boolean = child.foldable
- def nullable: Boolean = child.nullable
- def dataType: DataType = StringType
-
- override def eval(input: Row): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- convert(evaluated.toString)
- }
- }
-}
-
/**
* Simple RegEx pattern matching function
*/
@@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
+trait CaseConversionExpression {
+ self: UnaryExpression =>
+
+ type EvaluatedType = Any
+
+ def convert(v: UTF8String): UTF8String
+
+ override def foldable: Boolean = child.foldable
+ def nullable: Boolean = child.nullable
+ def dataType: DataType = StringType
+
+ override def eval(input: Row): Any = {
+ val evaluated = child.eval(input)
+ if (evaluated == null) {
+ null
+ } else {
+ convert(evaluated.asInstanceOf[UTF8String])
+ }
+ }
+}
+
/**
* A function that converts the characters of a string to uppercase.
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toUpperCase()
+ override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def toString: String = s"Upper($child)"
}
@@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toLowerCase()
+ override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def toString: String = s"Lower($child)"
}
@@ -162,15 +159,16 @@ trait StringComparison {
override def nullable: Boolean = left.nullable || right.nullable
- def compare(l: String, r: String): Boolean
+ def compare(l: UTF8String, r: UTF8String): Boolean
override def eval(input: Row): Any = {
- val leftEval = left.eval(input).asInstanceOf[String]
+ val leftEval = left.eval(input)
if(leftEval == null) {
null
} else {
- val rightEval = right.eval(input).asInstanceOf[String]
- if (rightEval == null) null else compare(leftEval, rightEval)
+ val rightEval = right.eval(input)
+ if (rightEval == null) null
+ else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String])
}
}
@@ -184,7 +182,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.contains(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
}
/**
@@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.startsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
}
/**
@@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.endsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
}
/**
@@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
- def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
- (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
- val len = str.length
+ def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
@@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val start = startPos match {
case pos if pos > 0 => pos - 1
- case neg if neg < 0 => len + neg
+ case neg if neg < 0 => length() + neg
case _ => 0
}
@@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
case x => start + x
}
- str.slice(start, end)
+ (start, end)
}
override def eval(input: Row): Any = {
val string = str.eval(input)
-
val po = pos.eval(input)
val ln = len.eval(input)
@@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
null
} else {
val start = po.asInstanceOf[Int]
- val length = ln.asInstanceOf[Int]
-
+ val length = ln.asInstanceOf[Int]
string match {
- case ba: Array[Byte] => slice(ba, start, length)
- case other => slice(other.toString, start, length)
+ case ba: Array[Byte] =>
+ val (st, end) = slicePos(start, length, () => ba.length)
+ ba.slice(st, end)
+ case s: UTF8String =>
+ val (st, end) = slicePos(start, length, () => s.length)
+ s.slice(st, end)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 93e69d409c..7c80634d2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] {
val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case Like(l, Literal(endsWith(pattern), StringType)) =>
- EndsWith(l, Literal(pattern))
- case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case Like(l, Literal(equalTo(pattern), StringType)) =>
- EqualTo(l, Literal(pattern))
+ case Like(l, Literal(utf, StringType)) =>
+ utf.toString match {
+ case startsWith(pattern) if !pattern.endsWith("\\") =>
+ StartsWith(l, Literal(pattern))
+ case endsWith(pattern) =>
+ EndsWith(l, Literal(pattern))
+ case contains(pattern) if !pattern.endsWith("\\") =>
+ Contains(l, Literal(pattern))
+ case equalTo(pattern) =>
+ EqualTo(l, Literal(pattern))
+ case _ =>
+ Like(l, Literal.create(utf, StringType))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
index 504fb05842..d36a49159b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
@@ -40,6 +40,7 @@ object DateUtils {
millisToDays(d.getTime)
}
+ // we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisLocal: Long): Int = {
((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
new file mode 100644
index 0000000000..fc02ba6c9c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -0,0 +1,214 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import java.util.Arrays
+
+/**
+ * A UTF-8 String, as internal representation of StringType in SparkSQL
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+
+final class UTF8String extends Ordered[UTF8String] with Serializable {
+
+ private[this] var bytes: Array[Byte] = _
+
+ /**
+ * Update the UTF8String with String.
+ */
+ def set(str: String): UTF8String = {
+ bytes = str.getBytes("utf-8")
+ this
+ }
+
+ /**
+ * Update the UTF8String with Array[Byte], which should be encoded in UTF-8
+ */
+ def set(bytes: Array[Byte]): UTF8String = {
+ this.bytes = bytes
+ this
+ }
+
+ /**
+ * Return the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ @inline
+ private[this] def numOfBytes(b: Byte): Int = {
+ val offset = (b & 0xFF) - 192
+ if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
+ }
+
+ /**
+ * Return the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ def length(): Int = {
+ var len = 0
+ var i: Int = 0
+ while (i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ len += 1
+ }
+ len
+ }
+
+ def getBytes: Array[Byte] = {
+ bytes
+ }
+
+ /**
+ * Return a substring of this,
+ * @param start the position of first code point
+ * @param until the position after last code point
+ */
+ def slice(start: Int, until: Int): UTF8String = {
+ if (until <= start || start >= bytes.length || bytes == null) {
+ new UTF8String
+ }
+
+ var c = 0
+ var i: Int = 0
+ while (c < start && i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ c += 1
+ }
+ var j = i
+ while (c < until && j < bytes.length) {
+ j += numOfBytes(bytes(j))
+ c += 1
+ }
+ UTF8String(Arrays.copyOfRange(bytes, i, j))
+ }
+
+ def contains(sub: UTF8String): Boolean = {
+ val b = sub.getBytes
+ if (b.length == 0) {
+ return true
+ }
+ var i: Int = 0
+ while (i <= bytes.length - b.length) {
+ // In worst case, it's O(N*K), but should works fine with SQL
+ if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
+ return true
+ }
+ i += 1
+ }
+ false
+ }
+
+ def startsWith(prefix: UTF8String): Boolean = {
+ val b = prefix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b)
+ }
+
+ def endsWith(suffix: UTF8String): Boolean = {
+ val b = suffix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b)
+ }
+
+ def toUpperCase(): UTF8String = {
+ // upper case depends on locale, fallback to String.
+ UTF8String(toString().toUpperCase)
+ }
+
+ def toLowerCase(): UTF8String = {
+ // lower case depends on locale, fallback to String.
+ UTF8String(toString().toLowerCase)
+ }
+
+ override def toString(): String = {
+ new String(bytes, "utf-8")
+ }
+
+ override def clone(): UTF8String = new UTF8String().set(this.bytes)
+
+ override def compare(other: UTF8String): Int = {
+ var i: Int = 0
+ val b = other.getBytes
+ while (i < bytes.length && i < b.length) {
+ val res = bytes(i).compareTo(b(i))
+ if (res != 0) return res
+ i += 1
+ }
+ bytes.length - b.length
+ }
+
+ override def compareTo(other: UTF8String): Int = {
+ compare(other)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case s: UTF8String =>
+ Arrays.equals(bytes, s.getBytes)
+ case s: String =>
+ // This is only used for Catalyst unit tests
+ // fail fast
+ bytes.length >= s.length && length() == s.length && toString() == s
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = {
+ Arrays.hashCode(bytes)
+ }
+}
+
+object UTF8String {
+ // number of tailing bytes in a UTF8 sequence for a code point
+ // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
+ private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6)
+
+ /**
+ * Create a UTF-8 String from String
+ */
+ def apply(s: String): UTF8String = {
+ if (s != null) {
+ new UTF8String().set(s)
+ } else{
+ null
+ }
+ }
+
+ /**
+ * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8
+ */
+ def apply(bytes: Array[Byte]): UTF8String = {
+ if (bytes != null) {
+ new UTF8String().set(bytes)
+ } else {
+ null
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index cdf2bc68d9..c6fb22c26b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = String
+ private[sql] type JvmType = UTF8String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/**
* Convert the user type to a SQL datum
*
- * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
- * where we need to convert Any to UserType.
+ * TODO: Can we make this take obj: UserType? The issue is in
+ * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d4362a91d9..76298f03c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.FunSuite
import org.scalatest.Matchers._
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
@@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite {
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
+ def create_row(values: Any*): Row = {
+ new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ }
+
test("literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(true), true)
@@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("LIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null)))
- checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**")))
- checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b")))
- checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b")))
-
- checkEvaluation(Literal.create(null, StringType) like regEx, null,
- new GenericRow(Array[Any]("bc%")))
+ checkEvaluation("abcd" like regEx, null, create_row(null))
+ checkEvaluation("abdef" like regEx, true, create_row("abdef"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, true, create_row("a_%b"))
+ checkEvaluation("addb" like regEx, false, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, false, create_row("a%\\%b"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b"))
+ checkEvaluation("addb" like regEx, true, create_row("a%"))
+ checkEvaluation("addb" like regEx, false, create_row("**"))
+ checkEvaluation("abc" like regEx, true, create_row("a%"))
+ checkEvaluation("abc" like regEx, false, create_row("b%"))
+ checkEvaluation("abc" like regEx, false, create_row("bc%"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a_b"))
+ checkEvaluation("ab" like regEx, true, create_row("a%b"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a%b"))
+
+ checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%"))
}
test("RLIKE literal Regular Expression") {
@@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("RLIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c")))
- checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo")))
- checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$")))
- checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n")))
+ checkEvaluation("abdef" rlike regEx, true, create_row("abdef"))
+ checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c"))
+ checkEvaluation("fofo" rlike regEx, true, create_row("^fo"))
+ checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$"))
+ checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n"))
intercept[java.util.regex.PatternSyntaxException] {
- evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
+ evaluate("abbbbc" rlike regEx, create_row("**"))
}
}
@@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("null checking") {
- val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
+ val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
@@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("case when") {
- val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+ val row = create_row(null, false, true, "a", "b", "c")
val c1 = 'a.boolean.at(0)
val c2 = 'a.boolean.at(1)
val c3 = 'a.boolean.at(2)
@@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("complex type") {
- val row = new GenericRow(Array[Any](
- "^Ba*n", // 0
- null.asInstanceOf[String], // 1
- new GenericRow(Array[Any]("aa", "bb")), // 2
- Map("aa"->"bb"), // 3
- Seq("aa", "bb") // 4
- ))
+ val row = create_row(
+ "^Ba*n", // 0
+ null.asInstanceOf[UTF8String], // 1
+ create_row("aa", "bb"), // 2
+ Map("aa"->"bb"), // 3
+ Seq("aa", "bb") // 4
+ )
val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
@@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("arithmetic") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("fractional arithmetic") {
- val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null))
+ val row = create_row(1.1, 2.0, 3.1, null)
val c1 = 'a.double.at(0)
val c2 = 'a.double.at(1)
val c3 = 'a.double.at(2)
@@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("BinaryComparison") {
- val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
+ val row = create_row(1, 2, 3, null, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("StringComparison") {
- val row = new GenericRow(Array[Any]("abc", null))
+ val row = create_row("abc", null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
@@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("Substring") {
- val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+ val row = create_row("example", "example".toArray.map(_.toByte))
val s = 'a.string.at(0)
@@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
// substring(null, _, _) -> null
checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)),
- null, new GenericRow(Array[Any](null)))
+ null, create_row(null))
// substring(_, null, _) -> null
checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)),
@@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("SQRT") {
val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val rowSequence = inputSequence.map(l => create_row(l.toDouble))
val d = 'a.double.at(0)
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(-1), null, EmptyRow)
checkEvaluation(Sqrt(-1.5), null, EmptyRow)
}
test("Bitwise operations") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index 275ea2627e..bcc0c404d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
@@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
new file mode 100644
index 0000000000..a22aa6f244
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
@@ -0,0 +1,70 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+// scalastyle:off
+class UTF8StringSuite extends FunSuite {
+ test("basic") {
+ def check(str: String, len: Int) {
+
+ assert(UTF8String(str).length == len)
+ assert(UTF8String(str.getBytes("utf8")).length() == len)
+
+ assert(UTF8String(str) == str)
+ assert(UTF8String(str.getBytes("utf8")) == str)
+ assert(UTF8String(str).toString == str)
+ assert(UTF8String(str.getBytes("utf8")).toString == str)
+ assert(UTF8String(str.getBytes("utf8")) == UTF8String(str))
+
+ assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode())
+ }
+
+ check("hello", 5)
+ check("世 界", 3)
+ }
+
+ test("contains") {
+ assert(UTF8String("hello").contains(UTF8String("ello")))
+ assert(!UTF8String("hello").contains(UTF8String("vello")))
+ assert(UTF8String("大千世界").contains(UTF8String("千世")))
+ assert(!UTF8String("大千世界").contains(UTF8String("世千")))
+ }
+
+ test("prefix") {
+ assert(UTF8String("hello").startsWith(UTF8String("hell")))
+ assert(!UTF8String("hello").startsWith(UTF8String("ell")))
+ assert(UTF8String("大千世界").startsWith(UTF8String("大千")))
+ assert(!UTF8String("大千世界").startsWith(UTF8String("千")))
+ }
+
+ test("suffix") {
+ assert(UTF8String("hello").endsWith(UTF8String("ello")))
+ assert(!UTF8String("hello").endsWith(UTF8String("ellov")))
+ assert(UTF8String("大千世界").endsWith(UTF8String("世界")))
+ assert(!UTF8String("大千世界").endsWith(UTF8String("世")))
+ }
+
+ test("slice") {
+ assert(UTF8String("hello").slice(1, 3) == UTF8String("el"))
+ assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大"))
+ assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世"))
+ assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界"))
+ }
+}