aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala90
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala214
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala70
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala13
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala17
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala10
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala4
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala36
50 files changed, 742 insertions, 298 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ef91a9c4f5..f2c3b74a18 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -456,7 +456,7 @@ class DataFrame(object):
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
- [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
+ [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
"""
if joinExprs is None:
@@ -637,9 +637,9 @@ class DataFrame(object):
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
- [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
+ [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
- [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
+ [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
@@ -867,11 +867,11 @@ class GroupedData(object):
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
- [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+ [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
- [Row(MIN(age)=5), Row(MIN(age)=2)]
+ [Row(MIN(age)=2), Row(MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index d794f034f5..ac8a782976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.{StructType, DateUtils}
+import org.apache.spark.sql.types.StructType
object Row {
/**
@@ -257,6 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
+ // TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 91976fef6d..d4f9fdacda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -77,6 +77,9 @@ object CatalystTypeConverters {
}
new GenericRowWithSchema(ar, structType)
+ case (d: String, _) =>
+ UTF8String(d)
+
case (d: BigDecimal, _) =>
Decimal(d)
@@ -175,6 +178,11 @@ object CatalystTypeConverters {
case other => other
}
+ case dataType: StringType => (item: Any) => extractOption(item) match {
+ case s: String => UTF8String(s)
+ case other => other
+ }
+
case _ =>
(item: Any) => extractOption(item) match {
case d: BigDecimal => Decimal(d)
@@ -184,6 +192,26 @@ object CatalystTypeConverters {
}
}
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ *
+ * Note: This should be called before do evaluation on Row
+ * (It does not support UDT)
+ * This is used to create an RDD or test results with correct types for Catalyst.
+ */
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: String => UTF8String(s)
+ case d: java.sql.Date => DateUtils.fromJavaDate(d)
+ case d: BigDecimal => Decimal(d)
+ case d: java.math.BigDecimal => Decimal(d)
+ case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
+ case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
+ case m: Map[Any, Any] =>
+ m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
+ case other => other
+ }
+
/**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
@@ -211,6 +239,9 @@ object CatalystTypeConverters {
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)
+ case (s: UTF8String, StringType) =>
+ s.toString()
+
case (other, _) =>
other
}
@@ -262,6 +293,12 @@ object CatalystTypeConverters {
case other => other
}
+ case StringType =>
+ (item: Any) => item match {
+ case s: UTF8String => s.toString()
+ case other => other
+ }
+
case other =>
(item: Any) => item
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 01d5c15122..d9521953ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -138,6 +138,7 @@ trait ScalaReflection {
// The data type can be determined without ambiguity.
case obj: BooleanType.JvmType => BooleanType
case obj: BinaryType.JvmType => BinaryType
+ case obj: String => StringType
case obj: StringType.JvmType => StringType
case obj: ByteType.JvmType => ByteType
case obj: ShortType.JvmType => ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 3aeb964994..35c7f00d4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal.create("NaN", StringType)
+ val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -563,6 +563,10 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
+ // Compatible with Hive
+ case Substring(e, start, len) if e.dataType != StringType =>
+ Substring(Cast(e, StringType), start, len)
+
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 31f1a5fdc7..adf941ab2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._
/** Cast the child expression to the target data type. */
@@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
- case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
- case DateType => buildCast[Int](_, d => DateUtils.toString(d))
- case TimestampType => buildCast[Timestamp](_, timestampToString)
- case _ => buildCast[Any](_, _.toString)
+ case BinaryType => buildCast[Array[Byte]](_, UTF8String(_))
+ case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d)))
+ case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t)))
+ case _ => buildCast[Any](_, o => UTF8String(o.toString))
}
// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
- case StringType => buildCast[String](_, _.getBytes("UTF-8"))
+ case StringType => buildCast[UTF8String](_, _.getBytes)
}
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, _.length() != 0)
+ buildCast[UTF8String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
@@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => {
+ buildCast[UTF8String](_, utfs => {
// Throw away extra if more than 9 decimal places
+ val s = utfs.toString
val periodIdx = s.indexOf(".")
var n = s
if (periodIdx != -1 && n.length() - periodIdx > 9) {
@@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s =>
- try DateUtils.fromJavaDate(Date.valueOf(s))
+ buildCast[UTF8String](_, s =>
+ try DateUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
@@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toLong catch {
+ buildCast[UTF8String](_, s => try s.toString.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toInt catch {
+ buildCast[UTF8String](_, s => try s.toString.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toShort catch {
+ buildCast[UTF8String](_, s => try s.toString.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toByte catch {
+ buildCast[UTF8String](_, s => try s.toString.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
+ buildCast[UTF8String](_, s => try {
+ changePrecision(Decimal(s.toString.toDouble), target)
+ } catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toDouble catch {
+ buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[String](_, s => try s.toFloat catch {
+ buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 47b6f358ed..3475ed05f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericRow(newValues)
}
- override def update(ordinal: Int, value: Any): Unit = {
- if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
+ override def update(ordinal: Int, value: Any) {
+ if (value == null) {
+ setNullAt(ordinal)
+ } else {
+ values(ordinal).update(value)
+ }
}
- override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))
- override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).toString
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d141354a0f..be2c101d63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children
- case expressions.Literal(value: String, dataType) =>
+ case expressions.Literal(value: UTF8String, dataType) =>
q"""
val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
+ val $primitiveTerm: ${termForType(dataType)} =
+ org.apache.spark.sql.types.UTF8String(${value.getBytes})
""".children
case expressions.Literal(value: Int, dataType) =>
@@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children
case Cast(child @ DateType(), StringType) =>
- child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType)
+ child.castOrNull(c =>
+ q"""org.apache.spark.sql.types.UTF8String(
+ org.apache.spark.sql.types.DateUtils.toString($c))""",
+ StringType)
case Cast(child @ NumericType(), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)
@@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
if($nullTerm)
${defaultPrimitive(StringType)}
else
- ${eval.primitiveTerm}.toString
+ org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children
+ case EqualTo(e1: BinaryType, e2: BinaryType) =>
+ (e1, e2).evaluateAs (BooleanType) {
+ case (eval1, eval2) =>
+ q"""
+ java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
+ $eval2.asInstanceOf[Array[Byte]])
+ """
+ }
+
case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
@@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
- $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm))
+ $localLoggerTree.debug(
+ ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
""" :: Nil
} else {
Nil
@@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
+ case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
@@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
+ case StringType => q"$destinationRow.update($ordinal, $value)"
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
@@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "String"
+ case StringType => "org.apache.spark.sql.types.UTF8String"
}
protected def defaultPrimitive(dt: DataType) = dt match {
case BooleanType => ru.Literal(Constant(false))
case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => ru.Literal(Constant("<uninit>"))
+ case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
case ShortType => ru.Literal(Constant(-1.toShort))
case LongType => ru.Literal(Constant(-1L))
case ByteType => ru.Literal(Constant(-1.toByte))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 69397a73a8..6f572ff959 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val specificAccessorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // getString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) return $elementName" :: Nil
case _ => Nil
}
-
- q"""
- override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ // Row() need this interface to compile
+ case StringType =>
+ q"""
+ override def getString(i: Int): String = {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val specificMutatorFunctions = NativeType.all.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
- case (e, i) if e.dataType == dataType =>
+ // setString() is not used by expressions
+ case (e, i) if e.dataType == dataType && dataType != StringType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
-
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ dataType match {
+ case StringType =>
+ // MutableRow() need this interface to compile
+ q"""
+ override def setString(i: Int, value: String) {
+ $accessorFailure
+ }"""
+ case other =>
+ q"""
+ override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
}
val hashValues = expressions.zipWithIndex.map { case (e,i) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 860b72fad3..67caadb839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
/**
@@ -85,8 +85,11 @@ case class UserDefinedGenerator(
override protected def makeOutput(): Seq[Attribute] = schema
override def eval(input: Row): TraversableOnce[Row] = {
+ // TODO(davies): improve this
+ // Convert the objects into Scala Type before calling function, we need schema to support UDT
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputRow = new InterpretedProjection(children)
- function(inputRow(input))
+ function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0e2d593e94..18cba4cc46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._
object Literal {
@@ -29,7 +30,7 @@ object Literal {
case f: Float => Literal(f, FloatType)
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
- case s: String => Literal(s, StringType)
+ case s: String => Literal(UTF8String(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
@@ -42,7 +43,9 @@ object Literal {
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
- def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
+ def create(v: Any, dataType: DataType): Literal = {
+ Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7e47cb3fff..fcd6352079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
val r = right.eval(input)
if (r == null) null
else if (left.dataType != BinaryType) l == r
- else BinaryType.ordering.compare(
- l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0
+ else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 0a275b8408..1b62e17ff4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{StructType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
/**
@@ -37,6 +37,7 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
+ // TODO(davies): add setDate() and setDecimal()
}
/**
@@ -114,9 +115,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
override def getString(i: Int): String = {
- values(i).asInstanceOf[String]
+ values(i) match {
+ case null => null
+ case s: String => s
+ case utf8: UTF8String => utf8.toString
+ }
}
+ // TODO(davies): add getDate and getDecimal
+
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
@@ -189,8 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
-
+ override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index acfbbace60..d597bf7ce7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
-import scala.collection.IndexedSeqOptimized
-
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType}
+import org.apache.spark.sql.types._
trait StringRegexExpression {
self: BinaryExpression =>
@@ -60,38 +57,17 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
- val regex = pattern(r.asInstanceOf[String])
+ val regex = pattern(r.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
- matches(regex, l.asInstanceOf[String])
+ matches(regex, l.asInstanceOf[UTF8String].toString)
}
}
}
}
}
-trait CaseConversionExpression {
- self: UnaryExpression =>
-
- type EvaluatedType = Any
-
- def convert(v: String): String
-
- override def foldable: Boolean = child.foldable
- def nullable: Boolean = child.nullable
- def dataType: DataType = StringType
-
- override def eval(input: Row): Any = {
- val evaluated = child.eval(input)
- if (evaluated == null) {
- null
- } else {
- convert(evaluated.toString)
- }
- }
-}
-
/**
* Simple RegEx pattern matching function
*/
@@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
+trait CaseConversionExpression {
+ self: UnaryExpression =>
+
+ type EvaluatedType = Any
+
+ def convert(v: UTF8String): UTF8String
+
+ override def foldable: Boolean = child.foldable
+ def nullable: Boolean = child.nullable
+ def dataType: DataType = StringType
+
+ override def eval(input: Row): Any = {
+ val evaluated = child.eval(input)
+ if (evaluated == null) {
+ null
+ } else {
+ convert(evaluated.asInstanceOf[UTF8String])
+ }
+ }
+}
+
/**
* A function that converts the characters of a string to uppercase.
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toUpperCase()
+ override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def toString: String = s"Upper($child)"
}
@@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: String): String = v.toLowerCase()
+ override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def toString: String = s"Lower($child)"
}
@@ -162,15 +159,16 @@ trait StringComparison {
override def nullable: Boolean = left.nullable || right.nullable
- def compare(l: String, r: String): Boolean
+ def compare(l: UTF8String, r: UTF8String): Boolean
override def eval(input: Row): Any = {
- val leftEval = left.eval(input).asInstanceOf[String]
+ val leftEval = left.eval(input)
if(leftEval == null) {
null
} else {
- val rightEval = right.eval(input).asInstanceOf[String]
- if (rightEval == null) null else compare(leftEval, rightEval)
+ val rightEval = right.eval(input)
+ if (rightEval == null) null
+ else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String])
}
}
@@ -184,7 +182,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.contains(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
}
/**
@@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.startsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
}
/**
@@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
- override def compare(l: String, r: String): Boolean = l.endsWith(r)
+ override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
}
/**
@@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
- def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
- (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
- val len = str.length
+ def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
// negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
@@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val start = startPos match {
case pos if pos > 0 => pos - 1
- case neg if neg < 0 => len + neg
+ case neg if neg < 0 => length() + neg
case _ => 0
}
@@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
case x => start + x
}
- str.slice(start, end)
+ (start, end)
}
override def eval(input: Row): Any = {
val string = str.eval(input)
-
val po = pos.eval(input)
val ln = len.eval(input)
@@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
null
} else {
val start = po.asInstanceOf[Int]
- val length = ln.asInstanceOf[Int]
-
+ val length = ln.asInstanceOf[Int]
string match {
- case ba: Array[Byte] => slice(ba, start, length)
- case other => slice(other.toString, start, length)
+ case ba: Array[Byte] =>
+ val (st, end) = slicePos(start, length, () => ba.length)
+ ba.slice(st, end)
+ case s: UTF8String =>
+ val (st, end) = slicePos(start, length, () => s.length)
+ s.slice(st, end)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 93e69d409c..7c80634d2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] {
val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case Like(l, Literal(endsWith(pattern), StringType)) =>
- EndsWith(l, Literal(pattern))
- case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case Like(l, Literal(equalTo(pattern), StringType)) =>
- EqualTo(l, Literal(pattern))
+ case Like(l, Literal(utf, StringType)) =>
+ utf.toString match {
+ case startsWith(pattern) if !pattern.endsWith("\\") =>
+ StartsWith(l, Literal(pattern))
+ case endsWith(pattern) =>
+ EndsWith(l, Literal(pattern))
+ case contains(pattern) if !pattern.endsWith("\\") =>
+ Contains(l, Literal(pattern))
+ case equalTo(pattern) =>
+ EqualTo(l, Literal(pattern))
+ case _ =>
+ Like(l, Literal.create(utf, StringType))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
index 504fb05842..d36a49159b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
@@ -40,6 +40,7 @@ object DateUtils {
millisToDays(d.getTime)
}
+ // we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisLocal: Long): Int = {
((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
new file mode 100644
index 0000000000..fc02ba6c9c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -0,0 +1,214 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import java.util.Arrays
+
+/**
+ * A UTF-8 String, as internal representation of StringType in SparkSQL
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+
+final class UTF8String extends Ordered[UTF8String] with Serializable {
+
+ private[this] var bytes: Array[Byte] = _
+
+ /**
+ * Update the UTF8String with String.
+ */
+ def set(str: String): UTF8String = {
+ bytes = str.getBytes("utf-8")
+ this
+ }
+
+ /**
+ * Update the UTF8String with Array[Byte], which should be encoded in UTF-8
+ */
+ def set(bytes: Array[Byte]): UTF8String = {
+ this.bytes = bytes
+ this
+ }
+
+ /**
+ * Return the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ @inline
+ private[this] def numOfBytes(b: Byte): Int = {
+ val offset = (b & 0xFF) - 192
+ if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
+ }
+
+ /**
+ * Return the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ def length(): Int = {
+ var len = 0
+ var i: Int = 0
+ while (i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ len += 1
+ }
+ len
+ }
+
+ def getBytes: Array[Byte] = {
+ bytes
+ }
+
+ /**
+ * Return a substring of this,
+ * @param start the position of first code point
+ * @param until the position after last code point
+ */
+ def slice(start: Int, until: Int): UTF8String = {
+ if (until <= start || start >= bytes.length || bytes == null) {
+ new UTF8String
+ }
+
+ var c = 0
+ var i: Int = 0
+ while (c < start && i < bytes.length) {
+ i += numOfBytes(bytes(i))
+ c += 1
+ }
+ var j = i
+ while (c < until && j < bytes.length) {
+ j += numOfBytes(bytes(j))
+ c += 1
+ }
+ UTF8String(Arrays.copyOfRange(bytes, i, j))
+ }
+
+ def contains(sub: UTF8String): Boolean = {
+ val b = sub.getBytes
+ if (b.length == 0) {
+ return true
+ }
+ var i: Int = 0
+ while (i <= bytes.length - b.length) {
+ // In worst case, it's O(N*K), but should works fine with SQL
+ if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
+ return true
+ }
+ i += 1
+ }
+ false
+ }
+
+ def startsWith(prefix: UTF8String): Boolean = {
+ val b = prefix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b)
+ }
+
+ def endsWith(suffix: UTF8String): Boolean = {
+ val b = suffix.getBytes
+ if (b.length > bytes.length) {
+ return false
+ }
+ Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b)
+ }
+
+ def toUpperCase(): UTF8String = {
+ // upper case depends on locale, fallback to String.
+ UTF8String(toString().toUpperCase)
+ }
+
+ def toLowerCase(): UTF8String = {
+ // lower case depends on locale, fallback to String.
+ UTF8String(toString().toLowerCase)
+ }
+
+ override def toString(): String = {
+ new String(bytes, "utf-8")
+ }
+
+ override def clone(): UTF8String = new UTF8String().set(this.bytes)
+
+ override def compare(other: UTF8String): Int = {
+ var i: Int = 0
+ val b = other.getBytes
+ while (i < bytes.length && i < b.length) {
+ val res = bytes(i).compareTo(b(i))
+ if (res != 0) return res
+ i += 1
+ }
+ bytes.length - b.length
+ }
+
+ override def compareTo(other: UTF8String): Int = {
+ compare(other)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case s: UTF8String =>
+ Arrays.equals(bytes, s.getBytes)
+ case s: String =>
+ // This is only used for Catalyst unit tests
+ // fail fast
+ bytes.length >= s.length && length() == s.length && toString() == s
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = {
+ Arrays.hashCode(bytes)
+ }
+}
+
+object UTF8String {
+ // number of tailing bytes in a UTF8 sequence for a code point
+ // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
+ private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6)
+
+ /**
+ * Create a UTF-8 String from String
+ */
+ def apply(s: String): UTF8String = {
+ if (s != null) {
+ new UTF8String().set(s)
+ } else{
+ null
+ }
+ }
+
+ /**
+ * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8
+ */
+ def apply(bytes: Array[Byte]): UTF8String = {
+ if (bytes != null) {
+ new UTF8String().set(bytes)
+ } else {
+ null
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index cdf2bc68d9..c6fb22c26b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = String
+ private[sql] type JvmType = UTF8String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
@@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/**
* Convert the user type to a SQL datum
*
- * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
- * where we need to convert Any to UserType.
+ * TODO: Can we make this take obj: UserType? The issue is in
+ * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d4362a91d9..76298f03c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.FunSuite
import org.scalatest.Matchers._
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
@@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite {
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
+ def create_row(values: Any*): Row = {
+ new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
+ }
+
test("literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(true), true)
@@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("LIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null)))
- checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b")))
- checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**")))
- checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
- checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b")))
- checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b")))
- checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b")))
-
- checkEvaluation(Literal.create(null, StringType) like regEx, null,
- new GenericRow(Array[Any]("bc%")))
+ checkEvaluation("abcd" like regEx, null, create_row(null))
+ checkEvaluation("abdef" like regEx, true, create_row("abdef"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, true, create_row("a_%b"))
+ checkEvaluation("addb" like regEx, false, create_row("a\\__b"))
+ checkEvaluation("addb" like regEx, false, create_row("a%\\%b"))
+ checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b"))
+ checkEvaluation("addb" like regEx, true, create_row("a%"))
+ checkEvaluation("addb" like regEx, false, create_row("**"))
+ checkEvaluation("abc" like regEx, true, create_row("a%"))
+ checkEvaluation("abc" like regEx, false, create_row("b%"))
+ checkEvaluation("abc" like regEx, false, create_row("bc%"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a_b"))
+ checkEvaluation("ab" like regEx, true, create_row("a%b"))
+ checkEvaluation("a\nb" like regEx, true, create_row("a%b"))
+
+ checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%"))
}
test("RLIKE literal Regular Expression") {
@@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("RLIKE Non-literal Regular Expression") {
val regEx = 'a.string.at(0)
- checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef")))
- checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c")))
- checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo")))
- checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$")))
- checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n")))
+ checkEvaluation("abdef" rlike regEx, true, create_row("abdef"))
+ checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c"))
+ checkEvaluation("fofo" rlike regEx, true, create_row("^fo"))
+ checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$"))
+ checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n"))
intercept[java.util.regex.PatternSyntaxException] {
- evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
+ evaluate("abbbbc" rlike regEx, create_row("**"))
}
}
@@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("null checking") {
- val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
+ val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
@@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("case when") {
- val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+ val row = create_row(null, false, true, "a", "b", "c")
val c1 = 'a.boolean.at(0)
val c2 = 'a.boolean.at(1)
val c3 = 'a.boolean.at(2)
@@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("complex type") {
- val row = new GenericRow(Array[Any](
- "^Ba*n", // 0
- null.asInstanceOf[String], // 1
- new GenericRow(Array[Any]("aa", "bb")), // 2
- Map("aa"->"bb"), // 3
- Seq("aa", "bb") // 4
- ))
+ val row = create_row(
+ "^Ba*n", // 0
+ null.asInstanceOf[UTF8String], // 1
+ create_row("aa", "bb"), // 2
+ Map("aa"->"bb"), // 3
+ Seq("aa", "bb") // 4
+ )
val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
@@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("arithmetic") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("fractional arithmetic") {
- val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null))
+ val row = create_row(1.1, 2.0, 3.1, null)
val c1 = 'a.double.at(0)
val c2 = 'a.double.at(1)
val c3 = 'a.double.at(2)
@@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("BinaryComparison") {
- val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null))
+ val row = create_row(1, 2, 3, null, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
@@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("StringComparison") {
- val row = new GenericRow(Array[Any]("abc", null))
+ val row = create_row("abc", null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
@@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
test("Substring") {
- val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+ val row = create_row("example", "example".toArray.map(_.toByte))
val s = 'a.string.at(0)
@@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
// substring(null, _, _) -> null
checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)),
- null, new GenericRow(Array[Any](null)))
+ null, create_row(null))
// substring(_, null, _) -> null
checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)),
@@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("SQRT") {
val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
+ val rowSequence = inputSequence.map(l => create_row(l.toDouble))
val d = 'a.double.at(0)
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null)))
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(-1), null, EmptyRow)
checkEvaluation(Sqrt(-1.5), null, EmptyRow)
}
test("Bitwise operations") {
- val row = new GenericRow(Array[Any](1, 2, 3, null))
+ val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index 275ea2627e..bcc0c404d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
@@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](expected))
+ val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
new file mode 100644
index 0000000000..a22aa6f244
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
@@ -0,0 +1,70 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+// scalastyle:off
+class UTF8StringSuite extends FunSuite {
+ test("basic") {
+ def check(str: String, len: Int) {
+
+ assert(UTF8String(str).length == len)
+ assert(UTF8String(str.getBytes("utf8")).length() == len)
+
+ assert(UTF8String(str) == str)
+ assert(UTF8String(str.getBytes("utf8")) == str)
+ assert(UTF8String(str).toString == str)
+ assert(UTF8String(str.getBytes("utf8")).toString == str)
+ assert(UTF8String(str.getBytes("utf8")) == UTF8String(str))
+
+ assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode())
+ }
+
+ check("hello", 5)
+ check("世 界", 3)
+ }
+
+ test("contains") {
+ assert(UTF8String("hello").contains(UTF8String("ello")))
+ assert(!UTF8String("hello").contains(UTF8String("vello")))
+ assert(UTF8String("大千世界").contains(UTF8String("千世")))
+ assert(!UTF8String("大千世界").contains(UTF8String("世千")))
+ }
+
+ test("prefix") {
+ assert(UTF8String("hello").startsWith(UTF8String("hell")))
+ assert(!UTF8String("hello").startsWith(UTF8String("ell")))
+ assert(UTF8String("大千世界").startsWith(UTF8String("大千")))
+ assert(!UTF8String("大千世界").startsWith(UTF8String("千")))
+ }
+
+ test("suffix") {
+ assert(UTF8String("hello").endsWith(UTF8String("ello")))
+ assert(!UTF8String("hello").endsWith(UTF8String("ellov")))
+ assert(UTF8String("大千世界").endsWith(UTF8String("世界")))
+ assert(!UTF8String("大千世界").endsWith(UTF8String("世")))
+ }
+
+ test("slice") {
+ assert(UTF8String("hello").slice(1, 3) == UTF8String("el"))
+ assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大"))
+ assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世"))
+ assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界"))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b237fe684c..89a4faf35e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case FloatType => true
case DateType => true
case TimestampType => true
+ case StringType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 87a6631da8..b0f983c180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats {
}
private[sql] class StringColumnStats extends ColumnStats {
- protected var upper: String = null
- protected var lower: String = null
+ protected var upper: UTF8String = null
+ protected var lower: UTF8String = null
override def gatherStats(row: Row, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getString(ordinal)
+ val value = row(ordinal).asInstanceOf[UTF8String]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += STRING.actualSize(row, ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index c47497e066..1b9e0df2dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
@@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
row.getString(ordinal).getBytes("utf-8").length + 4
}
- override def append(v: String, buffer: ByteBuffer): Unit = {
- val stringBytes = v.getBytes("utf-8")
+ override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
+ val stringBytes = v.getBytes
buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
}
- override def extract(buffer: ByteBuffer): String = {
+ override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
val stringBytes = new Array[Byte](length)
buffer.get(stringBytes, 0, length)
- new String(stringBytes, "utf-8")
+ UTF8String(stringBytes)
}
- override def setField(row: MutableRow, ordinal: Int, value: String): Unit = {
- row.setString(ordinal, value)
+ override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
+ row.update(ordinal, value)
}
- override def getField(row: Row, ordinal: Int): String = row.getString(ordinal)
+ override def getField(row: Row, ordinal: Int): UTF8String = {
+ row(ordinal).asInstanceOf[UTF8String]
+ }
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
- to.setString(toOrdinal, from.getString(fromOrdinal))
+ to.update(toOrdinal, from(fromOrdinal))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 656bdd7212..1fd387eec7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SQLContext}
/**
* :: DeveloperApi ::
@@ -54,6 +54,33 @@ object RDDConversions {
}
}
}
+
+ /**
+ * Convert the objects inside Row into the types Catalyst expected.
+ */
+ def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = {
+ data.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val bufferedIterator = iterator.buffered
+ val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray)
+ val schemaFields = schema.fields.toArray
+ val converters = schemaFields.map {
+ f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
+ }
+ bufferedIterator.map { r =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) = converters(i)(r(i))
+ i += 1
+ }
+
+ mutableRow
+ }
+ }
+ }
+ }
}
/** Logical plan node for scanning data from an RDD. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index fad7a281dc..99f24910fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType}
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
/**
* A logical command that is executed for its side-effects. `RunnableCommand`s are
@@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
- override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+ override def execute(): RDD[Row] = {
+ val converted = sideEffectResult.map(r =>
+ CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row])
+ sqlContext.sparkContext.parallelize(converted, 1)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index e916e68e58..710787096e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -164,7 +164,7 @@ package object debug {
case (_: Long, LongType) =>
case (_: Int, IntegerType) =>
- case (_: String, StringType) =>
+ case (_: UTF8String, StringType) =>
case (_: Float, FloatType) =>
case (_: Byte, ByteType) =>
case (_: Short, ShortType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 5b308d88d4..7a43bfd8bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -140,6 +140,7 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
case (date: Int, DateType) => DateUtils.toJavaDate(date)
+ case (s: UTF8String, StringType) => s.toString
// Pyrolite can handle Timestamp and Decimal
case (other, _) => other
@@ -192,7 +193,8 @@ object EvaluatePython {
case (c: Long, IntegerType) => c.toInt
case (c: Int, LongType) => c.toLong
case (c: Double, FloatType) => c.toFloat
- case (c, StringType) if !c.isInstanceOf[String] => c.toString
+ case (c: String, StringType) => UTF8String(c)
+ case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString)
case (c, _) => c
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 463e1dcc26..b9022fcd9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -233,7 +233,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
- case stringValue: String => s"'${escapeSql(stringValue)}'"
+ case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
case _ => value
}
@@ -349,12 +349,14 @@ private[sql] class JDBCRDD(
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
+ // TODO(davies): convert Date into Int
case DateConversion => mutableRow.update(i, rs.getDate(pos))
case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos))
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 4fa84dc076..99b755c9f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -130,6 +130,8 @@ private[sql] case class JDBCRelation(
extends BaseRelation
with PrunedFilteredScan {
+ override val needConversion: Boolean = false
+
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index 34f864f5fd..d4e0abc040 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -18,11 +18,8 @@
package org.apache.spark.sql
import java.sql.{Connection, DriverManager, PreparedStatement}
-import org.apache.spark.{Logging, Partition}
-import org.apache.spark.sql._
-import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition}
+import org.apache.spark.Logging
import org.apache.spark.sql.types._
package object jdbc {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f4c99b4b56..e3352d0278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
-
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] class DefaultSource
@@ -113,6 +113,8 @@ private[sql] case class JSONRelation(
// TODO: Support partitioned JSON relation.
private def baseRDD = sqlContext.sparkContext.textFile(path)
+ override val needConversion: Boolean = false
+
override val schema = userSpecifiedSchema.getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index b1e8521383..29de7401dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
- case StringType => toString(value)
+ case StringType => UTF8String(toString(value))
case _ if value == null || value == "" => null // guard the non string type
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 43ca359b51..bc108e37df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.getBytes)
- protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- updateField(fieldIndex, value)
+ protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ updateField(fieldIndex, UTF8String(value))
protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, readTimestamp(value))
@@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, value.getBytes)
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- current.setString(fieldIndex, value)
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ current.update(fieldIndex, UTF8String(value))
override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, readTimestamp(value))
@@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter(
private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int)
extends CatalystPrimitiveConverter(parent, fieldIndex) {
- private[this] var dict: Array[String] = null
+ private[this] var dict: Array[Array[Byte]] = null
override def hasDictionarySupport: Boolean = true
override def setDictionary(dictionary: Dictionary):Unit =
- dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8}
-
+ dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes }
override def addValueFromDictionary(dictionaryId: Int): Unit =
parent.updateString(fieldIndex, dict(dictionaryId))
override def addBinary(value: Binary): Unit =
- parent.updateString(fieldIndex, value.toStringUsingUTF8)
+ parent.updateString(fieldIndex, value.getBytes)
}
private[parquet] object CatalystArrayConverter {
@@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter(
elements += 1
}
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = {
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = {
checkGrowBuffer()
- buffer(elements) = value.asInstanceOf[NativeType]
+ buffer(elements) = UTF8String(value).asInstanceOf[NativeType]
elements += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index 0357dcc468..5eb1c6abc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -55,7 +55,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
@@ -76,7 +76,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
@@ -94,7 +94,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -111,7 +111,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -128,7 +128,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -145,7 +145,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 5a1b15490d..e05a4c20b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
if (value != null) {
schema match {
case StringType => writer.addBinary(
- Binary.fromByteArray(
- value.asInstanceOf[String].getBytes("utf-8")
- )
- )
+ Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
@@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
index: Int): Unit = {
ctype match {
case StringType => writer.addBinary(
- Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8")))
+ Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(record.getInt(index))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 20fdf5e58e..af7b3c81ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext}
-
import parquet.filter2.predicate.FilterApi
import parquet.format.converter.ParquetMetadataConverter
import parquet.hadoop.metadata.CompressionCodecName
@@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions}
import org.apache.spark.sql.parquet.ParquetTypesConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _}
import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
-import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition}
/**
* Allows creation of Parquet based tables using the syntax:
@@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2(
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+ // Skip type conversion
+ override val needConversion: Boolean = false
+
// TODO Should calculate per scan size
// It's common that a query only scans a fraction of a large Parquet file. Returning size of the
// whole Parquet file disables some optimizations in this case (e.g. broadcast join).
@@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2(
baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
- case p if split.getPath.getParent.toString == p.path => p.values
+ case p if split.getPath.getParent.toString == p.path =>
+ CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row]
}.get
val requiredPartOrdinal = partitionKeyLocations.keys.toSeq
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 34d048e426..b3d71f687a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{UTF8String, StringType}
import org.apache.spark.sql.{Row, Strategy, execution, sources}
/**
@@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy {
(a, _) => t.buildScan(a)) :: Nil
case l @ LogicalRelation(t: TableScan) =>
- execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
+ createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
@@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy {
projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above.
.map(relation.attributeMap) // Match original case of attributes.
- val scan =
- execution.PhysicalRDD(
- projectList.map(_.toAttribute),
+ val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute),
scanBuilder(requestedColumns, pushedFilters))
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
- val scan =
- execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters))
+ val scan = createPhysicalRDD(relation.relation, requestedColumns,
+ scanBuilder(requestedColumns, pushedFilters))
execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
}
+ private[this] def createPhysicalRDD(
+ relation: BaseRelation,
+ output: Seq[Attribute],
+ rdd: RDD[Row]): SparkPlan = {
+ val converted = if (relation.needConversion) {
+ execution.RDDConversions.rowToRowRdd(rdd, relation.schema)
+ } else {
+ rdd
+ }
+ execution.PhysicalRDD(output, converted)
+ }
+
/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
* and convert them.
@@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy {
case expressions.Not(child) =>
translate(child).map(sources.Not)
- case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringStartsWith(a.name, v))
+ case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringStartsWith(a.name, v.toString))
- case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringEndsWith(a.name, v))
+ case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringEndsWith(a.name, v.toString))
- case expressions.Contains(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringContains(a.name, v))
+ case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringContains(a.name, v.toString))
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 8f9946a5a8..ca53dcdb92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -126,6 +126,16 @@ abstract class BaseRelation {
* could lead to execution plans that are suboptimal (i.e. broadcasting a very large table).
*/
def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
+
+ /**
+ * Whether does it need to convert the objects in Row to internal representation, for example:
+ * java.lang.String -> UTF8String
+ * java.lang.Decimal -> Decimal
+ *
+ * Note: The internal representation is not stable across releases and thus data sources outside
+ * of Spark SQL should leave this as true.
+ */
+ def needConversion: Boolean = true
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 36465cc2fa..bf6cf1321a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -30,7 +30,7 @@ class RowSuite extends FunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
expected.update(0, 2147483647)
- expected.update(1, "this is a string")
+ expected.setString(1, "this is a string")
expected.update(2, false)
expected.update(3, null)
val actual1 = Row(2147483647, "this is a string", false, null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0174aaee94..4c48dca444 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,18 +17,14 @@
package org.apache.spark.sql
-import org.apache.spark.sql.execution.GeneratedAggregate
-import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types._
-
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
+import org.apache.spark.sql.types._
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 5f08834f73..c86ef338fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -65,7 +65,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
- checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
@@ -108,8 +108,8 @@ class ColumnTypeSuite extends FunSuite with Logging {
testNativeColumnType[StringType.type](
STRING,
- (buffer: ByteBuffer, string: String) => {
- val bytes = string.getBytes("utf-8")
+ (buffer: ByteBuffer, string: UTF8String) => {
+ val bytes = string.getBytes
buffer.putInt(bytes.length)
buffer.put(bytes)
},
@@ -117,7 +117,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes)
- new String(bytes, "utf-8")
+ UTF8String(bytes)
})
testColumnType[BinaryType.type, Array[Byte]](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index b301818a00..f76314b9da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType}
object ColumnarTestUtils {
def makeNullRow(length: Int): GenericMutableRow = {
@@ -48,7 +48,7 @@ object ColumnarTestUtils {
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
- case STRING => Random.nextString(Random.nextInt(32))
+ case STRING => UTF8String(Random.nextString(Random.nextInt(32)))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 60c8c00bda..3b47b8adf3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -74,7 +74,7 @@ case class AllDataTypesScan(
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -82,7 +82,7 @@ case class AllDataTypesScan(
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}
}
}
@@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest {
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest {
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}.toSeq
before {
@@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
- (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq)
+ (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq)
test("Caching") {
// Cached Query Execution
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 921c6194c7..74ae984f34 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -34,7 +34,7 @@ import scala.collection.JavaConversions._
* 1. The Underlying data type in catalyst and in Hive
* In catalyst:
* Primitive =>
- * java.lang.String
+ * UTF8String
* int / scala.Int
* boolean / scala.Boolean
* float / scala.Float
@@ -239,9 +239,10 @@ private[hive] trait HiveInspectors {
*/
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
- case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString
+ case poi: WritableConstantStringObjectInspector =>
+ UTF8String(poi.getWritableConstantValue.toString)
case poi: WritableConstantHiveVarcharObjectInspector =>
- poi.getWritableConstantValue.getHiveVarchar.getValue
+ UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue)
case poi: WritableConstantHiveDecimalObjectInspector =>
HiveShim.toCatalystDecimal(
PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
@@ -284,10 +285,13 @@ private[hive] trait HiveInspectors {
case pi: PrimitiveObjectInspector => pi match {
// We think HiveVarchar is also a String
case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
- hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue
- case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue
+ UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+ case hvoi: HiveVarcharObjectInspector =>
+ UTF8String(hvoi.getPrimitiveJavaObject(data).getValue)
case x: StringObjectInspector if x.preferWritable() =>
- x.getPrimitiveWritableObject(data).toString
+ UTF8String(x.getPrimitiveWritableObject(data).toString)
+ case x: StringObjectInspector =>
+ UTF8String(x.getPrimitiveJavaObject(data))
case x: IntObjectInspector if x.preferWritable() => x.get(data)
case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
case x: FloatObjectInspector if x.preferWritable() => x.get(data)
@@ -340,7 +344,9 @@ private[hive] trait HiveInspectors {
*/
protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
- (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
+ (o: Any) =>
+ val s = o.asInstanceOf[UTF8String].toString
+ new HiveVarchar(s, s.size)
case _: JavaHiveDecimalObjectInspector =>
(o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal)
@@ -409,7 +415,7 @@ private[hive] trait HiveInspectors {
case x: PrimitiveObjectInspector => x match {
// TODO we don't support the HiveVarcharObjectInspector yet.
case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a)
- case _: StringObjectInspector => a.asInstanceOf[java.lang.String]
+ case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString()
case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a)
case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 1ccb0c279c..a6f4fbe8ab 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -17,24 +17,21 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.catalyst.expressions.Row
-
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.expressions.{Row, _}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.sources.DescribeCommand
-import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing}
+import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.types.StringType
@@ -131,7 +128,7 @@ private[hive] trait HiveStrategies {
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
- inputData(i) = partitionValues(i)
+ inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i))
i += 1
}
pruningCondition(inputData)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 8efed7f029..cab0fdd357 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.hive.execution
-import java.io.{BufferedReader, InputStreamReader}
-import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
import java.util.Properties
import scala.collection.JavaConversions._
@@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.hive.HiveShim._
+import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
+import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
/**
@@ -121,14 +121,13 @@ case class ScriptTransformation(
if (outputSerde == null) {
val prevLine = curLine
curLine = reader.readLine()
-
if (!ioschema.schemaLess) {
- new GenericRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ new GenericRow(CatalystTypeConverters.convertToCatalyst(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")))
.asInstanceOf[Array[Any]])
} else {
- new GenericRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
+ new GenericRow(CatalystTypeConverters.convertToCatalyst(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2))
.asInstanceOf[Array[Any]])
}
} else {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 902a12785e..a40a1e5311 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
/**
* Analyzes the given table in the current database to generate statistics, which will be
@@ -76,6 +76,12 @@ case class DropTable(
private[hive]
case class AddJar(path: String) extends RunnableCommand {
+ override val output: Seq[Attribute] = {
+ val schema = StructType(
+ StructField("result", IntegerType, false) :: Nil)
+ schema.toAttributes
+ }
+
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
hiveContext.runSqlHive(s"ADD JAR $path")
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 0ed93c2c5b..33e96eaabf 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
import org.apache.hadoop.io.{NullWritable, Writable}
import org.apache.hadoop.mapred.InputFormat
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType}
private[hive] case class HiveFunctionWrapper(functionClassName: String)
extends java.io.Serializable {
@@ -135,7 +135,7 @@ private[hive] object HiveShim {
PrimitiveCategory.VOID, null)
def getStringWritable(value: Any): hadoopIo.Text =
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString)
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 7577309900..d331c210e8 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -17,37 +17,35 @@
package org.apache.spark.sql.hive
-import java.util
-import java.util.{ArrayList => JArrayList}
-import java.util.Properties
import java.rmi.server.UID
+import java.util.{Properties, ArrayList => JArrayList}
import scala.collection.JavaConversions._
import scala.language.implicitConversions
+import com.esotericsoftware.kryo.Kryo
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.{NullWritable, Writable}
-import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.hive.common.StatsSetupConst
-import org.apache.hadoop.hive.common.`type`.{HiveDecimal}
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
-import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
+import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector}
-import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
-import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector}
+import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory}
+import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo}
+import org.apache.hadoop.io.{NullWritable, Writable}
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
-import org.apache.spark.sql.types.{Decimal, DecimalType}
-
+import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String}
/**
* This class provides the UDF creation and also the UDF instance serialization and
@@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
// for Serialization
def this() = this(null)
- import java.io.{OutputStream, InputStream}
- import com.esotericsoftware.kryo.Kryo
import org.apache.spark.util.Utils._
- import org.apache.hadoop.hive.ql.exec.Utilities
- import org.apache.hadoop.hive.ql.exec.UDF
@transient
private val methodDeSerialize = {
val method = classOf[Utilities].getDeclaredMethod(
"deserializeObjectByKryo",
classOf[Kryo],
- classOf[InputStream],
+ classOf[java.io.InputStream],
classOf[Class[_]])
method.setAccessible(true)
@@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
"serializeObjectByKryo",
classOf[Kryo],
classOf[Object],
- classOf[OutputStream])
+ classOf[java.io.OutputStream])
method.setAccessible(true)
method
@@ -224,7 +218,7 @@ private[hive] object HiveShim {
TypeInfoFactory.voidTypeInfo, null)
def getStringWritable(value: Any): hadoopIo.Text =
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString)
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])