aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala185
1 files changed, 162 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 703ea4d149..e216945fbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,8 +22,10 @@ import java.io.CharArrayWriter
import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonFactory
+import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
@@ -39,11 +41,12 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
+import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -150,10 +153,10 @@ private[sql] object Dataset {
* @since 1.6.0
*/
class Dataset[T] private[sql](
- @transient override val sqlContext: SQLContext,
- @DeveloperApi @transient override val queryExecution: QueryExecution,
+ @transient val sqlContext: SQLContext,
+ @DeveloperApi @transient val queryExecution: QueryExecution,
encoder: Encoder[T])
- extends Queryable with Serializable {
+ extends Serializable {
queryExecution.assertAnalyzed()
@@ -224,7 +227,7 @@ class Dataset[T] private[sql](
* @param _numRows Number of rows to show
* @param truncate Whether truncate long strings and align cells right
*/
- override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
+ private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
val numRows = _numRows.max(0)
val takeResult = take(numRows + 1)
val hasMoreData = takeResult.length > numRows
@@ -249,7 +252,75 @@ class Dataset[T] private[sql](
}: Seq[String]
}
- formatString ( rows, numRows, hasMoreData, truncate )
+ val sb = new StringBuilder
+ val numCols = schema.fieldNames.length
+
+ // Initialise the width of each column to a minimum value of '3'
+ val colWidths = Array.fill(numCols)(3)
+
+ // Compute the width of each column
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Create SeparateLine
+ val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
+
+ // column names
+ rows.head.zipWithIndex.map { case (cell, i) =>
+ if (truncate) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+
+ sb.append(sep)
+
+ // data
+ rows.tail.map {
+ _.zipWithIndex.map { case (cell, i) =>
+ if (truncate) {
+ StringUtils.leftPad(cell.toString, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell.toString, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+ }
+
+ sb.append(sep)
+
+ // For Data that has more than "numRows" records
+ if (hasMoreData) {
+ val rowsString = if (numRows == 1) "row" else "rows"
+ sb.append(s"only showing top $numRows $rowsString\n")
+ }
+
+ sb.toString()
+ }
+
+ override def toString: String = {
+ try {
+ val builder = new StringBuilder
+ val fields = schema.take(2).map {
+ case f => s"${f.name}: ${f.dataType.simpleString(2)}"
+ }
+ builder.append("[")
+ builder.append(fields.mkString(", "))
+ if (schema.length > 2) {
+ if (schema.length - fields.size == 1) {
+ builder.append(" ... 1 more field")
+ } else {
+ builder.append(" ... " + (schema.length - 2) + " more fields")
+ }
+ }
+ builder.append("]").toString()
+ } catch {
+ case NonFatal(e) =>
+ s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+ }
}
/**
@@ -325,7 +396,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
// scalastyle:off println
- override def printSchema(): Unit = println(schema.treeString)
+ def printSchema(): Unit = println(schema.treeString)
// scalastyle:on println
/**
@@ -334,7 +405,7 @@ class Dataset[T] private[sql](
* @group basic
* @since 1.6.0
*/
- override def explain(extended: Boolean): Unit = {
+ def explain(extended: Boolean): Unit = {
val explain = ExplainCommand(queryExecution.logical, extended = extended)
sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
// scalastyle:off println
@@ -349,7 +420,7 @@ class Dataset[T] private[sql](
* @group basic
* @since 1.6.0
*/
- override def explain(): Unit = explain(extended = false)
+ def explain(): Unit = explain(extended = false)
/**
* Returns all column names and their data types as an array.
@@ -379,6 +450,22 @@ class Dataset[T] private[sql](
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
/**
+ * Returns true if this [[Dataset]] contains one or more sources that continuously
+ * return data as it arrives. A [[Dataset]] that reads data from a streaming source
+ * must be executed as a [[ContinuousQuery]] using the `startStream()` method in
+ * [[DataFrameWriter]]. Methods that return a single answer, (e.g., `count()` or
+ * `collect()`) will throw an [[AnalysisException]] when there is a streaming
+ * source present.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ @Experimental
+ def isStreaming: Boolean = logicalPlan.find { n =>
+ n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation]
+ }.isDefined
+
+ /**
* Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated,
* and all cells will be aligned right. For example:
* {{{
@@ -678,7 +765,8 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+
+ withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
@@ -1404,6 +1492,8 @@ class Dataset[T] private[sql](
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
+ * For Java API, use [[randomSplitAsList]].
+ *
* @group typedrel
* @since 2.0.0
*/
@@ -1422,6 +1512,20 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a Java list that contains randomly split [[Dataset]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
+ val values = randomSplit(weights, seed)
+ java.util.Arrays.asList(values : _*)
+ }
+
+ /**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
@@ -1790,7 +1894,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+ def filter(func: T => Boolean): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
+ val condition = Invoke(function, "apply", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1801,7 +1911,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+ def filter(func: FilterFunction[T]): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
+ val condition = Invoke(function, "call", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1812,7 +1928,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+ def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
+ MapElements[T, U](func, logicalPlan)
+ }
/**
* :: Experimental ::
@@ -1823,8 +1941,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ implicit val uEnc = encoder
+ withTypedPlan(MapElements[T, U](func, logicalPlan))
+ }
/**
* :: Experimental ::
@@ -1987,6 +2107,24 @@ class Dataset[T] private[sql](
}
/**
+ * Return an iterator that contains all of [[Row]]s in this [[Dataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[Dataset]].
+ *
+ * Note: this results in multiple Spark jobs, and if the input Dataset is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input Dataset should be cached first.
+ *
+ * @group action
+ * @since 2.0.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
+ withNewExecutionId {
+ queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ }
+ }
+
+ /**
* Returns the number of rows in the [[Dataset]].
* @group action
* @since 1.6.0
@@ -2007,7 +2145,7 @@ class Dataset[T] private[sql](
/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions into
- * `numPartitions`. The resulting Datasetis hash partitioned.
+ * `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
@@ -2230,6 +2368,12 @@ class Dataset[T] private[sql](
}
}
+ protected[sql] def toPythonIterator(): Int = {
+ withNewExecutionId {
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
@@ -2300,12 +2444,7 @@ class Dataset[T] private[sql](
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
- @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
- new Dataset[T](sqlContext, logicalPlan, encoder)
+ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+ Dataset(sqlContext, logicalPlan)
}
-
- private[sql] def withTypedPlan[R](
- other: Dataset[_], encoder: Encoder[R])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}