aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala173
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala72
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala43
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala392
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala68
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala141
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala79
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
26 files changed, 1501 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 27c96f4122..713c6b547d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -411,9 +411,9 @@ trait ScalaReflection {
}
/** Returns expressions for extracting all the fields from the given type. */
- def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
+ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children
+ extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
}
}
@@ -497,11 +497,11 @@ trait ScalaReflection {
}
}
- CreateStruct(params.head.map { p =>
+ CreateNamedStruct(params.head.flatMap { p =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- extractorFor(fieldValue, fieldType)
+ expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
})
case t if t <:< localTypeOf[Array[_]] =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
index 54096f18cb..b484b8fde6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
@@ -41,9 +41,11 @@ case class ClassEncoder[T](
clsTag: ClassTag[T])
extends Encoder[T] {
- private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ @transient
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)
+ @transient
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private val dataType = ObjectType(clsTag.runtimeClass)
@@ -64,4 +66,36 @@ case class ClassEncoder[T](
copy(constructExpression = boundExpression)
}
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(oldSchema)
+ val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference =>
+ r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+ })
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
+ var remaining = schema
+ copy(constructExpression = constructExpression transform {
+ case u: UnresolvedAttribute =>
+ val pos = remaining.head
+ remaining = remaining.drop(1)
+ pos
+ })
+ }
+
+ protected val attrs = extractExpressions.map(_.collect {
+ case a: Attribute => s"#${a.exprId}"
+ case b: BoundReference => s"[${b.ordinal}]"
+ }.headOption.getOrElse(""))
+
+
+ protected val schemaString =
+ schema
+ .zip(attrs)
+ .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+ override def toString: String = s"class[$schemaString]"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index bdb1c0959d..efb872ddb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType
* and reuse internal buffers to improve performance.
*/
trait Encoder[T] {
+
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType
@@ -46,13 +47,27 @@ trait Encoder[T] {
/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must bind the encoder to a specific schema before you can call this function.
+ * you must `bind` an encoder to a specific schema before you can call this function.
*/
def fromRow(row: InternalRow): T
/**
* Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
- * given schema
+ * given schema.
*/
def bind(schema: Seq[Attribute]): Encoder[T]
+
+ /**
+ * Binds this encoder to the given schema positionally. In this binding, the first reference to
+ * any input is mapped to `schema(0)`, and so on for each input that is encountered.
+ */
+ def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
+
+ /**
+ * Given an encoder that has already been bound to a given schema, returns a new encoder that
+ * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
+ * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
+ * row, but now you have projected out only the key expressions.
+ */
+ def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 4f7ce455ad..34f5e6c030 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -31,15 +31,17 @@ import org.apache.spark.sql.types.{ObjectType, StructType}
object ProductEncoder {
def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
- val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(typeTag[T].tpe)
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
+ val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
val constructExpression = ScalaReflection.constructorFor[T]
- new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
- }
-
+ new ClassEncoder[T](
+ extractExpression.dataType,
+ extractExpression.flatten,
+ constructExpression,
+ ClassTag[T](cls))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
new file mode 100644
index 0000000000..a93f2d7c61
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+
+/** An encoder for primitive Long types. */
+case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
+ private val row = UnsafeRow.createFromByteArray(64, 1)
+
+ override def clsTag: ClassTag[Long] = ClassTag.Long
+ override def schema: StructType =
+ StructType(StructField(fieldName, LongType) :: Nil)
+
+ override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
+
+ override def toRow(t: Long): InternalRow = {
+ row.setLong(ordinal, t)
+ row
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
+ override def bind(schema: Seq[Attribute]): Encoder[Long] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
+}
+
+/** An encoder for primitive Integer types. */
+case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
+ private val row = UnsafeRow.createFromByteArray(64, 1)
+
+ override def clsTag: ClassTag[Int] = ClassTag.Int
+ override def schema: StructType =
+ StructType(StructField(fieldName, IntegerType) :: Nil)
+
+ override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
+
+ override def toRow(t: Int): InternalRow = {
+ row.setInt(ordinal, t)
+ row
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
+ override def bind(schema: Seq[Attribute]): Encoder[Int] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
+}
+
+/** An encoder for String types. */
+case class StringEncoder(
+ fieldName: String = "value",
+ ordinal: Int = 0) extends Encoder[String] {
+
+ val record = new SpecificMutableRow(StringType :: Nil)
+
+ @transient
+ lazy val projection =
+ GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
+
+ override def schema: StructType =
+ StructType(
+ StructField("value", StringType, nullable = false) :: Nil)
+
+ override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
+
+
+ override final def fromRow(row: InternalRow): String = {
+ row.getString(ordinal)
+ }
+
+ override final def toRow(value: String): InternalRow = {
+ val utf8String = UTF8String.fromString(value)
+ record(0) = utf8String
+ // TODO: this is a bit of a hack to produce UnsafeRows
+ projection(record)
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
+ override def bind(schema: Seq[Attribute]): Encoder[String] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
new file mode 100644
index 0000000000..a48eeda7d2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.types.{StructField, StructType}
+
+// Most of this file is codegen.
+// scalastyle:off
+
+/**
+ * A set of composite encoders that take sub encoders and map each of their objects to a
+ * Scala tuple. Note that currently the implementation is fairly limited and only supports going
+ * from an internal row to a tuple.
+ */
+object TupleEncoder {
+
+ /** Code generator for composite tuple encoders. */
+ def main(args: Array[String]): Unit = {
+ (2 to 5).foreach { i =>
+ val types = (1 to i).map(t => s"T$t").mkString(", ")
+ val tupleType = s"($types)"
+ val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
+ val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
+ val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
+
+ println(
+ s"""
+ |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
+ | val schema = StructType(Array($fields))
+ |
+ | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
+ |
+ | def fromRow(row: InternalRow): $tupleType = {
+ | ($fromRow)
+ | }
+ |
+ | override def toRow(t: $tupleType): InternalRow =
+ | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+ |
+ | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
+ | this
+ | }
+ |
+ | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
+ | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+ |
+ |
+ | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
+ | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+ |}
+ """.stripMargin)
+ }
+ }
+}
+
+class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
+
+ def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
+
+ def fromRow(row: InternalRow): (T1, T2) = {
+ (e1.fromRow(row), e2.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3, T4)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 96a11e352e..ef3cc554b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,13 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
+
+ /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */
+ def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex)
+
+ /** Given a schema, constructs a map from ordinal to Attribute. */
+ def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] =
+ schema.zipWithIndex.map { case (a, i) => i -> a }.toMap
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 5345696570..3831535574 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -31,6 +31,10 @@ protected class AttributeEquals(val a: Attribute) {
}
object AttributeSet {
+ /** Returns an empty [[AttributeSet]]. */
+ val empty = apply(Iterable.empty)
+
+ /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index a5f02e2463..059e45bd68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -125,6 +125,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
*/
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
+ /**
+ * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
+ * StructType.
+ */
+ def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
+ case (v, n) => Alias(v, n.toString)()
+ }
+
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 30b7f8d376..f1fa13daa7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StructField, StructType}
/**
* A set of classes that can be used to represent trees of relational expressions. A key goal of
@@ -80,4 +81,15 @@ package object expressions {
/** Uses the given row to store the output of the projection. */
def target(row: MutableRow): MutableProjection
}
+
+
+ /**
+ * Helper functions for working with `Seq[Attribute]`.
+ */
+ implicit class AttributeSeq(attrs: Seq[Attribute]) {
+ /** Creates a StructType with a schema matching this `Seq[Attribute]`. */
+ def toStructType: StructType = {
+ StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ae9482c10f..21a55a5371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
@@ -417,7 +418,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
}
/**
- * Return a new RDD that has exactly `numPartitions` partitions. Differs from
+ * Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
* of the output requires some specific ordering or distribution of the data.
@@ -443,3 +444,72 @@ case object OneRowRelation extends LeafNode {
override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
+/**
+ * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
+ * used respectively to decode/encode from the JVM object representation expected by `func.`
+ */
+case class MapPartitions[T, U](
+ func: Iterator[T] => Iterator[U],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def missingInput: AttributeSet = AttributeSet.empty
+}
+
+/** Factory for constructing new `AppendColumn` nodes. */
+object AppendColumn {
+ def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
+ val attrs = implicitly[Encoder[U]].schema.toAttributes
+ new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each partition of the `child`, concatenating the
+ * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
+ * decode/encode from the JVM object representation expected by `func.`
+ */
+case class AppendColumn[T, U](
+ func: T => U,
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ newColumns: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output ++ newColumns
+ override def missingInput: AttributeSet = super.missingInput -- newColumns
+}
+
+/** Factory for constructing new `MapGroups` nodes. */
+object MapGroups {
+ def apply[K : Encoder, T : Encoder, U : Encoder](
+ func: (K, Iterator[T]) => Iterator[U],
+ groupingAttributes: Seq[Attribute],
+ child: LogicalPlan): MapGroups[K, T, U] = {
+ new MapGroups(
+ func,
+ implicitly[Encoder[K]],
+ implicitly[Encoder[T]],
+ implicitly[Encoder[U]],
+ groupingAttributes,
+ implicitly[Encoder[U]].schema.toAttributes,
+ child)
+ }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ */
+case class MapGroups[K, T, U](
+ func: (K, Iterator[T]) => Iterator[U],
+ kEncoder: Encoder[K],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ groupingAttributes: Seq[Attribute],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def missingInput: AttributeSet = AttributeSet.empty
+}
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
new file mode 100644
index 0000000000..52f8383fac
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.SparkFunSuite
+
+class PrimitiveEncoderSuite extends SparkFunSuite {
+ test("long encoder") {
+ val enc = new LongEncoder()
+ val row = enc.toRow(10)
+ assert(row.getLong(0) == 10)
+ assert(enc.fromRow(row) == 10)
+ }
+
+ test("int encoder") {
+ val enc = new IntEncoder()
+ val row = enc.toRow(10)
+ assert(row.getInt(0) == 10)
+ assert(enc.fromRow(row) == 10)
+ }
+
+ test("string encoder") {
+ val enc = new StringEncoder()
+ val row = enc.toRow("test")
+ assert(row.getString(0) == "test")
+ assert(enc.fromRow(row) == "test")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index 02e43ddb35..7735acbcba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -248,12 +248,16 @@ class ProductEncoderSuite extends SparkFunSuite {
val types =
convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
- val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
- case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
- a.toArray[Any](at.elementType).toSeq
- case (other, _) =>
- other
- }.mkString("[", ",", "]")
+ val encodedData = try {
+ convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
+ case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
+ a.toArray[Any](at.elementType).toSeq
+ case (other, _) =>
+ other
+ }.mkString("[", ",", "]")
+ } catch {
+ case e: Throwable => s"Failed to toSeq: $e"
+ }
fail(
s"""Encoded/Decoded data does not match input data
@@ -272,8 +276,9 @@ class ProductEncoderSuite extends SparkFunSuite {
|Construct Expressions:
|${boundEncoder.constructExpression.treeString}
|
- """.stripMargin)
+ """.stripMargin)
+ }
}
- }
+
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 37d559c8e4..de11a1699a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql
+
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -36,6 +38,11 @@ private[sql] object Column {
def unapply(col: Column): Option[Expression] = Some(col.expr)
}
+/**
+ * A [[Column]] where an [[Encoder]] has been given for the expected return type.
+ * @since 1.6.0
+ */
+class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr)
/**
* :: Experimental ::
@@ -70,6 +77,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
override def hashCode: Int = this.expr.hashCode
/**
+ * Provides a type hint about the expected return value of this column. This information can
+ * be used by operations such as `select` on a [[Dataset]] to automatically convert the
+ * results into the correct JVM types.
+ * @since 1.6.0
+ */
+ def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
+
+ /**
* Extracts a value or values from a complex type.
* The following types of extraction are supported:
* - Given an Array, an integer ordinal can be used to retrieve a single value.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 2f10aa9f3c..bf25bcde20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
@@ -259,6 +260,16 @@ class DataFrame private[sql](
def toDF(): DataFrame = this
/**
+ * :: Experimental ::
+ * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the
+ * specified type, `U`.
+ * @group basic
+ * @since 1.6.0
+ */
+ @Experimental
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+
+ /**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
new file mode 100644
index 0000000000..96213c7630
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel
+ * using functional or relational operations.
+ *
+ * A [[Dataset]] differs from an [[RDD]] in the following ways:
+ * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored
+ * in the encoded form. This representation allows for additional logical operations and
+ * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to
+ * an object.
+ * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be
+ * used to serialize the object into a binary format. Encoders are also capable of mapping the
+ * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime
+ * reflection based serialization. Operations that change the type of object stored in the
+ * dataset also need an encoder for the new type.
+ *
+ * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific
+ * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into
+ * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed
+ * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`.
+ *
+ * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However,
+ * making this change to the class hierarchy would break the function signatures for the existing
+ * functional operations (map, flatMap, etc). As such, this class should be considered a preview
+ * of the final API. Changes will be made to the interface after Spark 1.6.
+ *
+ * @since 1.6.0
+ */
+@Experimental
+class Dataset[T] private[sql](
+ @transient val sqlContext: SQLContext,
+ @transient val queryExecution: QueryExecution)(
+ implicit val encoder: Encoder[T]) extends Serializable {
+
+ private implicit def classTag = encoder.clsTag
+
+ private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
+ this(sqlContext, new QueryExecution(sqlContext, plan))
+
+ /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
+ def schema: StructType = encoder.schema
+
+ /* ************* *
+ * Conversions *
+ * ************* */
+
+ /**
+ * Returns a new `Dataset` where each record has been mapped on to the specified type.
+ * TODO: should bind here...
+ * TODO: document binding rules
+ * @since 1.6.0
+ */
+ def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+
+ /**
+ * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
+ * the same name after two Datasets have been joined.
+ */
+ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _))
+
+ /**
+ * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
+ * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
+ * objects that allow fields to be accessed by ordinal or name.
+ */
+ def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
+
+
+ /**
+ * Returns this Dataset.
+ * @since 1.6.0
+ */
+ def toDS(): Dataset[T] = this
+
+ /**
+ * Converts this Dataset to an RDD.
+ * @since 1.6.0
+ */
+ def rdd: RDD[T] = {
+ val tEnc = implicitly[Encoder[T]]
+ val input = queryExecution.analyzed.output
+ queryExecution.toRdd.mapPartitions { iter =>
+ val bound = tEnc.bind(input)
+ iter.map(bound.fromRow)
+ }
+ }
+
+ /* *********************** *
+ * Functional Operations *
+ * *********************** */
+
+ /**
+ * Concise syntax for chaining custom transformations.
+ * {{{
+ * def featurize(ds: Dataset[T]) = ...
+ *
+ * dataset
+ * .transform(featurize)
+ * .transform(...)
+ * }}}
+ *
+ * @since 1.6.0
+ */
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+ /**
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+
+ /**
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+
+ /**
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+ new Dataset(
+ sqlContext,
+ MapPartitions[T, U](
+ func,
+ implicitly[Encoder[T]],
+ implicitly[Encoder[U]],
+ implicitly[Encoder[U]].schema.toAttributes,
+ logicalPlan))
+ }
+
+ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+ mapPartitions(_.flatMap(func))
+
+ /* ************** *
+ * Side effects *
+ * ************** */
+
+ /**
+ * Runs `func` on each element of this Dataset.
+ * @since 1.6.0
+ */
+ def foreach(func: T => Unit): Unit = rdd.foreach(func)
+
+ /**
+ * Runs `func` on each partition of this Dataset.
+ * @since 1.6.0
+ */
+ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
+
+ /* ************* *
+ * Aggregation *
+ * ************* */
+
+ /**
+ * Reduces the elements of this Dataset using the specified binary function. The given function
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: (T, T) => T): T = rdd.reduce(func)
+
+ /**
+ * Aggregates the elements of each partition, and then the results for all the partitions, using a
+ * given associative and commutative function and a neutral "zero value".
+ *
+ * This behaves somewhat differently than the fold operations implemented for non-distributed
+ * collections in functional languages like Scala. This fold operation may be applied to
+ * partitions individually, and then those results will be folded into the final result.
+ * If op is not commutative, then the result may differ from that of a fold applied to a
+ * non-distributed collection.
+ * @since 1.6.0
+ */
+ def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
+
+ /**
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * @since 1.6.0
+ */
+ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
+ val inputPlan = queryExecution.analyzed
+ val withGroupingKey = AppendColumn(func, inputPlan)
+ val executed = sqlContext.executePlan(withGroupingKey)
+
+ new GroupedDataset(
+ implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
+ implicitly[Encoder[T]].bind(inputPlan.output),
+ executed,
+ inputPlan.output,
+ withGroupingKey.newColumns)
+ }
+
+ /* ****************** *
+ * Typed Relational *
+ * ****************** */
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
+ *
+ * {{{
+ * val ds = Seq(1, 2, 3).toDS()
+ * val newDS = ds.select(e[Int]("value + 1"))
+ * }}}
+ * @since 1.6.0
+ */
+ def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
+ new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
+ }
+
+ // Codegen
+ // scalastyle:off
+
+ /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
+ private def genSelect: String = {
+ (2 to 5).map { n =>
+ val types = (1 to n).map(i =>s"U$i").mkString(", ")
+ val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
+ val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
+ val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
+ s"""
+ |/**
+ | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ | * @since 1.6.0
+ | */
+ |def select[$types]($args): Dataset[($types)] = {
+ | implicit val te = new Tuple${n}Encoder($encoders)
+ | new Dataset[($types)](sqlContext,
+ | Project(
+ | $schema :: Nil,
+ | logicalPlan))
+ |}
+ |
+ """.stripMargin
+ }.mkString("\n")
+ }
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
+ implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
+ new Dataset[(U1, U2)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
+ implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
+ new Dataset[(U1, U2, U3)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
+ implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
+ new Dataset[(U1, U2, U3, U4)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
+ implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
+ new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
+ logicalPlan))
+ }
+
+ // scalastyle:on
+
+ /* **************** *
+ * Set operations *
+ * **************** */
+
+ /**
+ * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]].
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def distinct: Dataset[T] = withPlan(Distinct)
+
+ /**
+ * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also
+ * present in `other`.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def intersect(other: Dataset[T]): Dataset[T] =
+ withPlan[T](other)(Intersect)
+
+ /**
+ * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
+ * combined.
+ *
+ * Note that, this function is not a typical set union operation, in that it does not eliminate
+ * duplicate items. As such, it is analagous to `UNION ALL` in SQL.
+ * @since 1.6.0
+ */
+ def union(other: Dataset[T]): Dataset[T] =
+ withPlan[T](other)(Union)
+
+ /**
+ * Returns a new [[Dataset]] where any elements present in `other` have been removed.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+
+ /* ************************** *
+ * Gather to Driver Actions *
+ * ************************** */
+
+ /** Returns the first element in this [[Dataset]]. */
+ def first(): T = rdd.first()
+
+ /** Collects the elements to an Array. */
+ def collect(): Array[T] = rdd.collect()
+
+ /** Returns the first `num` elements of this [[Dataset]] as an Array. */
+ def take(num: Int): Array[T] = rdd.take(num)
+
+ /* ******************** *
+ * Internal Functions *
+ * ******************** */
+
+ private[sql] def logicalPlan = queryExecution.analyzed
+
+ private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+
+ private[sql] def withPlan[R : Encoder](
+ other: Dataset[_])(
+ f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
+ new Dataset[R](
+ sqlContext,
+ sqlContext.executePlan(
+ f(logicalPlan, other.logicalPlan)))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
new file mode 100644
index 0000000000..17817cbcc5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -0,0 +1,30 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+/**
+ * A container for a [[DataFrame]], used for implicit conversions.
+ *
+ * @since 1.3.0
+ */
+private[sql] case class DatasetHolder[T](df: Dataset[T]) {
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDS(): Dataset[T] = df
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
new file mode 100644
index 0000000000..89a16dd8b0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.QueryExecution
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing
+ * [[Dataset]].
+ */
+class GroupedDataset[K, T] private[sql](
+ private val kEncoder: Encoder[K],
+ private val tEncoder: Encoder[T],
+ queryExecution: QueryExecution,
+ private val dataAttributes: Seq[Attribute],
+ private val groupingAttributes: Seq[Attribute]) extends Serializable {
+
+ private implicit def kEnc = kEncoder
+ private implicit def tEnc = tEncoder
+ private def logicalPlan = queryExecution.analyzed
+ private def sqlContext = queryExecution.sqlContext
+
+ /**
+ * Returns a [[Dataset]] that contains each unique key.
+ */
+ def keys: Dataset[K] = {
+ new Dataset[K](
+ sqlContext,
+ Distinct(
+ Project(groupingAttributes, logicalPlan)))
+ }
+
+ /**
+ * Applies the given function to each group of data. For each unique group, the function will
+ * be passed the group key and an iterator that contains all of the elements in the group. The
+ * function can return an iterator containing elements of an arbitrary type which will be returned
+ * as a new [[Dataset]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ */
+ def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
+ new Dataset[U](
+ sqlContext,
+ MapGroups(f, groupingAttributes, logicalPlan))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a107639947..5e7198f974 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -21,6 +21,7 @@ import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
+
import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
@@ -33,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -487,6 +489,16 @@ class SQLContext private[sql](
DataFrame(this, logicalPlan)
}
+
+ def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
+ val enc = implicitly[Encoder[T]]
+ val attributes = enc.schema.toAttributes
+ val encoded = data.map(d => enc.toRow(d).copy())
+ val plan = new LocalRelation(attributes, encoded)
+
+ new Dataset[T](this, plan)
+ }
+
/**
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
* converted to Catalyst rows.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bf03c61088..af8474df0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -30,9 +34,19 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
*/
-private[sql] abstract class SQLImplicits {
+abstract class SQLImplicits {
protected def _sqlContext: SQLContext
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+
+ implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
+ implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
+ implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+
+ implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
+ DatasetHolder(_sqlContext.createDataset(s))
+ }
+
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
new file mode 100644
index 0000000000..10742cf734
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression}
+
+object GroupedIterator {
+ def apply(
+ input: Iterator[InternalRow],
+ keyExpressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
+ if (input.hasNext) {
+ new GroupedIterator(input, keyExpressions, inputSchema)
+ } else {
+ Iterator.empty
+ }
+ }
+}
+
+/**
+ * Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to
+ * next will return a pair containing the current group and an iterator that will return all the
+ * elements of that group. Iterators for each group are lazily constructed by extracting rows
+ * from the input iterator. As such, full groups are never materialized by this class.
+ *
+ * Example input:
+ * {{{
+ * Input: [a, 1], [b, 2], [b, 3]
+ * Grouping: x#1
+ * InputSchema: x#1, y#2
+ * }}}
+ *
+ * Result:
+ * {{{
+ * First call to next(): ([a], Iterator([a, 1])
+ * Second call to next(): ([b], Iterator([b, 2], [b, 3])
+ * }}}
+ *
+ * Note, the class does not handle the case of an empty input for simplicity of implementation.
+ * Use the factory to construct a new instance.
+ *
+ * @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or
+ * it is possible for the same group to appear more than once.
+ * @param groupingExpressions The set of expressions used to do grouping. The result of evaluating
+ * these expressions will be returned as the first part of each call
+ * to `next()`.
+ * @param inputSchema The schema of the rows in the `input` iterator.
+ */
+class GroupedIterator private(
+ input: Iterator[InternalRow],
+ groupingExpressions: Seq[Expression],
+ inputSchema: Seq[Attribute])
+ extends Iterator[(InternalRow, Iterator[InternalRow])] {
+
+ /** Compares two input rows and returns 0 if they are in the same group. */
+ val sortOrder = groupingExpressions.map(SortOrder(_, Ascending))
+ val keyOrdering = GenerateOrdering.generate(sortOrder, inputSchema)
+
+ /** Creates a row containing only the key for a given input row. */
+ val keyProjection = GenerateUnsafeProjection.generate(groupingExpressions, inputSchema)
+
+ /**
+ * Holds null or the row that will be returned on next call to `next()` in the inner iterator.
+ */
+ var currentRow = input.next()
+
+ /** Holds a copy of an input row that is in the current group. */
+ var currentGroup = currentRow.copy()
+ var currentIterator: Iterator[InternalRow] = null
+ assert(keyOrdering.compare(currentGroup, currentRow) == 0)
+
+ // Return true if we already have the next iterator or fetching a new iterator is successful.
+ def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
+
+ def next(): (InternalRow, Iterator[InternalRow]) = {
+ assert(hasNext) // Ensure we have fetched the next iterator.
+ val ret = (keyProjection(currentGroup), currentIterator)
+ currentIterator = null
+ ret
+ }
+
+ def fetchNextGroupIterator(): Boolean = {
+ if (currentRow != null || input.hasNext) {
+ val inputIterator = new Iterator[InternalRow] {
+ // Return true if we have a row and it is in the current group, or if fetching a new row is
+ // successful.
+ def hasNext = {
+ (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
+ fetchNextRowInGroup()
+ }
+
+ def fetchNextRowInGroup(): Boolean = {
+ if (currentRow != null || input.hasNext) {
+ currentRow = input.next()
+ if (keyOrdering.compare(currentGroup, currentRow) == 0) {
+ // The row is in the current group. Continue the inner iterator.
+ true
+ } else {
+ // We got a row, but its not in the right group. End this inner iterator and prepare
+ // for the next group.
+ currentIterator = null
+ currentGroup = currentRow.copy()
+ false
+ }
+ } else {
+ // There is no more input so we are done.
+ false
+ }
+ }
+
+ def next(): InternalRow = {
+ assert(hasNext) // Ensure we have fetched the next row.
+ val res = currentRow
+ currentRow = null
+ res
+ }
+ }
+ currentIterator = inputIterator
+ true
+ } else {
+ false
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 79bd1a4180..637deff4e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -372,6 +372,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
+
+ case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
+ execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
+ case logical.AppendColumn(f, tEnc, uEnc, newCol, child) =>
+ execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
+ case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
+ execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
+
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index dc38fe59fe..2bb3dba5bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.MutablePair
@@ -311,3 +313,80 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
protected override def doExecute(): RDD[InternalRow] = child.execute()
}
+
+/**
+ * Applies the given function to each input row and encodes the result.
+ */
+case class MapPartitions[T, U](
+ func: Iterator[T] => Iterator[U],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ output: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val tBoundEncoder = tEncoder.bind(child.output)
+ func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow)
+ }
+ }
+}
+
+/**
+ * Applies the given function to each input row, appending the encoded result at the end of the row.
+ */
+case class AppendColumns[T, U](
+ func: T => U,
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ newColumns: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output ++ newColumns
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val tBoundEncoder = tEncoder.bind(child.output)
+ val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
+ iter.map { row =>
+ val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row)))
+ combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
+ }
+ }
+ }
+}
+
+/**
+ * Groups the input rows together and calls the function with each group and an iterator containing
+ * all elements in the group. The result of this function is encoded and flattened before
+ * being output.
+ */
+case class MapGroups[K, T, U](
+ func: (K, Iterator[T]) => Iterator[U],
+ kEncoder: Encoder[K],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ groupingAttributes: Seq[Attribute],
+ output: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val groupKeyEncoder = kEncoder.bind(groupingAttributes)
+
+ grouped.flatMap { case (key, rowIter) =>
+ val result = func(
+ groupKeyEncoder.fromRow(key),
+ rowIter.map(tEncoder.fromRow))
+ result.map(uEncoder.toRow)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
new file mode 100644
index 0000000000..32443557fb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+case class IntClass(value: Int)
+
+class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("toDS") {
+ val data = Seq(1, 2, 3, 4, 5, 6)
+ checkAnswer(
+ data.toDS(),
+ data: _*)
+ }
+
+ test("as case class / collect") {
+ val ds = Seq(1, 2, 3).toDS().as[IntClass]
+ checkAnswer(
+ ds,
+ IntClass(1), IntClass(2), IntClass(3))
+
+ assert(ds.collect().head == IntClass(1))
+ }
+
+ test("map") {
+ val ds = Seq(1, 2, 3).toDS()
+ checkAnswer(
+ ds.map(_ + 1),
+ 2, 3, 4)
+ }
+
+ test("filter") {
+ val ds = Seq(1, 2, 3, 4).toDS()
+ checkAnswer(
+ ds.filter(_ % 2 == 0),
+ 2, 4)
+ }
+
+ test("foreach") {
+ val ds = Seq(1, 2, 3).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreach(acc +=)
+ assert(acc.value == 6)
+ }
+
+ test("foreachPartition") {
+ val ds = Seq(1, 2, 3).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreachPartition(_.foreach(acc +=))
+ assert(acc.value == 6)
+ }
+
+ test("reduce") {
+ val ds = Seq(1, 2, 3).toDS()
+ assert(ds.reduce(_ + _) == 6)
+ }
+
+ test("fold") {
+ val ds = Seq(1, 2, 3).toDS()
+ assert(ds.fold(0)(_ + _) == 6)
+ }
+
+ test("groupBy function, keys") {
+ val ds = Seq(1, 2, 3, 4, 5).toDS()
+ val grouped = ds.groupBy(_ % 2)
+ checkAnswer(
+ grouped.keys,
+ 0, 1)
+ }
+
+ test("groupBy function, mapGroups") {
+ val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
+ val grouped = ds.groupBy(_ % 2)
+ val agged = grouped.mapGroups { case (g, iter) =>
+ val name = if (g == 0) "even" else "odd"
+ Iterator((name, iter.size))
+ }
+
+ checkAnswer(
+ agged,
+ ("even", 5), ("odd", 6))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
new file mode 100644
index 0000000000..08496249c6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+case class ClassData(a: String, b: Int)
+
+class DatasetSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("toDS") {
+ val data = Seq(("a", 1) , ("b", 2), ("c", 3))
+ checkAnswer(
+ data.toDS(),
+ data: _*)
+ }
+
+ test("as case class / collect") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
+ checkAnswer(
+ ds,
+ ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
+ assert(ds.collect().head == ClassData("a", 1))
+ }
+
+ test("as case class - reordered fields by name") {
+ val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
+ assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)))
+ }
+
+ test("map") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.map(v => (v._1, v._2 + 1)),
+ ("a", 2), ("b", 3), ("c", 4))
+ }
+
+ test("select") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(expr("_2 + 1").as[Int]),
+ 2, 3, 4)
+ }
+
+ test("select 3") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("_2").as[Int],
+ expr("_2 + 1").as[Int]),
+ ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ }
+
+ test("filter") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.filter(_._1 == "b"),
+ ("b", 2))
+ }
+
+ test("foreach") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreach(v => acc += v._2)
+ assert(acc.value == 6)
+ }
+
+ test("foreachPartition") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreachPartition(_.foreach(v => acc += v._2))
+ assert(acc.value == 6)
+ }
+
+ test("reduce") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
+ }
+
+ test("fold") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
+ }
+
+ test("groupBy function, keys") {
+ val ds = Seq(("a", 1), ("b", 1)).toDS()
+ val grouped = ds.groupBy(v => (1, v._2))
+ checkAnswer(
+ grouped.keys,
+ (1, 1))
+ }
+
+ test("groupBy function, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy(v => (v._1, "word"))
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g._1, iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index e3c5a42667..aba567512f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
abstract class QueryTest extends PlanTest {
@@ -53,6 +55,12 @@ abstract class QueryTest extends PlanTest {
}
}
+ protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ checkAnswer(
+ ds.toDF(),
+ sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+ }
+
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param df the [[DataFrame]] to be executed