diff options
Diffstat (limited to 'sql')
7 files changed, 94 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10..4def65b01f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -57,6 +57,7 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (s: Array[_], arrayType: ArrayType) => s.toSeq case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -140,7 +141,9 @@ trait ScalaReflection { // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55..4a66716e0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,6 +60,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -132,6 +133,10 @@ class ScalaReflectionSuite extends FunSuite { ArrayType(IntegerType, containsNull = false), nullable = true), StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), + StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), nullable = true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7f9a91a032..9be2a03afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -22,15 +22,19 @@ import scala.language.implicitConversions import org.apache.spark.sql.api.scala.dsl.lit import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.types._ object Column { - def unapply(col: Column): Option[Expression] = Some(col.expr) - + /** + * Creates a [[Column]] based on the given column name. + * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]]. + */ def apply(colName: String): Column = new Column(colName) + + /** For internal pattern matching. */ + private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) } @@ -438,7 +442,7 @@ class Column( * @param ordinal * @return */ - override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) + override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) /** * An expression that gets a field by name in a [[StructField]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ceb5f86bef..050366aea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -118,8 +118,8 @@ class DataFrame protected[sql]( /** Resolves a column name into a Catalyst [[NamedExpression]]. */ protected[sql] def resolve(colName: String): NamedExpression = { - logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse( - throw new RuntimeException(s"""Cannot resolve column name "$colName"""")) + logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) } /** Left here for compatibility reasons. */ @@ -131,6 +131,29 @@ class DataFrame protected[sql]( */ def toDataFrame: DataFrame = this + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + */ + @scala.annotation.varargs + def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + /** Returns the schema of this [[DataFrame]]. */ override def schema: StructType = queryExecution.analyzed.schema @@ -227,7 +250,7 @@ class DataFrame protected[sql]( } /** - * Selects a single column and return it as a [[Column]]. + * Selects column based on the column name and return it as a [[Column]]. */ override def apply(colName: String): Column = colName match { case "*" => @@ -467,6 +490,12 @@ class DataFrame protected[sql]( } /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** * Returns a new RDD by applying a function to each partition of this DataFrame. */ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { @@ -474,6 +503,16 @@ class DataFrame protected[sql]( } /** + * Applies a function `f` to all rows. + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + + /** * Returns the first `n` rows in the [[DataFrame]]. */ override def take(n: Int): Array[Row] = head(n) @@ -520,7 +559,7 @@ class DataFrame protected[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s. + * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. */ override def rdd: RDD[Row] = { val schema = this.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index 5eeaf17d71..59634082f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] { def map[R: ClassTag](f: T => R): RDD[R] + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + def take(n: Int): Array[T] def collect(): Array[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java index 74d7649e08..16702afdb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java @@ -33,6 +33,13 @@ public class dsl { private static package$ scalaDsl = package$.MODULE$; /** + * Returns a {@link Column} based on the given column name. + */ + public static Column col(String colName) { + return new Column(colName); + } + + /** * Creates a column of literal value. */ public static Column lit(Object literalValue) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala index 9f2d1427d4..dc851fc504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -37,6 +38,21 @@ package object dsl { /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) +// /** +// * An implicit conversion that turns a RDD of product into a [[DataFrame]]. +// * +// * This method requires an implicit SQLContext in scope. For example: +// * {{{ +// * implicit val sqlContext: SQLContext = ... +// * val rdd: RDD[(Int, String)] = ... +// * rdd.toDataFrame // triggers the implicit here +// * }}} +// */ +// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) +// : DataFrame = { +// context.createDataFrame(rdd) +// } + /** Converts $"col name" into an [[Column]]. */ implicit class StringToColumn(val sc: StringContext) extends AnyVal { def $(args: Any*): ColumnName = { @@ -47,6 +63,11 @@ package object dsl { private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) /** + * Returns a [[Column]] based on the given column name. + */ + def col(colName: String): Column = new Column(colName) + + /** * Creates a [[Column]] of literal value. */ def lit(literal: Any): Column = { |