aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-10-22 15:20:17 -0700
committerReynold Xin <rxin@databricks.com>2015-10-22 15:20:17 -0700
commit53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc (patch)
tree9e10bf6e96c5faaf51d52790acdd9adc71145b54 /sql/catalyst
parent188ea348fdcf877d86f3c433cd15f6468fe3b42a (diff)
downloadspark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.tar.gz
spark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.tar.bz2
spark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.zip
[SPARK-11116][SQL] First Draft of Dataset API
*This PR adds a new experimental API to Spark, tentitively named Datasets.* A `Dataset` is a strongly-typed collection of objects that can be transformed in parallel using functional or relational operations. Example usage is as follows: ### Functional ```scala > val ds: Dataset[Int] = Seq(1, 2, 3).toDS() > ds.filter(_ % 1 == 0).collect() res1: Array[Int] = Array(1, 2, 3) ``` ### Relational ```scala scala> ds.toDF().show() +-----+ |value| +-----+ | 1| | 2| | 3| +-----+ > ds.select(expr("value + 1").as[Int]).collect() res11: Array[Int] = Array(2, 3, 4) ``` ## Comparison to RDDs A `Dataset` differs from an `RDD` in the following ways: - The creation of a `Dataset` requires the presence of an explicit `Encoder` that can be used to serialize the object into a binary format. Encoders are also capable of mapping the schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime reflection based serialization. - Internally, a `Dataset` is represented by a Catalyst logical plan and the data is stored in the encoded form. This representation allows for additional logical operations and enables many operations (sorting, shuffling, etc.) to be performed without deserializing to an object. A `Dataset` can be converted to an `RDD` by calling the `.rdd` method. ## Comparison to DataFrames A `Dataset` can be thought of as a specialized DataFrame, where the elements map to a specific JVM object type, instead of to a generic `Row` container. A DataFrame can be transformed into specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed `Dataset` to a generic DataFrame by calling `ds.toDF()`. ## Implementation Status and TODOs This is a rough cut at the least controversial parts of the API. The primary purpose here is to get something committed so that we can better parallelize further work and get early feedback on the API. The following is being deferred to future PRs: - Joins and Aggregations (prototype here https://github.com/apache/spark/commit/f11f91e6f08c8cf389b8388b626cd29eec32d937) - Support for Java Additionally, the responsibility for binding an encoder to a given schema is currently done in a fairly ad-hoc fashion. This is an internal detail, and what we are doing today works for the cases we care about. However, as we add more APIs we'll probably need to do this in a more principled way (i.e. separate resolution from binding as we do in DataFrames). ## COMPATIBILITY NOTE Long term we plan to make `DataFrame` extend `Dataset[Row]`. However, making this change to che class hierarchy would break the function signatures for the existing function operations (map, flatMap, etc). As such, this class should be considered a preview of the final API. Changes will be made to the interface after Spark 1.6. Author: Michael Armbrust <michael@databricks.com> Closes #9190 from marmbrus/dataset-infra.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala173
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala72
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala43
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala21
13 files changed, 495 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 27c96f4122..713c6b547d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -411,9 +411,9 @@ trait ScalaReflection {
}
/** Returns expressions for extracting all the fields from the given type. */
- def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
+ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children
+ extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
}
}
@@ -497,11 +497,11 @@ trait ScalaReflection {
}
}
- CreateStruct(params.head.map { p =>
+ CreateNamedStruct(params.head.flatMap { p =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- extractorFor(fieldValue, fieldType)
+ expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
})
case t if t <:< localTypeOf[Array[_]] =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
index 54096f18cb..b484b8fde6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
@@ -41,9 +41,11 @@ case class ClassEncoder[T](
clsTag: ClassTag[T])
extends Encoder[T] {
- private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ @transient
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)
+ @transient
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private val dataType = ObjectType(clsTag.runtimeClass)
@@ -64,4 +66,36 @@ case class ClassEncoder[T](
copy(constructExpression = boundExpression)
}
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(oldSchema)
+ val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference =>
+ r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+ })
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
+ var remaining = schema
+ copy(constructExpression = constructExpression transform {
+ case u: UnresolvedAttribute =>
+ val pos = remaining.head
+ remaining = remaining.drop(1)
+ pos
+ })
+ }
+
+ protected val attrs = extractExpressions.map(_.collect {
+ case a: Attribute => s"#${a.exprId}"
+ case b: BoundReference => s"[${b.ordinal}]"
+ }.headOption.getOrElse(""))
+
+
+ protected val schemaString =
+ schema
+ .zip(attrs)
+ .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+ override def toString: String = s"class[$schemaString]"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index bdb1c0959d..efb872ddb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType
* and reuse internal buffers to improve performance.
*/
trait Encoder[T] {
+
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType
@@ -46,13 +47,27 @@ trait Encoder[T] {
/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must bind the encoder to a specific schema before you can call this function.
+ * you must `bind` an encoder to a specific schema before you can call this function.
*/
def fromRow(row: InternalRow): T
/**
* Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
- * given schema
+ * given schema.
*/
def bind(schema: Seq[Attribute]): Encoder[T]
+
+ /**
+ * Binds this encoder to the given schema positionally. In this binding, the first reference to
+ * any input is mapped to `schema(0)`, and so on for each input that is encountered.
+ */
+ def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
+
+ /**
+ * Given an encoder that has already been bound to a given schema, returns a new encoder that
+ * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
+ * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
+ * row, but now you have projected out only the key expressions.
+ */
+ def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 4f7ce455ad..34f5e6c030 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -31,15 +31,17 @@ import org.apache.spark.sql.types.{ObjectType, StructType}
object ProductEncoder {
def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
- val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(typeTag[T].tpe)
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
+ val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
val constructExpression = ScalaReflection.constructorFor[T]
- new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
- }
-
+ new ClassEncoder[T](
+ extractExpression.dataType,
+ extractExpression.flatten,
+ constructExpression,
+ ClassTag[T](cls))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
new file mode 100644
index 0000000000..a93f2d7c61
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+
+/** An encoder for primitive Long types. */
+case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
+ private val row = UnsafeRow.createFromByteArray(64, 1)
+
+ override def clsTag: ClassTag[Long] = ClassTag.Long
+ override def schema: StructType =
+ StructType(StructField(fieldName, LongType) :: Nil)
+
+ override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
+
+ override def toRow(t: Long): InternalRow = {
+ row.setLong(ordinal, t)
+ row
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
+ override def bind(schema: Seq[Attribute]): Encoder[Long] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
+}
+
+/** An encoder for primitive Integer types. */
+case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
+ private val row = UnsafeRow.createFromByteArray(64, 1)
+
+ override def clsTag: ClassTag[Int] = ClassTag.Int
+ override def schema: StructType =
+ StructType(StructField(fieldName, IntegerType) :: Nil)
+
+ override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
+
+ override def toRow(t: Int): InternalRow = {
+ row.setInt(ordinal, t)
+ row
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
+ override def bind(schema: Seq[Attribute]): Encoder[Int] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
+}
+
+/** An encoder for String types. */
+case class StringEncoder(
+ fieldName: String = "value",
+ ordinal: Int = 0) extends Encoder[String] {
+
+ val record = new SpecificMutableRow(StringType :: Nil)
+
+ @transient
+ lazy val projection =
+ GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
+
+ override def schema: StructType =
+ StructType(
+ StructField("value", StringType, nullable = false) :: Nil)
+
+ override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
+
+
+ override final def fromRow(row: InternalRow): String = {
+ row.getString(ordinal)
+ }
+
+ override final def toRow(value: String): InternalRow = {
+ val utf8String = UTF8String.fromString(value)
+ record(0) = utf8String
+ // TODO: this is a bit of a hack to produce UnsafeRows
+ projection(record)
+ }
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
+ override def bind(schema: Seq[Attribute]): Encoder[String] = this
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
new file mode 100644
index 0000000000..a48eeda7d2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.types.{StructField, StructType}
+
+// Most of this file is codegen.
+// scalastyle:off
+
+/**
+ * A set of composite encoders that take sub encoders and map each of their objects to a
+ * Scala tuple. Note that currently the implementation is fairly limited and only supports going
+ * from an internal row to a tuple.
+ */
+object TupleEncoder {
+
+ /** Code generator for composite tuple encoders. */
+ def main(args: Array[String]): Unit = {
+ (2 to 5).foreach { i =>
+ val types = (1 to i).map(t => s"T$t").mkString(", ")
+ val tupleType = s"($types)"
+ val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
+ val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
+ val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
+
+ println(
+ s"""
+ |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
+ | val schema = StructType(Array($fields))
+ |
+ | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
+ |
+ | def fromRow(row: InternalRow): $tupleType = {
+ | ($fromRow)
+ | }
+ |
+ | override def toRow(t: $tupleType): InternalRow =
+ | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+ |
+ | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
+ | this
+ | }
+ |
+ | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
+ | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+ |
+ |
+ | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
+ | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+ |}
+ """.stripMargin)
+ }
+ }
+}
+
+class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
+
+ def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
+
+ def fromRow(row: InternalRow): (T1, T2) = {
+ (e1.fromRow(row), e2.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3, T4)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
+
+
+class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
+ val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
+
+ def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
+
+ def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
+ (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
+ }
+
+ override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
+ throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
+
+ override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
+ this
+ }
+
+ override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+
+
+ override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
+ throw new UnsupportedOperationException("Tuple Encoders only support bind.")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 96a11e352e..ef3cc554b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,13 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
+
+ /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */
+ def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex)
+
+ /** Given a schema, constructs a map from ordinal to Attribute. */
+ def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] =
+ schema.zipWithIndex.map { case (a, i) => i -> a }.toMap
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 5345696570..3831535574 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -31,6 +31,10 @@ protected class AttributeEquals(val a: Attribute) {
}
object AttributeSet {
+ /** Returns an empty [[AttributeSet]]. */
+ val empty = apply(Iterable.empty)
+
+ /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index a5f02e2463..059e45bd68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -125,6 +125,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
*/
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
+ /**
+ * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
+ * StructType.
+ */
+ def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
+ case (v, n) => Alias(v, n.toString)()
+ }
+
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 30b7f8d376..f1fa13daa7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StructField, StructType}
/**
* A set of classes that can be used to represent trees of relational expressions. A key goal of
@@ -80,4 +81,15 @@ package object expressions {
/** Uses the given row to store the output of the projection. */
def target(row: MutableRow): MutableProjection
}
+
+
+ /**
+ * Helper functions for working with `Seq[Attribute]`.
+ */
+ implicit class AttributeSeq(attrs: Seq[Attribute]) {
+ /** Creates a StructType with a schema matching this `Seq[Attribute]`. */
+ def toStructType: StructType = {
+ StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ae9482c10f..21a55a5371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
@@ -417,7 +418,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
}
/**
- * Return a new RDD that has exactly `numPartitions` partitions. Differs from
+ * Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
* of the output requires some specific ordering or distribution of the data.
@@ -443,3 +444,72 @@ case object OneRowRelation extends LeafNode {
override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
+/**
+ * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
+ * used respectively to decode/encode from the JVM object representation expected by `func.`
+ */
+case class MapPartitions[T, U](
+ func: Iterator[T] => Iterator[U],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def missingInput: AttributeSet = AttributeSet.empty
+}
+
+/** Factory for constructing new `AppendColumn` nodes. */
+object AppendColumn {
+ def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
+ val attrs = implicitly[Encoder[U]].schema.toAttributes
+ new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each partition of the `child`, concatenating the
+ * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
+ * decode/encode from the JVM object representation expected by `func.`
+ */
+case class AppendColumn[T, U](
+ func: T => U,
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ newColumns: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output ++ newColumns
+ override def missingInput: AttributeSet = super.missingInput -- newColumns
+}
+
+/** Factory for constructing new `MapGroups` nodes. */
+object MapGroups {
+ def apply[K : Encoder, T : Encoder, U : Encoder](
+ func: (K, Iterator[T]) => Iterator[U],
+ groupingAttributes: Seq[Attribute],
+ child: LogicalPlan): MapGroups[K, T, U] = {
+ new MapGroups(
+ func,
+ implicitly[Encoder[K]],
+ implicitly[Encoder[T]],
+ implicitly[Encoder[U]],
+ groupingAttributes,
+ implicitly[Encoder[U]].schema.toAttributes,
+ child)
+ }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ */
+case class MapGroups[K, T, U](
+ func: (K, Iterator[T]) => Iterator[U],
+ kEncoder: Encoder[K],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ groupingAttributes: Seq[Attribute],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def missingInput: AttributeSet = AttributeSet.empty
+}
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
new file mode 100644
index 0000000000..52f8383fac
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.SparkFunSuite
+
+class PrimitiveEncoderSuite extends SparkFunSuite {
+ test("long encoder") {
+ val enc = new LongEncoder()
+ val row = enc.toRow(10)
+ assert(row.getLong(0) == 10)
+ assert(enc.fromRow(row) == 10)
+ }
+
+ test("int encoder") {
+ val enc = new IntEncoder()
+ val row = enc.toRow(10)
+ assert(row.getInt(0) == 10)
+ assert(enc.fromRow(row) == 10)
+ }
+
+ test("string encoder") {
+ val enc = new StringEncoder()
+ val row = enc.toRow("test")
+ assert(row.getString(0) == "test")
+ assert(enc.fromRow(row) == "test")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index 02e43ddb35..7735acbcba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -248,12 +248,16 @@ class ProductEncoderSuite extends SparkFunSuite {
val types =
convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
- val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
- case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
- a.toArray[Any](at.elementType).toSeq
- case (other, _) =>
- other
- }.mkString("[", ",", "]")
+ val encodedData = try {
+ convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
+ case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
+ a.toArray[Any](at.elementType).toSeq
+ case (other, _) =>
+ other
+ }.mkString("[", ",", "]")
+ } catch {
+ case e: Throwable => s"Failed to toSeq: $e"
+ }
fail(
s"""Encoded/Decoded data does not match input data
@@ -272,8 +276,9 @@ class ProductEncoderSuite extends SparkFunSuite {
|Construct Expressions:
|${boundEncoder.constructExpression.treeString}
|
- """.stripMargin)
+ """.stripMargin)
+ }
}
- }
+
}
}