From 201236628a344194f7c20ba8e9afeeaefbe9318c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Feb 2015 10:52:18 -0800 Subject: [SPARK-5532][SQL] Repartition should not use external rdd representation Author: Michael Armbrust Closes #4738 from marmbrus/udtRepart and squashes the following commits: c06d7b5 [Michael Armbrust] fix compilation 91c8829 [Michael Armbrust] [SQL][SPARK-5532] Repartition should not use external rdd representation --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 +++-- .../org/apache/spark/sql/execution/debug/package.scala | 1 + .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27ac398063..04bf5d9b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -799,14 +799,15 @@ class DataFrame protected[sql]( * Returns the number of rows in the [[DataFrame]]. * @group action */ - override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + override def count(): Long = groupBy().count().collect().head.getLong(0) /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame(rdd.repartition(numPartitions), schema) + sqlContext.createDataFrame( + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 73162b22fa..ffe388cfa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -168,6 +168,7 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 5f21d990e2..9c098df24c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import java.io.File + import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -91,4 +92,17 @@ class UserDefinedTypeSuite extends QueryTest { sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } + + + test("UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + } + + test("Repartition UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + } } -- cgit v1.2.3