aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-15 13:06:38 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 13:06:38 -0700
commit85842760dc4616577162f44cc0fa9db9bd23bd9c (patch)
tree3f0d8c9e0b9cb75c6fed3e2e3d6b5302a384d600 /sql/core
parent785f95586b951d7b05481ee925fb95c20c4d6b6f (diff)
downloadspark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.gz
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.tar.bz2
spark-85842760dc4616577162f44cc0fa9db9bd23bd9c.zip
[SPARK-6638] [SQL] Improve performance of StringType in SQL
This PR change the internal representation for StringType from java.lang.String to UTF8String, which is implemented use ArrayByte. This PR should not break any public API, Row.getString() will still return java.lang.String. This is the first step of improve the performance of String in SQL. cc rxin Author: Davies Liu <davies@databricks.com> Closes #5350 from davies/string and squashes the following commits: 3b7bfa8 [Davies Liu] fix schema of AddJar 2772f0d [Davies Liu] fix new test failure 6d776a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 59025c8 [Davies Liu] address comments from @marmbrus 341ec2c [Davies Liu] turn off scala style check in UTF8StringSuite 744788f [Davies Liu] Merge branch 'master' of github.com:apache/spark into string b04a19c [Davies Liu] add comment for getString/setString 08d897b [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 5116b43 [Davies Liu] rollback unrelated changes 1314a37 [Davies Liu] address comments from Yin 867bf50 [Davies Liu] fix String filter push down 13d9d42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 2089d24 [Davies Liu] add hashcode check back ac18ae6 [Davies Liu] address comment fd11364 [Davies Liu] optimize UTF8String 8d17f21 [Davies Liu] fix hive compatibility tests e5fa5b8 [Davies Liu] remove clone in UTF8String 28f3d81 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 28d6f32 [Davies Liu] refactor 537631c [Davies Liu] some comment about Date 9f4c194 [Davies Liu] convert data type for data source 956b0a4 [Davies Liu] fix hive tests 73e4363 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 9dc32d1 [Davies Liu] fix some hive tests 23a766c [Davies Liu] refactor 8b45864 [Davies Liu] fix codegen with UTF8String bb52e44 [Davies Liu] fix scala style c7dd4d2 [Davies Liu] fix some catalyst tests 38c303e [Davies Liu] fix python sql tests 5f9e120 [Davies Liu] fix sql tests 6b499ac [Davies Liu] fix style a85fb27 [Davies Liu] refactor d32abd1 [Davies Liu] fix utf8 for python api 4699c3a [Davies Liu] use Array[Byte] in UTF8String 21f67c6 [Davies Liu] cleanup 685fd07 [Davies Liu] use UTF8String instead of String for StringType
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala10
23 files changed, 142 insertions, 86 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b237fe684c..89a4faf35e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case FloatType => true
case DateType => true
case TimestampType => true
+ case StringType => true
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 87a6631da8..b0f983c180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats {
}
private[sql] class StringColumnStats extends ColumnStats {
- protected var upper: String = null
- protected var lower: String = null
+ protected var upper: UTF8String = null
+ protected var lower: UTF8String = null
override def gatherStats(row: Row, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getString(ordinal)
+ val value = row(ordinal).asInstanceOf[UTF8String]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += STRING.actualSize(row, ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index c47497e066..1b9e0df2dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
@@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
row.getString(ordinal).getBytes("utf-8").length + 4
}
- override def append(v: String, buffer: ByteBuffer): Unit = {
- val stringBytes = v.getBytes("utf-8")
+ override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
+ val stringBytes = v.getBytes
buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
}
- override def extract(buffer: ByteBuffer): String = {
+ override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
val stringBytes = new Array[Byte](length)
buffer.get(stringBytes, 0, length)
- new String(stringBytes, "utf-8")
+ UTF8String(stringBytes)
}
- override def setField(row: MutableRow, ordinal: Int, value: String): Unit = {
- row.setString(ordinal, value)
+ override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
+ row.update(ordinal, value)
}
- override def getField(row: Row, ordinal: Int): String = row.getString(ordinal)
+ override def getField(row: Row, ordinal: Int): UTF8String = {
+ row(ordinal).asInstanceOf[UTF8String]
+ }
override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
- to.setString(toOrdinal, from.getString(fromOrdinal))
+ to.update(toOrdinal, from(fromOrdinal))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 656bdd7212..1fd387eec7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SQLContext}
/**
* :: DeveloperApi ::
@@ -54,6 +54,33 @@ object RDDConversions {
}
}
}
+
+ /**
+ * Convert the objects inside Row into the types Catalyst expected.
+ */
+ def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = {
+ data.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val bufferedIterator = iterator.buffered
+ val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray)
+ val schemaFields = schema.fields.toArray
+ val converters = schemaFields.map {
+ f => CatalystTypeConverters.createToCatalystConverter(f.dataType)
+ }
+ bufferedIterator.map { r =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) = converters(i)(r(i))
+ i += 1
+ }
+
+ mutableRow
+ }
+ }
+ }
+ }
}
/** Logical plan node for scanning data from an RDD. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index fad7a281dc..99f24910fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType}
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
/**
* A logical command that is executed for its side-effects. `RunnableCommand`s are
@@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray
- override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1)
+ override def execute(): RDD[Row] = {
+ val converted = sideEffectResult.map(r =>
+ CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row])
+ sqlContext.sparkContext.parallelize(converted, 1)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index e916e68e58..710787096e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -164,7 +164,7 @@ package object debug {
case (_: Long, LongType) =>
case (_: Int, IntegerType) =>
- case (_: String, StringType) =>
+ case (_: UTF8String, StringType) =>
case (_: Float, FloatType) =>
case (_: Byte, ByteType) =>
case (_: Short, ShortType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 5b308d88d4..7a43bfd8bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -140,6 +140,7 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
case (date: Int, DateType) => DateUtils.toJavaDate(date)
+ case (s: UTF8String, StringType) => s.toString
// Pyrolite can handle Timestamp and Decimal
case (other, _) => other
@@ -192,7 +193,8 @@ object EvaluatePython {
case (c: Long, IntegerType) => c.toInt
case (c: Int, LongType) => c.toLong
case (c: Double, FloatType) => c.toFloat
- case (c, StringType) if !c.isInstanceOf[String] => c.toString
+ case (c: String, StringType) => UTF8String(c)
+ case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString)
case (c, _) => c
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 463e1dcc26..b9022fcd9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -233,7 +233,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
- case stringValue: String => s"'${escapeSql(stringValue)}'"
+ case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
case _ => value
}
@@ -349,12 +349,14 @@ private[sql] class JDBCRDD(
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
+ // TODO(davies): convert Date into Int
case DateConversion => mutableRow.update(i, rs.getDate(pos))
case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos))
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.setString(i, rs.getString(pos))
case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 4fa84dc076..99b755c9f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -130,6 +130,8 @@ private[sql] case class JDBCRelation(
extends BaseRelation
with PrunedFilteredScan {
+ override val needConversion: Boolean = false
+
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index 34f864f5fd..d4e0abc040 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -18,11 +18,8 @@
package org.apache.spark.sql
import java.sql.{Connection, DriverManager, PreparedStatement}
-import org.apache.spark.{Logging, Partition}
-import org.apache.spark.sql._
-import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition}
+import org.apache.spark.Logging
import org.apache.spark.sql.types._
package object jdbc {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f4c99b4b56..e3352d0278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
-
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] class DefaultSource
@@ -113,6 +113,8 @@ private[sql] case class JSONRelation(
// TODO: Support partitioned JSON relation.
private def baseRDD = sqlContext.sparkContext.textFile(path)
+ override val needConversion: Boolean = false
+
override val schema = userSpecifiedSchema.getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index b1e8521383..29de7401dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging {
null
} else {
desiredType match {
- case StringType => toString(value)
+ case StringType => UTF8String(toString(value))
case _ if value == null || value == "" => null // guard the non string type
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 43ca359b51..bc108e37df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.getBytes)
- protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- updateField(fieldIndex, value)
+ protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ updateField(fieldIndex, UTF8String(value))
protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, readTimestamp(value))
@@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, value.getBytes)
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit =
- current.setString(fieldIndex, value)
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit =
+ current.update(fieldIndex, UTF8String(value))
override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit =
current.update(fieldIndex, readTimestamp(value))
@@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter(
private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int)
extends CatalystPrimitiveConverter(parent, fieldIndex) {
- private[this] var dict: Array[String] = null
+ private[this] var dict: Array[Array[Byte]] = null
override def hasDictionarySupport: Boolean = true
override def setDictionary(dictionary: Dictionary):Unit =
- dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8}
-
+ dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes }
override def addValueFromDictionary(dictionaryId: Int): Unit =
parent.updateString(fieldIndex, dict(dictionaryId))
override def addBinary(value: Binary): Unit =
- parent.updateString(fieldIndex, value.toStringUsingUTF8)
+ parent.updateString(fieldIndex, value.getBytes)
}
private[parquet] object CatalystArrayConverter {
@@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter(
elements += 1
}
- override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = {
+ override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = {
checkGrowBuffer()
- buffer(elements) = value.asInstanceOf[NativeType]
+ buffer(elements) = UTF8String(value).asInstanceOf[NativeType]
elements += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index 0357dcc468..5eb1c6abc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -55,7 +55,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
@@ -76,7 +76,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
- Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
@@ -94,7 +94,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -111,7 +111,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -128,7 +128,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -145,7 +145,7 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 5a1b15490d..e05a4c20b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
if (value != null) {
schema match {
case StringType => writer.addBinary(
- Binary.fromByteArray(
- value.asInstanceOf[String].getBytes("utf-8")
- )
- )
+ Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
@@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
index: Int): Unit = {
ctype match {
case StringType => writer.addBinary(
- Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8")))
+ Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(record.getInt(index))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 20fdf5e58e..af7b3c81ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext}
-
import parquet.filter2.predicate.FilterApi
import parquet.format.converter.ParquetMetadataConverter
import parquet.hadoop.metadata.CompressionCodecName
@@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions}
import org.apache.spark.sql.parquet.ParquetTypesConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _}
import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
-import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition}
/**
* Allows creation of Parquet based tables using the syntax:
@@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2(
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+ // Skip type conversion
+ override val needConversion: Boolean = false
+
// TODO Should calculate per scan size
// It's common that a query only scans a fraction of a large Parquet file. Returning size of the
// whole Parquet file disables some optimizations in this case (e.g. broadcast join).
@@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2(
baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
- case p if split.getPath.getParent.toString == p.path => p.values
+ case p if split.getPath.getParent.toString == p.path =>
+ CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row]
}.get
val requiredPartOrdinal = partitionKeyLocations.keys.toSeq
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 34d048e426..b3d71f687a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{UTF8String, StringType}
import org.apache.spark.sql.{Row, Strategy, execution, sources}
/**
@@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy {
(a, _) => t.buildScan(a)) :: Nil
case l @ LogicalRelation(t: TableScan) =>
- execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
+ createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
@@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy {
projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above.
.map(relation.attributeMap) // Match original case of attributes.
- val scan =
- execution.PhysicalRDD(
- projectList.map(_.toAttribute),
+ val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute),
scanBuilder(requestedColumns, pushedFilters))
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
- val scan =
- execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters))
+ val scan = createPhysicalRDD(relation.relation, requestedColumns,
+ scanBuilder(requestedColumns, pushedFilters))
execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
}
+ private[this] def createPhysicalRDD(
+ relation: BaseRelation,
+ output: Seq[Attribute],
+ rdd: RDD[Row]): SparkPlan = {
+ val converted = if (relation.needConversion) {
+ execution.RDDConversions.rowToRowRdd(rdd, relation.schema)
+ } else {
+ rdd
+ }
+ execution.PhysicalRDD(output, converted)
+ }
+
/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
* and convert them.
@@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy {
case expressions.Not(child) =>
translate(child).map(sources.Not)
- case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringStartsWith(a.name, v))
+ case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringStartsWith(a.name, v.toString))
- case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringEndsWith(a.name, v))
+ case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringEndsWith(a.name, v.toString))
- case expressions.Contains(a: Attribute, Literal(v: String, StringType)) =>
- Some(sources.StringContains(a.name, v))
+ case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
+ Some(sources.StringContains(a.name, v.toString))
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 8f9946a5a8..ca53dcdb92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -126,6 +126,16 @@ abstract class BaseRelation {
* could lead to execution plans that are suboptimal (i.e. broadcasting a very large table).
*/
def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
+
+ /**
+ * Whether does it need to convert the objects in Row to internal representation, for example:
+ * java.lang.String -> UTF8String
+ * java.lang.Decimal -> Decimal
+ *
+ * Note: The internal representation is not stable across releases and thus data sources outside
+ * of Spark SQL should leave this as true.
+ */
+ def needConversion: Boolean = true
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 36465cc2fa..bf6cf1321a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -30,7 +30,7 @@ class RowSuite extends FunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
expected.update(0, 2147483647)
- expected.update(1, "this is a string")
+ expected.setString(1, "this is a string")
expected.update(2, false)
expected.update(3, null)
val actual1 = Row(2147483647, "this is a string", false, null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0174aaee94..4c48dca444 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,18 +17,14 @@
package org.apache.spark.sql
-import org.apache.spark.sql.execution.GeneratedAggregate
-import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types._
-
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
+import org.apache.spark.sql.types._
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 5f08834f73..c86ef338fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -65,7 +65,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
- checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
@@ -108,8 +108,8 @@ class ColumnTypeSuite extends FunSuite with Logging {
testNativeColumnType[StringType.type](
STRING,
- (buffer: ByteBuffer, string: String) => {
- val bytes = string.getBytes("utf-8")
+ (buffer: ByteBuffer, string: UTF8String) => {
+ val bytes = string.getBytes
buffer.putInt(bytes.length)
buffer.put(bytes)
},
@@ -117,7 +117,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes)
- new String(bytes, "utf-8")
+ UTF8String(bytes)
})
testColumnType[BinaryType.type, Array[Byte]](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index b301818a00..f76314b9da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType}
object ColumnarTestUtils {
def makeNullRow(length: Int): GenericMutableRow = {
@@ -48,7 +48,7 @@ object ColumnarTestUtils {
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
- case STRING => Random.nextString(Random.nextInt(32))
+ case STRING => UTF8String(Random.nextString(Random.nextInt(32)))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 60c8c00bda..3b47b8adf3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -74,7 +74,7 @@ case class AllDataTypesScan(
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -82,7 +82,7 @@ case class AllDataTypesScan(
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}
}
}
@@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest {
i.toDouble,
new java.math.BigDecimal(i),
new java.math.BigDecimal(i),
- new Date((i + 1) * 8640000),
+ new Date(1970, 1, 1),
new Timestamp(20000 + i),
s"varchar_$i",
Seq(i, i + 1),
@@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest {
Map(i -> i.toString),
Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
Row(i, i.toString),
- Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1)))))
}.toSeq
before {
@@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
- (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq)
+ (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq)
test("Caching") {
// Cached Query Execution