aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-25 22:44:26 -0700
committerDavies Liu <davies@databricks.com>2015-06-25 22:44:26 -0700
commit1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb (patch)
tree85061123f7de88e73e317304ee8fec392f54e7db
parent40360112c417b5432564f4bcb8a9100f4066b55e (diff)
downloadspark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.gz
spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.bz2
spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.zip
[SPARK-8635] [SQL] improve performance of CatalystTypeConverters
In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7018 from cloud-fan/converter and squashes the following commits: 8b16630 [Wenchen Fan] another fix 326c82c [Wenchen Fan] optimize type converter
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala60
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala4
8 files changed, 48 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 429fc4077b..012f8bbecb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -52,6 +52,13 @@ object CatalystTypeConverters {
}
}
+ private def isWholePrimitive(dt: DataType): Boolean = dt match {
+ case dt if isPrimitive(dt) => true
+ case ArrayType(elementType, _) => isWholePrimitive(elementType)
+ case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
+ case _ => false
+ }
+
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
@@ -148,6 +155,8 @@ object CatalystTypeConverters {
private[this] val elementConverter = getConverterForType(elementType)
+ private[this] val isNoChange = isWholePrimitive(elementType)
+
override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
scalaValue match {
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
@@ -166,8 +175,10 @@ object CatalystTypeConverters {
override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
if (catalystValue == null) {
null
+ } else if (isNoChange) {
+ catalystValue
} else {
- catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
+ catalystValue.map(elementConverter.toScala)
}
}
@@ -183,6 +194,8 @@ object CatalystTypeConverters {
private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)
+ private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)
+
override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
case m: Map[_, _] =>
m.map { case (k, v) =>
@@ -203,6 +216,8 @@ object CatalystTypeConverters {
override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
if (catalystValue == null) {
null
+ } else if (isNoChange) {
+ catalystValue
} else {
catalystValue.map { case (k, v) =>
keyConverter.toScala(k) -> valueConverter.toScala(v)
@@ -258,16 +273,13 @@ object CatalystTypeConverters {
toScala(row(column).asInstanceOf[InternalRow])
}
- private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
+ private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
case str: String => UTF8String.fromString(str)
case utf8: UTF8String => utf8
}
- override def toScala(catalystValue: Any): String = catalystValue match {
- case null => null
- case str: String => str
- case utf8: UTF8String => utf8.toString()
- }
+ override def toScala(catalystValue: UTF8String): String =
+ if (catalystValue == null) null else catalystValue.toString
override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString
}
@@ -275,7 +287,8 @@ object CatalystTypeConverters {
override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
override def toScala(catalystValue: Any): Date =
if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
- override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
+ override def toScalaImpl(row: InternalRow, column: Int): Date =
+ DateTimeUtils.toJavaDate(row.getInt(column))
}
private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
@@ -285,7 +298,7 @@ object CatalystTypeConverters {
if (catalystValue == null) null
else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
- toScala(row.getLong(column))
+ DateTimeUtils.toJavaTimestamp(row.getLong(column))
}
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
@@ -296,10 +309,7 @@ object CatalystTypeConverters {
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
- row.get(column) match {
- case d: JavaBigDecimal => d
- case d: Decimal => d.toJavaBigDecimal
- }
+ row.get(column).asInstanceOf[Decimal].toJavaBigDecimal
}
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
@@ -363,6 +373,19 @@ object CatalystTypeConverters {
}
/**
+ * Creates a converter function that will convert Catalyst types to Scala type.
+ * Typical use case would be converting a collection of rows that have the same schema. You will
+ * call this function once to get a converter, and apply it to every row.
+ */
+ private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
+ if (isPrimitive(dataType)) {
+ identity
+ } else {
+ getConverterForType(dataType).toScala
+ }
+ }
+
+ /**
* Converts Scala objects to Catalyst rows / types.
*
* Note: This should be called before do evaluation on Row
@@ -389,15 +412,6 @@ object CatalystTypeConverters {
* produced by createToScalaConverter.
*/
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
- getConverterForType(dataType).toScala(catalystValue)
- }
-
- /**
- * Creates a converter function that will convert Catalyst types to Scala type.
- * Typical use case would be converting a collection of rows that have the same schema. You will
- * call this function once to get a converter, and apply it to every row.
- */
- private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
- getConverterForType(dataType).toScala
+ createToScalaConverter(dataType)(catalystValue)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 3992f1f59d..55df72f102 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types.DataType
@@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
(1 to 22).map { x =>
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _)
- lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
+ val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _)
val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _)
s"""case $x =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f3f0f53053..0db4df34f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1418,12 +1418,14 @@ class DataFrame private[sql](
lazy val rdd: RDD[Row] = {
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
- queryExecution.executedPlan.execute().mapPartitions { rows =>
+ internalRowRdd.mapPartitions { rows =>
val converter = CatalystTypeConverters.createToScalaConverter(schema)
rows.map(converter(_).asInstanceOf[Row])
}
}
+ private[sql] def internalRowRdd = queryExecution.executedPlan.execute()
+
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
* @group rdd
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 8df1da037c..3ebbf96090 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}
- val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
+ val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 93383e5a62..252c611d02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging {
s"with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
- df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
+ df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index a8f56f4767..ce16e050c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.relation.needConversion) {
- execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
+ execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
} else {
rdd.map(_.asInstanceOf[InternalRow])
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index fb6173f58e..dbb369cf45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
writerContainer.driverSideSetup()
try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
@@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
writerContainer.driverSideSetup()
try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 79eac930e5..de0ed0c042 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -88,9 +88,9 @@ case class AllDataTypesScan(
UTF8String.fromString(s"varchar_$i"),
Seq(i, i + 1),
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
- Map(i -> i.toString),
+ Map(i -> UTF8String.fromString(i.toString)),
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
- Row(i, i.toString),
+ Row(i, UTF8String.fromString(i.toString)),
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}