aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala90
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala214
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala70
20 files changed, 543 insertions, 157 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index d794f034f5..ac8a782976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.{StructType, DateUtils}
+import org.apache.spark.sql.types.StructType
object Row {
/**
@@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
+ // TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 91976fef6d..d4f9fdacda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -77,6 +77,9 @@ object CatalystTypeConverters {
}
new GenericRowWithSchema(ar, structType)
+ case (d: String, _) =>
+ UTF8String(d)
+
case (d: BigDecimal, _) =>
Decimal(d)
@@ -175,6 +178,11 @@ object CatalystTypeConverters {
case other => other
}
+ case dataType: StringType => (item: Any) => extractOption(item) match {
+ case s: String => UTF8String(s)
+ case other => other
+ }
+
case _ =>
(item: Any) => extractOption(item) match {
case d: BigDecimal => Decimal(d)
@@ -184,6 +192,26 @@ object CatalystTypeConverters {
}
}
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ *
+ * Note: This should be called before do evaluation on Row
+ * (It does not support UDT)
+ * This is used to create an RDD or test results with correct types for Catalyst.
+ */
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: String => UTF8String(s)
+ case d: java.sql.Date => DateUtils.fromJavaDate(d)
+ case d: BigDecimal => Decimal(d)
+ case d: java.math.BigDecimal => Decimal(d)
+ case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
+ case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
+ case m: Map[Any, Any] =>
+ m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
+ case other => other
+ }
+
/**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
@@ -211,6 +239,9 @@ object CatalystTypeConverters {
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)
+ case (s: UTF8String, StringType) =>
+ s.toString()
+
case (other, _) =>
other
}
@@ -262,6 +293,12 @@ object CatalystTypeConverters {
case other => other
}
+ case StringType =>
+ (item: Any) => item match {
+ case s: UTF8String => s.toString()
+ case other => other
+ }
+
case other =>
(item: Any) => item
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 01d5c15122..d9521953ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -138,6 +138,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
+ case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 3aeb964994..35c7f00d4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal.create("NaN", StringType)
+ val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -563,6 +563,10 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
+ // Compatible with Hive
+ case Substring(e, start, len) if e.dataType != StringType =>
+ Substring(Cast(e, StringType), start, len)
+
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 31f1a5fdc7..adf941ab2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._
/** Cast the child expression to the target data type. */
@@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
- case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
- case DateType => buildCast[Int](_, d => DateUtils.toString(d))
- case TimestampType => buildCast[Timestamp](_, timestampToString)
- case _ => buildCast[Any](_, _.toString)
+ case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
+ case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
+ case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
+ case _ => buildCast[Any](_, o => UTF8String(o.toString))
}
// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
- case StringType => buildCast[String](_, _.getBytes("UTF-8"))
+ case StringType => buildCast[UTF8String](_, _.getBytes)
}
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, _.length() != 0)
+ buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
@@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => {
+ buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
+ val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
@@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s =>
- try DateUtils.fromJavaDate(Date.valueOf(s))
+ buildCast[UTF8String](_, s =>
+ try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
@@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toLong catch {
+ buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toInt catch {
+ buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toShort catch {
+ buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toByte catch {
+ buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
+ buildCast[UTF8String](_, s => try {
+ changePrecision(Decimal(s.toString.toDouble), target)
+ } catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toDouble catch {
+ buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toFloat catch {
+ buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 47b6f358ed..3475ed05f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}
- override def update(ordinal: Int, value: Any): Unit = {
- if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
+ override def update(ordinal: Int, value: Any) {
+ if (value == null) {
+ setNullAt(ordinal)
+ } else {
+ values(ordinal).update(value)
+ }
}
- override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
- override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).toString
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d141354a0f..be2c101d63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children
- case expressions.Literal(value: String, dataType) =>
+ case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
+ val $primitiveTerm: ${termForType(dataType)} =
+ org.apache.spark.sql.types.UTF8String(${value.getBytes})
""".children
case expressions.Literal(value: Int, dataType) =>
@@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children
case Cast(child @ DateType(), StringType) =>
- child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
+ child.castOrNull(c =>
+ q"""org.apache.spark.sql.types.UTF8String(
+ org.apache.spark.sql.types.DateUtils.toString($c))""",
+ StringType)
case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- ${eval.primitiveTerm}.toString
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children
+ case EqualTo(e1: BinaryType, e2: BinaryType) =>
+ (e1, e2).evaluateAs (BooleanType) {
+ case (eval1, eval2) =>
+ q"""
+ java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
+ $eval2.asInstanceOf[Array[Byte]])
+ """
+ }
+
case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
@@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
- $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
+ $localLoggerTree.debug(
+ ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
@@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
+ case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
@@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
+ case StringType => q"$destinationRow.update($ordinal, $value)"
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
@@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "String"
+ case StringType => "org.apache.spark.sql.types.UTF8String"
}
protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => ru.Literal(Constant("<uninit>"))
+ case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 69397a73a8..6f572ff959 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val specificAccessorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // getString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) return $elementName" :: Nil
case _ => Nil
}
-
- q"""
- override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ // Row() need this interface to compile
+ case StringType =>
+ q"""
+ override def getString(i: Int): String = {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val specificMutatorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // setString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
-
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ case StringType =>
+ // MutableRow() need this interface to compile
+ q"""
+ override def setString(i: Int, value: String) {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val hashValues = expressions.zipWithIndex.map { case (e,i) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 860b72fad3..67caadb839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
/**
@@ -85,8 +85,11 @@ case class UserDefinedGenerator(
override protected def makeOutput(): Seq[Attribute] = schema
override def eval(input: Row): TraversableOnce[Row] = {
+ // TODO(davies): improve this
+ // Convert the objects into Scala Type before calling function, we need schema to support UDT
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputRow = new InterpretedProjection(children)
- function(inputRow(input))
+ function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0e2d593e94..18cba4cc46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._
object Literal {
@@ -29,7 +30,7 @@ object Literal {
case f: Float => Literal(f, FloatType)
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
- case s: String => Literal(s, StringType)
+ case s: String => Literal(UTF8String(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
@@ -42,7 +43,9 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
- def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
+ def create(v: Any, dataType: DataType): Literal = {
+ Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7e47cb3fff..fcd6352079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val r = right.eval(input)
if (r == null) null
else if (left.dataType != BinaryType) l == r
- else BinaryType.ordering.compare(
- l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0
+ else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0a275b8408..1b62e17ff4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{StructType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
/**
@@ -37,6 +37,7 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
+ // TODO(davies): add setDate() and setDecimal()
}
/**
@@ -114,9 +115,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
override def getString(i: Int): String = {
- values(i).asInstanceOf[String]
+ values(i) match {
+ case null => null
+ case s: String => s
+ case utf8: UTF8String => utf8.toString
+ }
}
+ // TODO(davies): add getDate and getDecimal
+
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
@@ -189,8 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
-
+ override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index acfbbace60..d597bf7ce7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
-import scala.collection.IndexedSeqOptimized
-
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType}
+import org.apache.spark.sql.types._
trait StringRegexExpression {
self: BinaryExpression =>
@@ -60,38 +57,17 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
- val regex = pattern(r.asInstanceOf[String])
+ val regex = pattern(r.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
- matches(regex, l.asInstanceOf[String])
+ matches(regex, l.asInstanceOf[UTF8String].toString)
}
}
}
}
}
-trait CaseConversionExpression {
- self: UnaryExpression =>
-
- type EvaluatedType = Any
-
- def convert(v: String): String
-
- override def foldable: Boolean = child.foldable
- def nullable: Boolean = child.nullable
- def dataType: DataType = StringType
-
- override def eval(input: Row): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- convert(evaluated.toString)
- }
- }
-}
-
/**
* Simple RegEx pattern matching function
*/
@@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
+trait CaseConversionExpression {
+ self: UnaryExpression =>
+
+ type EvaluatedType = Any
+
+ def convert(v: UTF8String): UTF8String
+
+ override def foldable: Boolean = child.foldable
+ def nullable: Boolean = child.nullable
+ def dataType: DataType = StringType
+
+ override def eval(input: Row): Any = {
+ val evaluated = child.eval(input)
+ if (evaluated == null) {
+ null
+ } else {
+ convert(evaluated.asInstanceOf[UTF8String])
+ }
+ }
+}
+
/**
* A function that converts the characters of a string to uppercase.
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toUpperCase()
+ override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def toString: String = s"Upper($child)"
}
@@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toLowerCase()
+ override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def toString: String = s"Lower($child)"
}
@@ -162,15 +159,16 @@ trait StringComparison {
override def nullable: Boolean = left.nullable || right.nullable
- def compare(l: String, r: String): Boolean
+ def compare(l: UTF8String, r: UTF8String): Boolean
override def eval(input: Row): Any = {
- val leftEval = left.eval(input).asInstanceOf[String]
+ val leftEval = left.eval(input)
if(leftEval == null) {
null
} else {
- val rightEval = right.eval(input).asInstanceOf[String]
- if (rightEval == null) null else compare(leftEval, rightEval)
+ val rightEval = right.eval(input)
+ if (rightEval == null) null
+ else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String])
}
}
@@ -184,7 +182,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.contains(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
}
/**
@@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.startsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
}
/**
@@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.endsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
}
/**
@@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
- def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
- (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
- val len = str.length
+ def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
@@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val start = startPos match {
case pos if pos > 0 => pos - 1
- case neg if neg < 0 => len + neg
+ case neg if neg < 0 => length() + neg
case _ => 0
}
@@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
case x => start + x
}
- str.slice(start, end)
+ (start, end)
}
override def eval(input: Row): Any = {
val string = str.eval(input)
-
val po = pos.eval(input)
val ln = len.eval(input)
@@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
null
} else {
val start = po.asInstanceOf[Int]
- val length = ln.asInstanceOf[Int]
-
+ val length = ln.asInstanceOf[Int]
string match {
- case ba: Array[Byte] => slice(ba, start, length)
- case other => slice(other.toString, start, length)
+ case ba: Array[Byte] =>
+ val (st, end) = slicePos(start, length, () => ba.length)
+ ba.slice(st, end)
+ case s: UTF8String =>
+ val (st, end) = slicePos(start, length, () => s.length)
+ s.slice(st, end)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 93e69d409c..7c80634d2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] {
val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case Like(l, Literal(endsWith(pattern), StringType)) =>
- EndsWith(l, Literal(pattern))
- case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case Like(l, Literal(equalTo(pattern), StringType)) =>
- EqualTo(l, Literal(pattern))
+ case Like(l, Literal(utf, StringType)) =>
+ utf.toString match {
+ case startsWith(pattern) if !pattern.endsWith("\\") =>
+ StartsWith(l, Literal(pattern))
+ case endsWith(pattern) =>
+ EndsWith(l, Literal(pattern))
+ case contains(pattern) if !pattern.endsWith("\\") =>
+ Contains(l, Literal(pattern))
+ case equalTo(pattern) =>
+ EqualTo(l, Literal(pattern))
+ case _ =>
+ Like(l, Literal.create(utf, StringType))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
index 504fb05842..d36a49159b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
@@ -40,6 +40,7 @@ object DateUtils {
millisToDays(d.getTime)
}
+ // we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisLocal: Long): Int = {
((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
new file mode 100644
index 0000000000..fc02ba6c9c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -0,0 +1,214 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import java.util.Arrays
+
+/**
+ * A UTF-8 String, as internal representation of StringType in SparkSQL
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+
+final class UTF8String extends Ordered[UTF8String] with Serializable {
+
+ private[this] var bytes: Array[Byte] = _
+
+ /**
+ * Update the UTF8String with String.
+ */
+ def set(str: String): UTF8String = {
+ bytes = str.getBytes("utf-8")
+ this
+ }
+
+ /**
+ * Update the UTF8String with Array[Byte], which should be encoded in UTF-8
+ */
+ def set(bytes: Array[Byte]): UTF8String = {
+ this.bytes = bytes
+ this
+ }
+
+ /**
+ * Return the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ @inline
+ private[this] def numOfBytes(b: Byte): Int = {
+ val offset = (b & 0xFF) - 192
+ if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
+ }
+
+ /**
+ * Return the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ def length(): Int = {
+ var len = 0
+ var i: Int = 0
+ while (i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ len += 1
+ }
+ len
+ }
+
+ def getBytes: Array[Byte] = {
+ bytes
+ }
+
+ /**
+ * Return a substring of this,
+ * @param start the position of first code point
+ * @param until the position after last code point
+ */
+ def slice(start: Int, until: Int): UTF8String = {
+ if (until <= start || start >= bytes.length || bytes == null) {
+ new UTF8String
+ }
+
+ var c = 0
+ var i: Int = 0
+ while (c < start && i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ c += 1
+ }
+ var j = i
+ while (c < until && j < bytes.length) {
+ j += numOfBytes(bytes(j))
+ c += 1
+ }
+ UTF8String(Arrays.copyOfRange(bytes, i, j))
+ }
+
+ def contains(sub: UTF8String): Boolean = {
+ val b = sub.getBytes
+ if (b.length == 0) {
+ return true
+ }
+ var i: Int = 0
+ while (i <= bytes.length - b.length) {
+ // In worst case, it's O(N*K), but should works fine with SQL
+ if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
+ return true
+ }
+ i += 1
+ }
+ false
+ }
+
+ def startsWith(prefix: UTF8String): Boolean = {
+ val b = prefix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b)
+ }
+
+ def endsWith(suffix: UTF8String): Boolean = {
+ val b = suffix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b)
+ }
+
+ def toUpperCase(): UTF8String = {
+ // upper case depends on locale, fallback to String.
+ UTF8String(toString().toUpperCase)
+ }
+
+ def toLowerCase(): UTF8String = {
+ // lower case depends on locale, fallback to String.
+ UTF8String(toString().toLowerCase)
+ }
+
+ override def toString(): String = {
+ new String(bytes, "utf-8")
+ }
+
+ override def clone(): UTF8String = new UTF8String().set(this.bytes)
+
+ override def compare(other: UTF8String): Int = {
+ var i: Int = 0
+ val b = other.getBytes
+ while (i < bytes.length && i < b.length) {
+ val res = bytes(i).compareTo(b(i))
+ if (res != 0) return res
+ i += 1
+ }
+ bytes.length - b.length
+ }
+
+ override def compareTo(other: UTF8String): Int = {
+ compare(other)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case s: UTF8String =>
+ Arrays.equals(bytes, s.getBytes)
+ case s: String =>
+ // This is only used for Catalyst unit tests
+ // fail fast
+ bytes.length >= s.length && length() == s.length && toString() == s
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = {
+ Arrays.hashCode(bytes)
+ }
+}
+
+object UTF8String {
+ // number of tailing bytes in a UTF8 sequence for a code point
+ // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
+ private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6)
+
+ /**
+ * Create a UTF-8 String from String
+ */
+ def apply(s: String): UTF8String = {
+ if (s != null) {
+ new UTF8String().set(s)
+ } else{
+ null
+ }
+ }
+
+ /**
+ * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8
+ */
+ def apply(bytes: Array[Byte]): UTF8String = {
+ if (bytes != null) {
+ new UTF8String().set(bytes)
+ } else {
+ null
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index cdf2bc68d9..c6fb22c26b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = String
+ private[sql] type JvmType = UTF8String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/**
* Convert the user type to a SQL datum
*
- * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
- * where we need to convert Any to UserType.
+ * TODO: Can we make this take obj: UserType? The issue is in
+ * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d4362a91d9..76298f03c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.FunSuite
import org.scalatest.Matchers._
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
@@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite {
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
+ def create_row(values: Any*): Row = {
+ new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ }
+
test("literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(true), true)
@@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("LIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null)))
- checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**")))
- checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b")))
- checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b")))
-
- checkEvaluation(Literal.create(null, StringType) like regEx, null,
- new GenericRow(Array[Any]("bc%")))
+ checkEvaluation("abcd" like regEx, null, create_row(null))
+ checkEvaluation("abdef" like regEx, true, create_row("abdef"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, true, create_row("a_%b"))
+ checkEvaluation("addb" like regEx, false, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, false, create_row("a%\\%b"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b"))
+ checkEvaluation("addb" like regEx, true, create_row("a%"))
+ checkEvaluation("addb" like regEx, false, create_row("**"))
+ checkEvaluation("abc" like regEx, true, create_row("a%"))
+ checkEvaluation("abc" like regEx, false, create_row("b%"))
+ checkEvaluation("abc" like regEx, false, create_row("bc%"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a_b"))
+ checkEvaluation("ab" like regEx, true, create_row("a%b"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a%b"))
+
+ checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%"))
}
test("RLIKE literal Regular Expression") {
@@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("RLIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c")))
- checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo")))
- checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$")))
- checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n")))
+ checkEvaluation("abdef" rlike regEx, true, create_row("abdef"))
+ checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c"))
+ checkEvaluation("fofo" rlike regEx, true, create_row("^fo"))
+ checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$"))
+ checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n"))
intercept[java.util.regex.PatternSyntaxException] {
- evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
+ evaluate("abbbbc" rlike regEx, create_row("**"))
}
}
@@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("null checking") {
- val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
+ val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
@@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("case when") {
- val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+ val row = create_row(null, false, true, "a", "b", "c")
val c1 = 'a.boolean.at(0)
val c2 = 'a.boolean.at(1)
val c3 = 'a.boolean.at(2)
@@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("complex type") {
- val row = new GenericRow(Array[Any](
- "^Ba*n", // 0
- null.asInstanceOf[String], // 1
- new GenericRow(Array[Any]("aa", "bb")), // 2
- Map("aa"->"bb"), // 3
- Seq("aa", "bb") // 4
- ))
+ val row = create_row(
+ "^Ba*n", // 0
+ null.asInstanceOf[UTF8String], // 1
+ create_row("aa", "bb"), // 2
+ Map("aa"->"bb"), // 3
+ Seq("aa", "bb") // 4
+ )
val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
@@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("arithmetic") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("fractional arithmetic") {
- val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null))
+ val row = create_row(1.1, 2.0, 3.1, null)
val c1 = 'a.double.at(0)
val c2 = 'a.double.at(1)
val c3 = 'a.double.at(2)
@@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("BinaryComparison") {
- val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
+ val row = create_row(1, 2, 3, null, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("StringComparison") {
- val row = new GenericRow(Array[Any]("abc", null))
+ val row = create_row("abc", null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
@@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("Substring") {
- val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+ val row = create_row("example", "example".toArray.map(_.toByte))
val s = 'a.string.at(0)
@@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
// substring(null, _, _) -> null
checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)),
- null, new GenericRow(Array[Any](null)))
+ null, create_row(null))
// substring(_, null, _) -> null
checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)),
@@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("SQRT") {
val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val rowSequence = inputSequence.map(l => create_row(l.toDouble))
val d = 'a.double.at(0)
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(-1), null, EmptyRow)
checkEvaluation(Sqrt(-1.5), null, EmptyRow)
}
test("Bitwise operations") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index 275ea2627e..bcc0c404d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
@@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
new file mode 100644
index 0000000000..a22aa6f244
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
@@ -0,0 +1,70 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+// scalastyle:off
+class UTF8StringSuite extends FunSuite {
+ test("basic") {
+ def check(str: String, len: Int) {
+
+ assert(UTF8String(str).length == len)
+ assert(UTF8String(str.getBytes("utf8")).length() == len)
+
+ assert(UTF8String(str) == str)
+ assert(UTF8String(str.getBytes("utf8")) == str)
+ assert(UTF8String(str).toString == str)
+ assert(UTF8String(str.getBytes("utf8")).toString == str)
+ assert(UTF8String(str.getBytes("utf8")) == UTF8String(str))
+
+ assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode())
+ }
+
+ check("hello", 5)
+ check("世 界", 3)
+ }
+
+ test("contains") {
+ assert(UTF8String("hello").contains(UTF8String("ello")))
+ assert(!UTF8String("hello").contains(UTF8String("vello")))
+ assert(UTF8String("大千世界").contains(UTF8String("千世")))
+ assert(!UTF8String("大千世界").contains(UTF8String("世千")))
+ }
+
+ test("prefix") {
+ assert(UTF8String("hello").startsWith(UTF8String("hell")))
+ assert(!UTF8String("hello").startsWith(UTF8String("ell")))
+ assert(UTF8String("大千世界").startsWith(UTF8String("大千")))
+ assert(!UTF8String("大千世界").startsWith(UTF8String("千")))
+ }
+
+ test("suffix") {
+ assert(UTF8String("hello").endsWith(UTF8String("ello")))
+ assert(!UTF8String("hello").endsWith(UTF8String("ellov")))
+ assert(UTF8String("大千世界").endsWith(UTF8String("世界")))
+ assert(!UTF8String("大千世界").endsWith(UTF8String("世")))
+ }
+
+ test("slice") {
+ assert(UTF8String("hello").slice(1, 3) == UTF8String("el"))
+ assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大"))
+ assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世"))
+ assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界"))
+ }
+}