aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala21
7 files changed, 94 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 191d16fb10..4def65b01f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -57,6 +57,7 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
+ case (s: Array[_], arrayType: ArrayType) => s.toSeq
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
@@ -140,7 +141,9 @@ trait ScalaReflection {
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
- sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 5138942a55..4a66716e0a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -60,6 +60,7 @@ case class OptionalData(
case class ComplexData(
arrayField: Seq[Int],
+ arrayField1: Array[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
@@ -132,6 +133,10 @@ class ScalaReflectionSuite extends FunSuite {
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
+ "arrayField1",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true),
+ StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
nullable = true),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 7f9a91a032..9be2a03afa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -22,15 +22,19 @@ import scala.language.implicitConversions
import org.apache.spark.sql.api.scala.dsl.lit
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.types._
object Column {
- def unapply(col: Column): Option[Expression] = Some(col.expr)
-
+ /**
+ * Creates a [[Column]] based on the given column name.
+ * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]].
+ */
def apply(colName: String): Column = new Column(colName)
+
+ /** For internal pattern matching. */
+ private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
}
@@ -438,7 +442,7 @@ class Column(
* @param ordinal
* @return
*/
- override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
+ override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
/**
* An expression that gets a field by name in a [[StructField]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index ceb5f86bef..050366aea8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -118,8 +118,8 @@ class DataFrame protected[sql](
/** Resolves a column name into a Catalyst [[NamedExpression]]. */
protected[sql] def resolve(colName: String): NamedExpression = {
- logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
- throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
+ logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException(
+ s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})"""))
}
/** Left here for compatibility reasons. */
@@ -131,6 +131,29 @@ class DataFrame protected[sql](
*/
def toDataFrame: DataFrame = this
+ /**
+ * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
+ * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
+ * {{{
+ * val rdd: RDD[(Int, String)] = ...
+ * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
+ * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * }}}
+ */
+ @scala.annotation.varargs
+ def toDataFrame(colName: String, colNames: String*): DataFrame = {
+ val newNames = colName +: colNames
+ require(schema.size == newNames.size,
+ "The number of columns doesn't match.\n" +
+ "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ "New column names: " + newNames.mkString(", "))
+
+ val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
+ apply(oldName).as(newName)
+ }
+ select(newCols :_*)
+ }
+
/** Returns the schema of this [[DataFrame]]. */
override def schema: StructType = queryExecution.analyzed.schema
@@ -227,7 +250,7 @@ class DataFrame protected[sql](
}
/**
- * Selects a single column and return it as a [[Column]].
+ * Selects column based on the column name and return it as a [[Column]].
*/
override def apply(colName: String): Column = colName match {
case "*" =>
@@ -467,6 +490,12 @@ class DataFrame protected[sql](
}
/**
+ * Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
+ * and then flattening the results.
+ */
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
+
+ /**
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
@@ -474,6 +503,16 @@ class DataFrame protected[sql](
}
/**
+ * Applies a function `f` to all rows.
+ */
+ override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+
+ /**
+ * Applies a function f to each partition of this [[DataFrame]].
+ */
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
+
+ /**
* Returns the first `n` rows in the [[DataFrame]].
*/
override def take(n: Int): Array[Row] = head(n)
@@ -520,7 +559,7 @@ class DataFrame protected[sql](
/////////////////////////////////////////////////////////////////////////////
/**
- * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
+ * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
*/
override def rdd: RDD[Row] = {
val schema = this.schema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
index 5eeaf17d71..59634082f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {
def map[R: ClassTag](f: T => R): RDD[R]
+ def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]
+
def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
+ def foreach(f: T => Unit): Unit
+
+ def foreachPartition(f: Iterator[T] => Unit): Unit
+
def take(n: Int): Array[T]
def collect(): Array[T]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
index 74d7649e08..16702afdb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
@@ -33,6 +33,13 @@ public class dsl {
private static package$ scalaDsl = package$.MODULE$;
/**
+ * Returns a {@link Column} based on the given column name.
+ */
+ public static Column col(String colName) {
+ return new Column(colName);
+ }
+
+ /**
* Creates a column of literal value.
*/
public static Column lit(Object literalValue) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
index 9f2d1427d4..dc851fc504 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
@@ -37,6 +38,21 @@ package object dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+// /**
+// * An implicit conversion that turns a RDD of product into a [[DataFrame]].
+// *
+// * This method requires an implicit SQLContext in scope. For example:
+// * {{{
+// * implicit val sqlContext: SQLContext = ...
+// * val rdd: RDD[(Int, String)] = ...
+// * rdd.toDataFrame // triggers the implicit here
+// * }}}
+// */
+// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
+// : DataFrame = {
+// context.createDataFrame(rdd)
+// }
+
/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
@@ -47,6 +63,11 @@ package object dsl {
private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
/**
+ * Returns a [[Column]] based on the given column name.
+ */
+ def col(colName: String): Column = new Column(colName)
+
+ /**
* Creates a [[Column]] of literal value.
*/
def lit(literal: Any): Column = {