aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-15 13:06:38 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 13:06:38 -0700
commit85842760dc4616577162f44cc0fa9db9bd23bd9c (patch)
tree3f0d8c9e0b9cb75c6fed3e2e3d6b5302a384d600
parent785f95586b951d7b05481ee925fb95c20c4d6b6f (diff)
downloadspark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.gz
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.bz2
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.zip
[SPARK-6638] [SQL] Improve performance of StringType in SQL
This PR change the internal representation for StringType from java.lang.String to UTF8String, which is implemented use ArrayByte. This PR should not break any public API, Row.getString() will still return java.lang.String. This is the first step of improve the performance of String in SQL. cc rxin Author: Davies Liu <davies@databricks.com> Closes #5350 from davies/string and squashes the following commits: 3b7bfa8 [Davies Liu] fix schema of AddJar 2772f0d [Davies Liu] fix new test failure 6d776a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 59025c8 [Davies Liu] address comments from @marmbrus 341ec2c [Davies Liu] turn off scala style check in UTF8StringSuite 744788f [Davies Liu] Merge branch 'master' of github.com:apache/spark into string b04a19c [Davies Liu] add comment for getString/setString 08d897b [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 5116b43 [Davies Liu] rollback unrelated changes 1314a37 [Davies Liu] address comments from Yin 867bf50 [Davies Liu] fix String filter push down 13d9d42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 2089d24 [Davies Liu] add hashcode check back ac18ae6 [Davies Liu] address comment fd11364 [Davies Liu] optimize UTF8String 8d17f21 [Davies Liu] fix hive compatibility tests e5fa5b8 [Davies Liu] remove clone in UTF8String 28f3d81 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 28d6f32 [Davies Liu] refactor 537631c [Davies Liu] some comment about Date 9f4c194 [Davies Liu] convert data type for data source 956b0a4 [Davies Liu] fix hive tests 73e4363 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 9dc32d1 [Davies Liu] fix some hive tests 23a766c [Davies Liu] refactor 8b45864 [Davies Liu] fix codegen with UTF8String bb52e44 [Davies Liu] fix scala style c7dd4d2 [Davies Liu] fix some catalyst tests 38c303e [Davies Liu] fix python sql tests 5f9e120 [Davies Liu] fix sql tests 6b499ac [Davies Liu] fix style a85fb27 [Davies Liu] refactor d32abd1 [Davies Liu] fix utf8 for python api 4699c3a [Davies Liu] use Array[Byte] in UTF8String 21f67c6 [Davies Liu] cleanup 685fd07 [Davies Liu] use UTF8String instead of String for StringType
-rw-r--r--python/pyspark/sql/dataframe.py10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala90
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala214
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala70
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala13
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala17
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala10
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala4
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala36
50 files changed, 742 insertions, 298 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ef91a9c4f5..f2c3b74a18 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -456,7 +456,7 @@ class DataFrame(object):
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
- [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
+ [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
"""
if joinExprs is None:
@@ -637,9 +637,9 @@ class DataFrame(object):
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
- [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
+ [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
- [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
+ [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
@@ -867,11 +867,11 @@ class GroupedData(object):
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
- [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+ [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
- [Row(MIN(age)=5), Row(MIN(age)=2)]
+ [Row(MIN(age)=2), Row(MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index d794f034f5..ac8a782976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.{StructType, DateUtils}
+import org.apache.spark.sql.types.StructType
object Row {
/**
@@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
+ // TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 91976fef6d..d4f9fdacda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -77,6 +77,9 @@ object CatalystTypeConverters {
}
new GenericRowWithSchema(ar, structType)
+ case (d: String, _) =>
+ UTF8String(d)
+
case (d: BigDecimal, _) =>
Decimal(d)
@@ -175,6 +178,11 @@ object CatalystTypeConverters {
case other => other
}
+ case dataType: StringType => (item: Any) => extractOption(item) match {
+ case s: String => UTF8String(s)
+ case other => other
+ }
+
case _ =>
(item: Any) => extractOption(item) match {
case d: BigDecimal => Decimal(d)
@@ -184,6 +192,26 @@ object CatalystTypeConverters {
}
}
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ *
+ * Note: This should be called before do evaluation on Row
+ * (It does not support UDT)
+ * This is used to create an RDD or test results with correct types for Catalyst.
+ */
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: String => UTF8String(s)
+ case d: java.sql.Date => DateUtils.fromJavaDate(d)
+ case d: BigDecimal => Decimal(d)
+ case d: java.math.BigDecimal => Decimal(d)
+ case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
+ case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
+ case m: Map[Any, Any] =>
+ m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
+ case other => other
+ }
+
/**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
@@ -211,6 +239,9 @@ object CatalystTypeConverters {
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)
+ case (s: UTF8String, StringType) =>
+ s.toString()
+
case (other, _) =>
other
}
@@ -262,6 +293,12 @@ object CatalystTypeConverters {
case other => other
}
+ case StringType =>
+ (item: Any) => item match {
+ case s: UTF8String => s.toString()
+ case other => other
+ }
+
case other =>
(item: Any) => item
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 01d5c15122..d9521953ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -138,6 +138,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
+ case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 3aeb964994..35c7f00d4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal.create("NaN", StringType)
+ val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -563,6 +563,10 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
+ // Compatible with Hive
+ case Substring(e, start, len) if e.dataType != StringType =>
+ Substring(Cast(e, StringType), start, len)
+
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 31f1a5fdc7..adf941ab2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._
/** Cast the child expression to the target data type. */
@@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
- case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
- case DateType => buildCast[Int](_, d => DateUtils.toString(d))
- case TimestampType => buildCast[Timestamp](_, timestampToString)
- case _ => buildCast[Any](_, _.toString)
+ case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
+ case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
+ case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
+ case _ => buildCast[Any](_, o => UTF8String(o.toString))
}
// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
- case StringType => buildCast[String](_, _.getBytes("UTF-8"))
+ case StringType => buildCast[UTF8String](_, _.getBytes)
}
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, _.length() != 0)
+ buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
@@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => {
+ buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
+ val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
@@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s =>
- try DateUtils.fromJavaDate(Date.valueOf(s))
+ buildCast[UTF8String](_, s =>
+ try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
@@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toLong catch {
+ buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toInt catch {
+ buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toShort catch {
+ buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toByte catch {
+ buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
+ buildCast[UTF8String](_, s => try {
+ changePrecision(Decimal(s.toString.toDouble), target)
+ } catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toDouble catch {
+ buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toFloat catch {
+ buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 47b6f358ed..3475ed05f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}
- override def update(ordinal: Int, value: Any): Unit = {
- if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
+ override def update(ordinal: Int, value: Any) {
+ if (value == null) {
+ setNullAt(ordinal)
+ } else {
+ values(ordinal).update(value)
+ }
}
- override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
- override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).toString
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d141354a0f..be2c101d63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children
- case expressions.Literal(value: String, dataType) =>
+ case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
+ val $primitiveTerm: ${termForType(dataType)} =
+ org.apache.spark.sql.types.UTF8String(${value.getBytes})
""".children
case expressions.Literal(value: Int, dataType) =>
@@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children
case Cast(child @ DateType(), StringType) =>
- child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
+ child.castOrNull(c =>
+ q"""org.apache.spark.sql.types.UTF8String(
+ org.apache.spark.sql.types.DateUtils.toString($c))""",
+ StringType)
case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- ${eval.primitiveTerm}.toString
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children
+ case EqualTo(e1: BinaryType, e2: BinaryType) =>
+ (e1, e2).evaluateAs (BooleanType) {
+ case (eval1, eval2) =>
+ q"""
+ java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
+ $eval2.asInstanceOf[Array[Byte]])
+ """
+ }
+
case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
@@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
- $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
+ $localLoggerTree.debug(
+ ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
@@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
+ case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
@@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
+ case StringType => q"$destinationRow.update($ordinal, $value)"
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
@@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "String"
+ case StringType => "org.apache.spark.sql.types.UTF8String"
}
protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => ru.Literal(Constant("<uninit>"))
+ case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 69397a73a8..6f572ff959 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val specificAccessorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // getString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) return $elementName" :: Nil
case _ => Nil
}
-
- q"""
- override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ // Row() need this interface to compile
+ case StringType =>
+ q"""
+ override def getString(i: Int): String = {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val specificMutatorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // setString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
-
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ case StringType =>
+ // MutableRow() need this interface to compile
+ q"""
+ override def setString(i: Int, value: String) {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val hashValues = expressions.zipWithIndex.map { case (e,i) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 860b72fad3..67caadb839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
/**
@@ -85,8 +85,11 @@ case class UserDefinedGenerator(
override protected def makeOutput(): Seq[Attribute] = schema
override def eval(input: Row): TraversableOnce[Row] = {
+ // TODO(davies): improve this
+ // Convert the objects into Scala Type before calling function, we need schema to support UDT
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputRow = new InterpretedProjection(children)
- function(inputRow(input))
+ function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0e2d593e94..18cba4cc46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._
object Literal {
@@ -29,7 +30,7 @@ object Literal {
case f: Float => Literal(f, FloatType)
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
- case s: String => Literal(s, StringType)
+ case s: String => Literal(UTF8String(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
@@ -42,7 +43,9 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
- def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
+ def create(v: Any, dataType: DataType): Literal = {
+ Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7e47cb3fff..fcd6352079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val r = right.eval(input)
if (r == null) null
else if (left.dataType != BinaryType) l == r
- else BinaryType.ordering.compare(
- l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0
+ else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0a275b8408..1b62e17ff4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{StructType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
/**
@@ -37,6 +37,7 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
+ // TODO(davies): add setDate() and setDecimal()
}
/**
@@ -114,9 +115,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
override def getString(i: Int): String = {
- values(i).asInstanceOf[String]
+ values(i) match {
+ case null => null
+ case s: String => s
+ case utf8: UTF8String => utf8.toString
+ }
}
+ // TODO(davies): add getDate and getDecimal
+
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
@@ -189,8 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
-
+ override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index acfbbace60..d597bf7ce7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
-import scala.collection.IndexedSeqOptimized
-
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType}
+import org.apache.spark.sql.types._
trait StringRegexExpression {
self: BinaryExpression =>
@@ -60,38 +57,17 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
- val regex = pattern(r.asInstanceOf[String])
+ val regex = pattern(r.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
- matches(regex, l.asInstanceOf[String])
+ matches(regex, l.asInstanceOf[UTF8String].toString)
}
}
}
}
}
-trait CaseConversionExpression {
- self: UnaryExpression =>
-
- type EvaluatedType = Any
-
- def convert(v: String): String
-
- override def foldable: Boolean = child.foldable
- def nullable: Boolean = child.nullable
- def dataType: DataType = StringType
-
- override def eval(input: Row): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- convert(evaluated.toString)
- }
- }
-}
-
/**
* Simple RegEx pattern matching function
*/
@@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
+trait CaseConversionExpression {
+ self: UnaryExpression =>
+
+ type EvaluatedType = Any
+
+ def convert(v: UTF8String): UTF8String
+
+ override def foldable: Boolean = child.foldable
+ def nullable: Boolean = child.nullable
+ def dataType: DataType = StringType
+
+ override def eval(input: Row): Any = {
+ val evaluated = child.eval(input)
+ if (evaluated == null) {
+ null
+ } else {
+ convert(evaluated.asInstanceOf[UTF8String])
+ }
+ }
+}
+
/**
* A function that converts the characters of a string to uppercase.
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toUpperCase()
+ override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def toString: String = s"Upper($child)"
}
@@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toLowerCase()
+ override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def toString: String = s"Lower($child)"
}
@@ -162,15 +159,16 @@ trait StringComparison {
override def nullable: Boolean = left.nullable || right.nullable
- def compare(l: String, r: String): Boolean
+ def compare(l: UTF8String, r: UTF8String): Boolean
override def eval(input: Row): Any = {
- val leftEval = left.eval(input).asInstanceOf[String]
+ val leftEval = left.eval(input)
if(leftEval == null) {
null
} else {
- val rightEval = right.eval(input).asInstanceOf[String]
- if (rightEval == null) null else compare(leftEval, rightEval)
+ val rightEval = right.eval(input)
+ if (rightEval == null) null
+ else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String])
}
}
@@ -184,7 +182,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.contains(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
}
/**
@@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.startsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
}
/**
@@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.endsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
}
/**
@@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
- def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
- (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
- val len = str.length
+ def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
@@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val start = startPos match {
case pos if pos > 0 => pos - 1
- case neg if neg < 0 => len + neg
+ case neg if neg < 0 => length() + neg
case _ => 0
}
@@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
case x => start + x
}
- str.slice(start, end)
+ (start, end)
}
override def eval(input: Row): Any = {
val string = str.eval(input)
-
val po = pos.eval(input)
val ln = len.eval(input)
@@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
null
} else {
val start = po.asInstanceOf[Int]
- val length = ln.asInstanceOf[Int]
-
+ val length = ln.asInstanceOf[Int]
string match {
- case ba: Array[Byte] => slice(ba, start, length)
- case other => slice(other.toString, start, length)
+ case ba: Array[Byte] =>
+ val (st, end) = slicePos(start, length, () => ba.length)
+ ba.slice(st, end)
+ case s: UTF8String =>
+ val (st, end) = slicePos(start, length, () => s.length)
+ s.slice(st, end)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 93e69d409c..7c80634d2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] {
val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case Like(l, Literal(endsWith(pattern), StringType)) =>
- EndsWith(l, Literal(pattern))
- case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case Like(l, Literal(equalTo(pattern), StringType)) =>
- EqualTo(l, Literal(pattern))
+ case Like(l, Literal(utf, StringType)) =>
+ utf.toString match {
+ case startsWith(pattern) if !pattern.endsWith("\\") =>
+ StartsWith(l, Literal(pattern))
+ case endsWith(pattern) =>
+ EndsWith(l, Literal(pattern))
+ case contains(pattern) if !pattern.endsWith("\\") =>
+ Contains(l, Literal(pattern))
+ case equalTo(pattern) =>
+ EqualTo(l, Literal(pattern))
+ case _ =>
+ Like(l, Literal.create(utf, StringType))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
index 504fb05842..d36a49159b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
@@ -40,6 +40,7 @@ object DateUtils {
millisToDays(d.getTime)
}
+ // we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisLocal: Long): Int = {
((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
new file mode 100644
index 0000000000..fc02ba6c9c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -0,0 +1,214 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import java.util.Arrays
+
+/**
+ * A UTF-8 String, as internal representation of StringType in SparkSQL
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+
+final class UTF8String extends Ordered[UTF8String] with Serializable {
+
+ private[this] var bytes: Array[Byte] = _
+
+ /**
+ * Update the UTF8String with String.
+ */
+ def set(str: String): UTF8String = {
+ bytes = str.getBytes("utf-8")
+ this
+ }
+
+ /**
+ * Update the UTF8String with Array[Byte], which should be encoded in UTF-8
+ */
+ def set(bytes: Array[Byte]): UTF8String = {
+ this.bytes = bytes
+ this
+ }
+
+ /**
+ * Return the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ @inline
+ private[this] def numOfBytes(b: Byte): Int = {
+ val offset = (b & 0xFF) - 192
+ if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
+ }
+
+ /**
+ * Return the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ def length(): Int = {
+ var len = 0
+ var i: Int = 0
+ while (i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ len += 1
+ }
+ len
+ }
+
+ def getBytes: Array[Byte] = {
+ bytes
+ }
+
+ /**
+ * Return a substring of this,
+ * @param start the position of first code point
+ * @param until the position after last code point
+ */
+ def slice(start: Int, until: Int): UTF8String = {
+ if (until <= start || start >= bytes.length || bytes == null) {
+ new UTF8String
+ }
+
+ var c = 0
+ var i: Int = 0
+ while (c < start && i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ c += 1
+ }
+ var j = i
+ while (c < until && j < bytes.length) {
+ j += numOfBytes(bytes(j))
+ c += 1
+ }
+ UTF8String(Arrays.copyOfRange(bytes, i, j))
+ }
+
+ def contains(sub: UTF8String): Boolean = {
+ val b = sub.getBytes
+ if (b.length == 0) {
+ return true
+ }
+ var i: Int = 0
+ while (i <= bytes.length - b.length) {
+ // In worst case, it's O(N*K), but should works fine with SQL
+ if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
+ return true
+ }
+ i += 1
+ }
+ false
+ }
+
+ def startsWith(prefix: UTF8String): Boolean = {
+ val b = prefix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b)
+ }
+
+ def endsWith(suffix: UTF8String): Boolean = {
+ val b = suffix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b)
+ }
+
+ def toUpperCase(): UTF8String = {
+ // upper case depends on locale, fallback to String.
+ UTF8String(toString().toUpperCase)
+ }
+
+ def toLowerCase(): UTF8String = {
+ // lower case depends on locale, fallback to String.
+ UTF8String(toString().toLowerCase)
+ }
+
+ override def toString(): String = {
+ new String(bytes, "utf-8")
+ }
+
+ override def clone(): UTF8String = new UTF8String().set(this.bytes)
+
+ override def compare(other: UTF8String): Int = {
+ var i: Int = 0
+ val b = other.getBytes
+ while (i < bytes.length && i < b.length) {
+ val res = bytes(i).compareTo(b(i))
+ if (res != 0) return res
+ i += 1
+ }
+ bytes.length - b.length
+ }
+
+ override def compareTo(other: UTF8String): Int = {
+ compare(other)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case s: UTF8String =>
+ Arrays.equals(bytes, s.getBytes)
+ case s: String =>
+ // This is only used for Catalyst unit tests
+ // fail fast
+ bytes.length >= s.length && length() == s.length && toString() == s
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = {
+ Arrays.hashCode(bytes)
+ }
+}
+
+object UTF8String {
+ // number of tailing bytes in a UTF8 sequence for a code point
+ // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
+ private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6)
+
+ /**
+ * Create a UTF-8 String from String
+ */
+ def apply(s: String): UTF8String = {
+ if (s != null) {
+ new UTF8String().set(s)
+ } else{
+ null
+ }
+ }
+
+ /**
+ * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8
+ */
+ def apply(bytes: Array[Byte]): UTF8String = {
+ if (bytes != null) {
+ new UTF8String().set(bytes)
+ } else {
+ null
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index cdf2bc68d9..c6fb22c26b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = String
+ private[sql] type JvmType = UTF8String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/**
* Convert the user type to a SQL datum
*
- * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
- * where we need to convert Any to UserType.
+ * TODO: Can we make this take obj: UserType? The issue is in
+ * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d4362a91d9..76298f03c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.FunSuite
import org.scalatest.Matchers._
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
@@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite {
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
+ def create_row(values: Any*): Row = {
+ new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ }
+
test("literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(true), true)
@@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("LIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null)))
- checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**")))
- checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b")))
- checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b")))
-
- checkEvaluation(Literal.create(null, StringType) like regEx, null,
- new GenericRow(Array[Any]("bc%")))
+ checkEvaluation("abcd" like regEx, null, create_row(null))
+ checkEvaluation("abdef" like regEx, true, create_row("abdef"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, true, create_row("a_%b"))
+ checkEvaluation("addb" like regEx, false, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, false, create_row("a%\\%b"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b"))
+ checkEvaluation("addb" like regEx, true, create_row("a%"))
+ checkEvaluation("addb" like regEx, false, create_row("**"))
+ checkEvaluation("abc" like regEx, true, create_row("a%"))
+ checkEvaluation("abc" like regEx, false, create_row("b%"))
+ checkEvaluation("abc" like regEx, false, create_row("bc%"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a_b"))
+ checkEvaluation("ab" like regEx, true, create_row("a%b"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a%b"))
+
+ checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%"))
}
test("RLIKE literal Regular Expression") {
@@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("RLIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c")))
- checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo")))
- checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$")))
- checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n")))
+ checkEvaluation("abdef" rlike regEx, true, create_row("abdef"))
+ checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c"))
+ checkEvaluation("fofo" rlike regEx, true, create_row("^fo"))
+ checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$"))
+ checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n"))
intercept[java.util.regex.PatternSyntaxException] {
- evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
+ evaluate("abbbbc" rlike regEx, create_row("**"))
}
}
@@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("null checking") {
- val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
+ val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
@@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("case when") {
- val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+ val row = create_row(null, false, true, "a", "b", "c")
val c1 = 'a.boolean.at(0)
val c2 = 'a.boolean.at(1)
val c3 = 'a.boolean.at(2)
@@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("complex type") {
- val row = new GenericRow(Array[Any](
- "^Ba*n", // 0
- null.asInstanceOf[String], // 1
- new GenericRow(Array[Any]("aa", "bb")), // 2
- Map("aa"->"bb"), // 3
- Seq("aa", "bb") // 4
- ))
+ val row = create_row(
+ "^Ba*n", // 0
+ null.asInstanceOf[UTF8String], // 1
+ create_row("aa", "bb"), // 2
+ Map("aa"->"bb"), // 3
+ Seq("aa", "bb") // 4
+ )
val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
@@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("arithmetic") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("fractional arithmetic") {
- val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null))
+ val row = create_row(1.1, 2.0, 3.1, null)
val c1 = 'a.double.at(0)
val c2 = 'a.double.at(1)
val c3 = 'a.double.at(2)
@@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("BinaryComparison") {
- val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
+ val row = create_row(1, 2, 3, null, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("StringComparison") {
- val row = new GenericRow(Array[Any]("abc", null))
+ val row = create_row("abc", null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
@@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("Substring") {
- val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+ val row = create_row("example", "example".toArray.map(_.toByte))
val s = 'a.string.at(0)
@@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
// substring(null, _, _) -> null
checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)),
- null, new GenericRow(Array[Any](null)))
+ null, create_row(null))
// substring(_, null, _) -> null
checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)),
@@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("SQRT") {
val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val rowSequence = inputSequence.map(l => create_row(l.toDouble))
val d = 'a.double.at(0)
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(-1), null, EmptyRow)
checkEvaluation(Sqrt(-1.5), null, EmptyRow)
}
test("Bitwise operations") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index 275ea2627e..bcc0c404d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
@@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
new file mode 100644
index 0000000000..a22aa6f244
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
@@ -0,0 +1,70 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+// scalastyle:off
+class UTF8StringSuite extends FunSuite {
+ test("basic") {
+ def check(str: String, len: Int) {
+
+ assert(UTF8String(str).length == len)
+ assert(UTF8String(str.getBytes("utf8")).length() == len)
+
+ assert(UTF8String(str) == str)
+ assert(UTF8String(str.getBytes("utf8")) == str)
+ assert(UTF8String(str).toString == str)
+ assert(UTF8String(str.getBytes("utf8")).toString == str)
+ assert(UTF8String(str.getBytes("utf8")) == UTF8String(str))
+
+ assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode())
+ }
+
+ check("hello", 5)
+ check("世 界", 3)
+ }
+
+ test("contains") {
+ assert(UTF8String("hello").contains(UTF8String("ello")))
+ assert(!UTF8String("hello").contains(UTF8String("vello")))
+ assert(UTF8String("大千世界").contains(UTF8String("千世")))
+ assert(!UTF8String("大千世界").contains(UTF8String("世千")))
+ }
+
+ test("prefix") {
+ assert(UTF8String("hello").startsWith(UTF8String("hell")))
+ assert(!UTF8String("hello").startsWith(UTF8String("ell")))
+ assert(UTF8String("大千世界").startsWith(UTF8String("大千")))
+ assert(!UTF8String("大千世界").startsWith(UTF8String("千")))
+ }
+
+ test("suffix") {
+ assert(UTF8String("hello").endsWith(UTF8String("ello")))
+ assert(!UTF8String("hello").endsWith(UTF8String("ellov")))
+ assert(UTF8String("大千世界").endsWith(UTF8String("世界")))
+ assert(!UTF8String("大千世界").endsWith(UTF8String("世")))
+ }
+
+ test("slice") {
+ assert(UTF8String("hello").slice(1, 3) == UTF8String("el"))
+ assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大"))
+ assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世"))
+ assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界"))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b237fe684c..89a4faf35e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case FloatType => true
case DateType => true
case TimestampType => true
+ case StringType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 87a6631da8..b0f983c180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats {
}
private[sql] class StringColumnStats extends ColumnStats {
- protected var upper: String = null
- protected var lower: String = null
+ protected var upper: UTF8String = null
+ protected var lower: UTF8String = null
override def gatherStats(row: Row, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getString(ordinal)
+ val value = row(ordinal).asInstanceOf[UTF8String]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += STRING.actualSize(row, ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index c47497e066..1b9e0df2dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
@@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
row.getString(ordinal).getBytes("utf-8").length + 4
}
- override def append(v: String, buffer: ByteBuffer): Unit = {
- val stringBytes = v.getBytes("utf-8")
+ override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
+ val stringBytes = v.getBytes
buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
}
- override def extract(buffer: ByteBuffer): String = {
+ override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
val stringBytes = new Array[Byte](length)
buffer.get(stringBytes, 0, length)
- new String(stringBytes, "utf-8")
+ UTF8String(stringBytes)
}
- override def setField(row: MutableRow, ordinal: Int, value: String): Unit = {
- row.setString(ordinal, value)
+ override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
+ row.update(ordinal, value)
}
- override def getField(row: Row, ordinal: Int): String = row.getString(ordinal)
+ override def getField(row: Row, ordinal: Int): UTF8String = {
+ row(ordinal).asInstanceOf[UTF8String]
+ }
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
- to.setString(toOrdinal, from.getString(fromOrdinal))
+ to.update(toOrdinal, from(fromOrdinal))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 656bdd7212..1fd387eec7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SQLContext}
/**
* :: DeveloperApi ::
@@ -54,6 +54,33 @@ object RDDConversions {
}
}
}
+
+ /**
+ * Convert the objects inside Row into the types Catalyst expected.
+ */
+ def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = {
+ data.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val bufferedIterator = iterator.buffered
+ val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray)
+ val schemaFields = schema.fields.toArray
+ val converters = schemaFields.map {
+ f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
+ }
+ bufferedIterator.map { r =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) = converters(i)(r(i))
+ i += 1
+ }
+
+ mutableRow
+ }
+ }
+ }
+ }
}
/** Logical plan node for scanning data from an RDD. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index fad7a281dc..99f24910fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType}
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
/**
* A logical command that is executed for its side-effects. `RunnableCommand`s are
@@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
- override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+ override def execute(): RDD[Row] = {
+ val converted = sideEffectResult.map(r =>
+ CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row])
+ sqlContext.sparkContext.parallelize(converted, 1)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index e916e68e58..710787096e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -164,7 +164,7 @@ package object debug {
case (_: Long, LongType) =>
case (_: Int, IntegerType) =>
- case (_: String, StringType) =>
+ case (_: UTF8String, StringType) =>
case (_: Float, FloatType) =>
case (_: Byte, ByteType) =>
case (_: Short, ShortType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 5b308d88d4..7a43bfd8bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -140,6 +140,7 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
case (date: Int, DateType) => DateUtils.toJavaDate(date)
+ case (s: UTF8String, StringType) => s.toString
// Pyrolite can handle Timestamp and Decimal
case (other, _) => other
@@ -192,7 +193,8 @@ object EvaluatePython {
case (c: Long, IntegerType) => c.toInt
case (c: Int, LongType) => c.toLong
case (c: Double, FloatType) => c.toFloat
- case (c, StringType) if !c.isInstanceOf[String] => c.toString
+ case (c: String, StringType) => UTF8String(c)
+ case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString)
case (c, _) => c
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 463e1dcc26..b9022fcd9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -233,7 +233,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
- case stringValue: String => s"'${escapeSql(stringValue)}'"
+ case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
case _ => value
}
@@ -349,12 +349,14 @@ private[sql] class JDBCRDD(
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
+ // TODO(davies): convert Date into Int
case DateConversion => mutableRow.update(i, rs.getDate(pos))
case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos))
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 4fa84dc076..99b755c9f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -130,6 +130,8 @@ private[sql] case class JDBCRelation(
extends BaseRelation
with PrunedFilteredScan {
+ override val needConversion: Boolean = false
+
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index 34f864f5fd..d4e0abc040 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -18,11 +18,8 @@
package org.apache.spark.sql
import java.sql.{Connection, DriverManager, PreparedStatement}
-import org.apache.spark.{Logging, Partition}
-import org.apache.spark.sql._
-import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition}
+import org.apache.spark.Logging
import org.apache.spark.sql.types._
package object jdbc {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f4c99b4b56..e3352d0278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
-
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] class DefaultSource
@@ -113,6 +113,8 @@ private[sql] case class JSONRelation(
// TODO: Support partitioned JSON relation.
private def baseRDD = sqlContext.sparkContext.textFile(path)
+ override val needConversion: Boolean = false
+
override val schema = userSpecifiedSchema.getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index b1e8521383..29de7401dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
- case StringType => toString(value)
+ case StringType => UTF8String(toString(value))
case _ if value == null || value == "" => null // guard the non string type
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 43ca359b51..bc108e37df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.getBytes)
- protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- updateField(fieldIndex, value)
+ protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ updateField(fieldIndex, UTF8String(value))
protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, readTimestamp(value))
@@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, value.getBytes)
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- current.setString(fieldIndex, value)
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ current.update(fieldIndex, UTF8String(value))
override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, readTimestamp(value))
@@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter(
private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int)
extends CatalystPrimitiveConverter(parent, fieldIndex) {
- private[this] var dict: Array[String] = null
+ private[this] var dict: Array[Array[Byte]] = null
override def hasDictionarySupport: Boolean = true
override def setDictionary(dictionary: Dictionary):Unit =
- dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8}
-
+ dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes }
override def addValueFromDictionary(dictionaryId: Int): Unit =
parent.updateString(fieldIndex, dict(dictionaryId))
override def addBinary(value: Binary): Unit =
- parent.updateString(fieldIndex, value.toStringUsingUTF8)
+ parent.updateString(fieldIndex, value.getBytes)
}
private[parquet] object CatalystArrayConverter {
@@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter(
elements += 1
}
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = {
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = {
checkGrowBuffer()
- buffer(elements) = value.asInstanceOf[NativeType]
+ buffer(elements) = UTF8String(value).asInstanceOf[NativeType]
elements += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index 0357dcc468..5eb1c6abc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -55,7 +55,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
@@ -76,7 +76,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
@@ -94,7 +94,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -111,7 +111,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -128,7 +128,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -145,7 +145,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 5a1b15490d..e05a4c20b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
if (value != null) {
schema match {
case StringType => writer.addBinary(
- Binary.fromByteArray(
- value.asInstanceOf[String].getBytes("utf-8")
- )
- )
+ Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
@@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
index: Int): Unit = {
ctype match {
case StringType => writer.addBinary(
- Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8")))
+ Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(record.getInt(index))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 20fdf5e58e..af7b3c81ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext}
-
import parquet.filter2.predicate.FilterApi
import parquet.format.converter.ParquetMetadataConverter
import parquet.hadoop.metadata.CompressionCodecName
@@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions}
import org.apache.spark.sql.parquet.ParquetTypesConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _}
import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
-import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition}
/**
* Allows creation of Parquet based tables using the syntax:
@@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2(
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+ // Skip type conversion
+ override val needConversion: Boolean = false
+
// TODO Should calculate per scan size
// It's common that a query only scans a fraction of a large Parquet file. Returning size of the
// whole Parquet file disables some optimizations in this case (e.g. broadcast join).
@@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2(
baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
- case p if split.getPath.getParent.toString == p.path => p.values
+ case p if split.getPath.getParent.toString == p.path =>
+ CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row]
}.get
val requiredPartOrdinal = partitionKeyLocations.keys.toSeq
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 34d048e426..b3d71f687a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{UTF8String, StringType}
import org.apache.spark.sql.{Row, Strategy, execution, sources}
/**
@@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy {
(a, _) => t.buildScan(a)) :: Nil
case l @ LogicalRelation(t: TableScan) =>
- execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
+ createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
@@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy {
projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above.
.map(relation.attributeMap) // Match original case of attributes.
- val scan =
- execution.PhysicalRDD(
- projectList.map(_.toAttribute),
+ val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute),
scanBuilder(requestedColumns, pushedFilters))
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
- val scan =
- execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters))
+ val scan = createPhysicalRDD(relation.relation, requestedColumns,
+ scanBuilder(requestedColumns, pushedFilters))
execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
}
+ private[this] def createPhysicalRDD(
+ relation: BaseRelation,
+ output: Seq[Attribute],
+ rdd: RDD[Row]): SparkPlan = {
+ val converted = if (relation.needConversion) {
+ execution.RDDConversions.rowToRowRdd(rdd, relation.schema)
+ } else {
+ rdd
+ }
+ execution.PhysicalRDD(output, converted)
+ }
+
/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
* and convert them.
@@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy {
case expressions.Not(child) =>
translate(child).map(sources.Not)
- case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringStartsWith(a.name, v))
+ case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringStartsWith(a.name, v.toString))
- case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringEndsWith(a.name, v))
+ case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringEndsWith(a.name, v.toString))
- case expressions.Contains(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringContains(a.name, v))
+ case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringContains(a.name, v.toString))
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 8f9946a5a8..ca53dcdb92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -126,6 +126,16 @@ abstract class BaseRelation {
* could lead to execution plans that are suboptimal (i.e. broadcasting a very large table).
*/
def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
+
+ /**
+ * Whether does it need to convert the objects in Row to internal representation, for example:
+ * java.lang.String -> UTF8String
+ * java.lang.Decimal -> Decimal
+ *
+ * Note: The internal representation is not stable across releases and thus data sources outside
+ * of Spark SQL should leave this as true.
+ */
+ def needConversion: Boolean = true
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 36465cc2fa..bf6cf1321a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -30,7 +30,7 @@ class RowSuite extends FunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
expected.update(0, 2147483647)
- expected.update(1, "this is a string")
+ expected.setString(1, "this is a string")
expected.update(2, false)
expected.update(3, null)
val actual1 = Row(2147483647, "this is a string", false, null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0174aaee94..4c48dca444 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,18 +17,14 @@
package org.apache.spark.sql
-import org.apache.spark.sql.execution.GeneratedAggregate
-import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types._
-
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
+import org.apache.spark.sql.types._
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 5f08834f73..c86ef338fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -65,7 +65,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
- checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
@@ -108,8 +108,8 @@ class ColumnTypeSuite extends FunSuite with Logging {
testNativeColumnType[StringType.type](
STRING,
- (buffer: ByteBuffer, string: String) => {
- val bytes = string.getBytes("utf-8")
+ (buffer: ByteBuffer, string: UTF8String) => {
+ val bytes = string.getBytes
buffer.putInt(bytes.length)
buffer.put(bytes)
},
@@ -117,7 +117,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes)
- new String(bytes, "utf-8")
+ UTF8String(bytes)
})
testColumnType[BinaryType.type, Array[Byte]](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index b301818a00..f76314b9da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType}
object ColumnarTestUtils {
def makeNullRow(length: Int): GenericMutableRow = {
@@ -48,7 +48,7 @@ object ColumnarTestUtils {
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
- case STRING => Random.nextString(Random.nextInt(32))
+ case STRING => UTF8String(Random.nextString(Random.nextInt(32)))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 60c8c00bda..3b47b8adf3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -74,7 +74,7 @@ case class AllDataTypesScan(
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -82,7 +82,7 @@ case class AllDataTypesScan(
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}
}
}
@@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest {
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest {
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}.toSeq
before {
@@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
- (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq)
+ (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq)
test("Caching") {
// Cached Query Execution
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 921c6194c7..74ae984f34 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -34,7 +34,7 @@ import scala.collection.JavaConversions._
* 1. The Underlying data type in catalyst and in Hive
* In catalyst:
* Primitive =>
- * java.lang.String
+ * UTF8String
* int / scala.Int
* boolean / scala.Boolean
* float / scala.Float
@@ -239,9 +239,10 @@ private[hive] trait HiveInspectors {
*/
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
- case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString
+ case poi: WritableConstantStringObjectInspector =>
+ UTF8String(poi.getWritableConstantValue.toString)
case poi: WritableConstantHiveVarcharObjectInspector =>
- poi.getWritableConstantValue.getHiveVarchar.getValue
+ UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue)
case poi: WritableConstantHiveDecimalObjectInspector =>
HiveShim.toCatalystDecimal(
PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
@@ -284,10 +285,13 @@ private[hive] trait HiveInspectors {
case pi: PrimitiveObjectInspector => pi match {
// We think HiveVarchar is also a String
case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
- hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue
- case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue
+ UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+ case hvoi: HiveVarcharObjectInspector =>
+ UTF8String(hvoi.getPrimitiveJavaObject(data).getValue)
case x: StringObjectInspector if x.preferWritable() =>
- x.getPrimitiveWritableObject(data).toString
+ UTF8String(x.getPrimitiveWritableObject(data).toString)
+ case x: StringObjectInspector =>
+ UTF8String(x.getPrimitiveJavaObject(data))
case x: IntObjectInspector if x.preferWritable() => x.get(data)
case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
case x: FloatObjectInspector if x.preferWritable() => x.get(data)
@@ -340,7 +344,9 @@ private[hive] trait HiveInspectors {
*/
protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
- (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
+ (o: Any) =>
+ val s = o.asInstanceOf[UTF8String].toString
+ new HiveVarchar(s, s.size)
case _: JavaHiveDecimalObjectInspector =>
(o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal)
@@ -409,7 +415,7 @@ private[hive] trait HiveInspectors {
case x: PrimitiveObjectInspector => x match {
// TODO we don't support the HiveVarcharObjectInspector yet.
case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a)
- case _: StringObjectInspector => a.asInstanceOf[java.lang.String]
+ case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString()
case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a)
case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 1ccb0c279c..a6f4fbe8ab 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -17,24 +17,21 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.catalyst.expressions.Row
-
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.expressions.{Row, _}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.sources.DescribeCommand
-import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing}
+import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.types.StringType
@@ -131,7 +128,7 @@ private[hive] trait HiveStrategies {
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
- inputData(i) = partitionValues(i)
+ inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i))
i += 1
}
pruningCondition(inputData)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 8efed7f029..cab0fdd357 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.hive.execution
-import java.io.{BufferedReader, InputStreamReader}
-import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
import java.util.Properties
import scala.collection.JavaConversions._
@@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.hive.HiveShim._
+import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
+import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
/**
@@ -121,14 +121,13 @@ case class ScriptTransformation(
if (outputSerde == null) {
val prevLine = curLine
curLine = reader.readLine()
-
if (!ioschema.schemaLess) {
- new GenericRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ new GenericRow(CatalystTypeConverters.convertToCatalyst(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")))
.asInstanceOf[Array[Any]])
} else {
- new GenericRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
+ new GenericRow(CatalystTypeConverters.convertToCatalyst(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2))
.asInstanceOf[Array[Any]])
}
} else {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 902a12785e..a40a1e5311 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
/**
* Analyzes the given table in the current database to generate statistics, which will be
@@ -76,6 +76,12 @@ case class DropTable(
private[hive]
case class AddJar(path: String) extends RunnableCommand {
+ override val output: Seq[Attribute] = {
+ val schema = StructType(
+ StructField("result", IntegerType, false) :: Nil)
+ schema.toAttributes
+ }
+
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
hiveContext.runSqlHive(s"ADD JAR $path")
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 0ed93c2c5b..33e96eaabf 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
import org.apache.hadoop.io.{NullWritable, Writable}
import org.apache.hadoop.mapred.InputFormat
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType}
private[hive] case class HiveFunctionWrapper(functionClassName: String)
extends java.io.Serializable {
@@ -135,7 +135,7 @@ private[hive] object HiveShim {
PrimitiveCategory.VOID, null)
def getStringWritable(value: Any): hadoopIo.Text =
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString)
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 7577309900..d331c210e8 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -17,37 +17,35 @@
package org.apache.spark.sql.hive
-import java.util
-import java.util.{ArrayList => JArrayList}
-import java.util.Properties
import java.rmi.server.UID
+import java.util.{Properties, ArrayList => JArrayList}
import scala.collection.JavaConversions._
import scala.language.implicitConversions
+import com.esotericsoftware.kryo.Kryo
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.{NullWritable, Writable}
-import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.hive.common.StatsSetupConst
-import org.apache.hadoop.hive.common.`type`.{HiveDecimal}
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
-import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
+import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector}
-import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
-import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector}
+import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory}
+import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo}
+import org.apache.hadoop.io.{NullWritable, Writable}
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
-import org.apache.spark.sql.types.{Decimal, DecimalType}
-
+import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String}
/**
* This class provides the UDF creation and also the UDF instance serialization and
@@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
// for Serialization
def this() = this(null)
- import java.io.{OutputStream, InputStream}
- import com.esotericsoftware.kryo.Kryo
import org.apache.spark.util.Utils._
- import org.apache.hadoop.hive.ql.exec.Utilities
- import org.apache.hadoop.hive.ql.exec.UDF
@transient
private val methodDeSerialize = {
val method = classOf[Utilities].getDeclaredMethod(
"deserializeObjectByKryo",
classOf[Kryo],
- classOf[InputStream],
+ classOf[java.io.InputStream],
classOf[Class[_]])
method.setAccessible(true)
@@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
"serializeObjectByKryo",
classOf[Kryo],
classOf[Object],
- classOf[OutputStream])
+ classOf[java.io.OutputStream])
method.setAccessible(true)
method
@@ -224,7 +218,7 @@ private[hive] object HiveShim {
TypeInfoFactory.voidTypeInfo, null)
def getStringWritable(value: Any): hadoopIo.Text =
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString)
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])