aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-12-08 10:25:57 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-08 10:25:57 -0800
commit5d96a710a5ed543ec81e383620fc3b2a808b26a1 (patch)
treefd82e4613e93cc75859d3052c9730d62b119e844 /sql
parentc0b13d5565c45ae2acbe8cfb17319c92b6a634e4 (diff)
downloadspark-5d96a710a5ed543ec81e383620fc3b2a808b26a1.tar.gz
spark-5d96a710a5ed543ec81e383620fc3b2a808b26a1.tar.bz2
spark-5d96a710a5ed543ec81e383620fc3b2a808b26a1.zip
[SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs
This PR contains the following updates: - Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`. - Replaced all the `queryExecution.analyzed` by the function call `logicalPlan` - A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`) - A few API descriptions are wrong. (e.g., `mapPartitions`) marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you! Author: gatorsmile <gatorsmile@gmail.com> Closes #10184 from gatorsmile/datasetClean.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala80
1 files changed, 40 insertions, 40 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d6bb1d2ad8..3bd18a14f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
tEncoder: Encoder[T]) extends Queryable with Serializable {
/**
- * An unresolved version of the internal encoder for the type of this dataset. This one is marked
- * implicit so that we can use it when constructing new [[Dataset]] objects that have the same
- * object type (that will be possibly resolved to a different schema).
+ * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
+ * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+ * same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
+ unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+ /**
+ * The encoder where the expressions used to construct an object from an input row have been
+ * bound to the ordinals of the given schema.
+ */
+ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
private implicit def classTag = resolvedTEncoder.clsTag
@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
override def schema: StructType = resolvedTEncoder.schema
/**
- * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
+ * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
* ************* */
/**
- * Returns a new `Dataset` where each record has been mapped on to the specified type. The
+ * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`)
@@ -145,7 +151,7 @@ class Dataset[T] private[sql](
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
/**
- * Returns this Dataset.
+ * Returns this [[Dataset]].
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
@@ -153,15 +159,12 @@ class Dataset[T] private[sql](
def toDS(): Dataset[T] = this
/**
- * Converts this Dataset to an RDD.
+ * Converts this [[Dataset]] to an [[RDD]].
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = resolvedTEncoder
- val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
- val bound = tEnc.bind(input)
- iter.map(bound.fromRow)
+ iter.map(boundTEncoder.fromRow)
}
}
@@ -189,7 +192,7 @@ class Dataset[T] private[sql](
def show(numRows: Int): Unit = show(numRows, truncate = true)
/**
- * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
+ * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @since 1.6.0
@@ -197,7 +200,7 @@ class Dataset[T] private[sql](
def show(): Unit = show(20)
/**
- * Displays the top 20 rows of [[DataFrame]] in a tabular form.
+ * Displays the top 20 rows of [[Dataset]] in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
def show(truncate: Boolean): Unit = show(20, truncate)
/**
- * Displays the [[DataFrame]] in a tabular form. For example:
+ * Displays the [[Dataset]] in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Runs `func` on each element of this Dataset.
+ * Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)
/**
* (Java-specific)
- * Runs `func` on each element of this Dataset.
+ * Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
- * Runs `func` on each partition of this Dataset.
+ * Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
/**
* (Java-specific)
- * Runs `func` on each partition of this Dataset.
+ * Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
@@ -374,7 +377,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -382,7 +385,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -390,11 +393,11 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
- val inputPlan = queryExecution.analyzed
+ val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
- def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
- groupBy(f.call(_))(encoder)
+ def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ groupBy(func.call(_))(encoder)
/* ****************** *
* Typed Relational *
* ****************** */
/**
- * Selects a set of column based expressions.
+ * Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
- resolvedTEncoder.bind(queryExecution.analyzed.output),
- queryExecution.analyzed.output).named :: Nil,
+ boundTEncoder,
+ logicalPlan.output).named :: Nil,
logicalPlan))
}
@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
+ columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
@@ -654,7 +657,7 @@ class Dataset[T] private[sql](
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
@@ -662,17 +665,14 @@ class Dataset[T] private[sql](
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
- val tEnc = resolvedTEncoder
- val input = queryExecution.analyzed.output
- val bound = tEnc.bind(input)
- queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
+ queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
}
/**
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)