diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-06-25 22:44:26 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-25 22:44:26 -0700 |
commit | 1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb (patch) | |
tree | 85061123f7de88e73e317304ee8fec392f54e7db /sql | |
parent | 40360112c417b5432564f4bcb8a9100f4066b55e (diff) | |
download | spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.gz spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.bz2 spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.zip |
[SPARK-8635] [SQL] improve performance of CatalystTypeConverters
In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7018 from cloud-fan/converter and squashes the following commits:
8b16630 [Wenchen Fan] another fix
326c82c [Wenchen Fan] optimize type converter
Diffstat (limited to 'sql')
8 files changed, 48 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 429fc4077b..012f8bbecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -52,6 +52,13 @@ object CatalystTypeConverters { } } + private def isWholePrimitive(dt: DataType): Boolean = dt match { + case dt if isPrimitive(dt) => true + case ArrayType(elementType, _) => isWholePrimitive(elementType) + case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) + case _ => false + } + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -148,6 +155,8 @@ object CatalystTypeConverters { private[this] val elementConverter = getConverterForType(elementType) + private[this] val isNoChange = isWholePrimitive(elementType) + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { scalaValue match { case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) @@ -166,8 +175,10 @@ object CatalystTypeConverters { override def toScala(catalystValue: Seq[Any]): Seq[Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { - catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) + catalystValue.map(elementConverter.toScala) } } @@ -183,6 +194,8 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) + private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { case m: Map[_, _] => m.map { case (k, v) => @@ -203,6 +216,8 @@ object CatalystTypeConverters { override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { catalystValue.map { case (k, v) => keyConverter.toScala(k) -> valueConverter.toScala(v) @@ -258,16 +273,13 @@ object CatalystTypeConverters { toScala(row(column).asInstanceOf[InternalRow]) } - private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 } - override def toScala(catalystValue: Any): String = catalystValue match { - case null => null - case str: String => str - case utf8: UTF8String => utf8.toString() - } + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString } @@ -275,7 +287,8 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) override def toScala(catalystValue: Any): Date = if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) - override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column)) + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) } private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { @@ -285,7 +298,7 @@ object CatalystTypeConverters { if (catalystValue == null) null else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) override def toScalaImpl(row: InternalRow, column: Int): Timestamp = - toScala(row.getLong(column)) + DateTimeUtils.toJavaTimestamp(row.getLong(column)) } private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { @@ -296,10 +309,7 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column) match { - case d: JavaBigDecimal => d - case d: Decimal => d.toJavaBigDecimal - } + row.get(column).asInstanceOf[Decimal].toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { @@ -363,6 +373,19 @@ object CatalystTypeConverters { } /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala + } + } + + /** * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row @@ -389,15 +412,6 @@ object CatalystTypeConverters { * produced by createToScalaConverter. */ def convertToScala(catalystValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toScala(catalystValue) - } - - /** - * Creates a converter function that will convert Catalyst types to Scala type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { - getConverterForType(dataType).toScala + createToScalaConverter(dataType)(catalystValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 3992f1f59d..55df72f102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f3f0f53053..0db4df34f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1418,12 +1418,14 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().mapPartitions { rows => + internalRowRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } + private[sql] def internalRowRdd = queryExecution.executedPlan.execute() + /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 8df1da037c..3ebbf96090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 93383e5a62..252c611d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index a8f56f4767..ce16e050c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType)) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd.map(_.asInstanceOf[InternalRow]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fb6173f58e..dbb369cf45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 79eac930e5..de0ed0c042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> i.toString), + Map(i -> UTF8String.fromString(i.toString)), Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - Row(i, i.toString), + Row(i, UTF8String.fromString(i.toString)), Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } |