diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 185 |
1 files changed, 162 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 703ea4d149..e216945fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,8 +22,10 @@ import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD @@ -39,11 +41,12 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression -import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -150,10 +153,10 @@ private[sql] object Dataset { * @since 1.6.0 */ class Dataset[T] private[sql]( - @transient override val sqlContext: SQLContext, - @DeveloperApi @transient override val queryExecution: QueryExecution, + @transient val sqlContext: SQLContext, + @DeveloperApi @transient val queryExecution: QueryExecution, encoder: Encoder[T]) - extends Queryable with Serializable { + extends Serializable { queryExecution.assertAnalyzed() @@ -224,7 +227,7 @@ class Dataset[T] private[sql]( * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) val takeResult = take(numRows + 1) val hasMoreData = takeResult.length > numRows @@ -249,7 +252,75 @@ class Dataset[T] private[sql]( }: Seq[String] } - formatString ( rows, numRows, hasMoreData, truncate ) + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } + + override def toString: String = { + try { + val builder = new StringBuilder + val fields = schema.take(2).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("[") + builder.append(fields.mkString(", ")) + if (schema.length > 2) { + if (schema.length - fields.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (schema.length - 2) + " more fields") + } + } + builder.append("]").toString() + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } } /** @@ -325,7 +396,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ // scalastyle:off println - override def printSchema(): Unit = println(schema.treeString) + def printSchema(): Unit = println(schema.treeString) // scalastyle:on println /** @@ -334,7 +405,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - override def explain(extended: Boolean): Unit = { + def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { // scalastyle:off println @@ -349,7 +420,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - override def explain(): Unit = explain(extended = false) + def explain(): Unit = explain(extended = false) /** * Returns all column names and their data types as an array. @@ -379,6 +450,22 @@ class Dataset[T] private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** + * Returns true if this [[Dataset]] contains one or more sources that continuously + * return data as it arrives. A [[Dataset]] that reads data from a streaming source + * must be executed as a [[ContinuousQuery]] using the `startStream()` method in + * [[DataFrameWriter]]. Methods that return a single answer, (e.g., `count()` or + * `collect()`) will throw an [[AnalysisException]] when there is a streaming + * source present. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def isStreaming: Boolean = logicalPlan.find { n => + n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation] + }.isDefined + + /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: * {{{ @@ -678,7 +765,8 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + + withTypedPlan { Project( leftData :: rightData :: Nil, joined.analyzed) @@ -1404,6 +1492,8 @@ class Dataset[T] private[sql]( * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. * + * For Java API, use [[randomSplitAsList]]. + * * @group typedrel * @since 2.0.0 */ @@ -1422,6 +1512,20 @@ class Dataset[T] private[sql]( } /** + * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { + val values = randomSplit(weights, seed) + java.util.Arrays.asList(values : _*) + } + + /** * Randomly splits this [[Dataset]] with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. @@ -1790,7 +1894,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + def filter(func: T => Boolean): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[T => Boolean])) + val condition = Invoke(function, "apply", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: @@ -1801,7 +1911,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + def filter(func: FilterFunction[T]): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) + val condition = Invoke(function, "call", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: @@ -1812,7 +1928,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) + } /** * :: Experimental :: @@ -1823,8 +1941,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + withTypedPlan(MapElements[T, U](func, logicalPlan)) + } /** * :: Experimental :: @@ -1987,6 +2107,24 @@ class Dataset[T] private[sql]( } /** + * Return an iterator that contains all of [[Row]]s in this [[Dataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[Dataset]]. + * + * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input Dataset should be cached first. + * + * @group action + * @since 2.0.0 + */ + def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => + withNewExecutionId { + queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + } + } + + /** * Returns the number of rows in the [[Dataset]]. * @group action * @since 1.6.0 @@ -2007,7 +2145,7 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Datasetis hash partitioned. + * `numPartitions`. The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * @@ -2230,6 +2368,12 @@ class Dataset[T] private[sql]( } } + protected[sql] def toPythonIterator(): Int = { + withNewExecutionId { + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// // Private Helpers //////////////////////////////////////////////////////////////////////////// @@ -2300,12 +2444,7 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) } - - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } |