aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala217
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala)29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala173
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala)39
11 files changed, 304 insertions, 516 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c25161ee81..9cbb7c2ffd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,6 +146,10 @@ trait ScalaReflection {
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
+ *
+ * When used on a primitive type, the constructor will instead default to extracting the value
+ * from ordinal 0 (since there are no names to map to). The actual location can be moved by
+ * calling unbind/bind with a new schema.
*/
def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
@@ -159,8 +163,14 @@ trait ScalaReflection {
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
+ /** Returns the current path with a field at ordinal extracted. */
+ def addToPathOrdinal(ordinal: Int, dataType: DataType) =
+ path
+ .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+
/** Returns the current path or throws an error. */
- def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+ def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true))
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
@@ -387,12 +397,17 @@ trait ScalaReflection {
val className: String = t.erasure.typeSymbol.asClass.fullName
val cls = Utils.classForName(className)
- val arguments = params.head.map { p =>
+ val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- val dataType = dataTypeFor(fieldType)
+ val dataType = schemaFor(fieldType).dataType
- constructorFor(fieldType, Some(addToPath(fieldName)))
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ if (className startsWith "scala.Tuple") {
+ constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+ } else {
+ constructorFor(fieldType, Some(addToPath(fieldName)))
+ }
}
val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
@@ -413,7 +428,10 @@ trait ScalaReflection {
/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
+ extractorFor(inputObject, typeTag[T].tpe) match {
+ case s: CreateNamedStruct => s
+ case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
+ }
}
}
@@ -602,6 +620,21 @@ trait ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
+ case t if t <:< definitions.IntTpe =>
+ BoundReference(0, IntegerType, false)
+ case t if t <:< definitions.LongTpe =>
+ BoundReference(0, LongType, false)
+ case t if t <:< definitions.DoubleTpe =>
+ BoundReference(0, DoubleType, false)
+ case t if t <:< definitions.FloatTpe =>
+ BoundReference(0, FloatType, false)
+ case t if t <:< definitions.ShortTpe =>
+ BoundReference(0, ShortType, false)
+ case t if t <:< definitions.ByteTpe =>
+ BoundReference(0, ByteType, false)
+ case t if t <:< definitions.BooleanTpe =>
+ BoundReference(0, BooleanType, false)
+
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
deleted file mode 100644
index b484b8fde6..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
- schema: StructType,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
- clsTag: ClassTag[T])
- extends Encoder[T] {
-
- @transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
- private val inputRow = new GenericMutableRow(1)
-
- @transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
- private val dataType = ObjectType(clsTag.runtimeClass)
-
- override def toRow(t: T): InternalRow = {
- inputRow(0) = t
- extractProjection(inputRow)
- }
-
- override def fromRow(row: InternalRow): T = {
- constructProjection(row).get(0, dataType).asInstanceOf[T]
- }
-
- override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
- val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- val resolvedExpression = analyzedPlan.expressions.head.children.head
- val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
- copy(constructExpression = boundExpression)
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
- var remaining = schema
- copy(constructExpression = constructExpression transform {
- case u: UnresolvedAttribute =>
- val pos = remaining.head
- remaining = remaining.drop(1)
- pos
- })
- }
-
- protected val attrs = extractExpressions.map(_.collect {
- case a: Attribute => s"#${a.exprId}"
- case b: BoundReference => s"[${b.ordinal}]"
- }.headOption.getOrElse(""))
-
-
- protected val schemaString =
- schema
- .zip(attrs)
- .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
-
- override def toString: String = s"class[$schemaString]"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index efb872ddb8..329a132d3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -18,10 +18,9 @@
package org.apache.spark.sql.catalyst.encoders
+
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
/**
@@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
*/
-trait Encoder[T] {
+trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
-
- /**
- * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
- * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
- * copy the result before making another call if required.
- */
- def toRow(t: T): InternalRow
-
- /**
- * Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must `bind` an encoder to a specific schema before you can call this function.
- */
- def fromRow(row: InternalRow): T
-
- /**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
- * given schema.
- */
- def bind(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Binds this encoder to the given schema positionally. In this binding, the first reference to
- * any input is mapped to `schema(0)`, and so on for each input that is encountered.
- */
- def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Given an encoder that has already been bound to a given schema, returns a new encoder that
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
- * row, but now you have projected out only the key expressions.
- */
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
new file mode 100644
index 0000000000..c287aebeee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert objects and primitves to and from the
+ * internal row format using catalyst expressions and code generation. By default, the
+ * expressions used to retrieve values from an input row when producing an object will be created as
+ * follows:
+ * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
+ * and [[UnresolvedExtractValue]] expressions.
+ * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
+ * - Primitives will have their values extracted from the first ordinal with a schema that defaults
+ * to the name `value`.
+ */
+object ExpressionEncoder {
+ def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+ // We convert the not-serializable TypeTag into StructType and ClassTag.
+ val mirror = typeTag[T].mirror
+ val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
+ val constructExpression = ScalaReflection.constructorFor[T]
+
+ new ExpressionEncoder[T](
+ extractExpression.dataType,
+ flat,
+ extractExpression.flatten,
+ constructExpression,
+ ClassTag[T](cls))
+ }
+
+ /**
+ * Given a set of N encoders, constructs a new encoder that produce objects as items in an
+ * N-tuple. Note that these encoders should first be bound correctly to the combined input
+ * schema.
+ */
+ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ val schema =
+ StructType(
+ encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+ val extractExpressions = encoders.map {
+ case e if e.flat => e.extractExpressions.head
+ case other => CreateStruct(other.extractExpressions)
+ }
+ val constructExpression =
+ NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+
+ new ExpressionEncoder[Any](
+ schema,
+ false,
+ extractExpressions,
+ constructExpression,
+ ClassTag.apply(cls))
+ }
+
+ /** A helper for producing encoders of Tuple2 from other encoders. */
+ def tuple[T1, T2](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
+ tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ExpressionEncoder[T](
+ schema: StructType,
+ flat: Boolean,
+ extractExpressions: Seq[Expression],
+ constructExpression: Expression,
+ clsTag: ClassTag[T])
+ extends Encoder[T] {
+
+ if (flat) require(extractExpressions.size == 1)
+
+ @transient
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private val inputRow = new GenericMutableRow(1)
+
+ @transient
+ private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+
+ /**
+ * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
+ * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
+ * copy the result before making another call if required.
+ */
+ def toRow(t: T): InternalRow = {
+ inputRow(0) = t
+ extractProjection(inputRow)
+ }
+
+ /**
+ * Returns an object of type `T`, extracting the required values from the provided row. Note that
+ * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+ * function.
+ */
+ def fromRow(row: InternalRow): T = try {
+ constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ }
+
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+ * given schema.
+ */
+ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to construct an object from an input
+ * row have been bound to the ordinals of the given schema. Note that you need to first call
+ * resolve before bind.
+ */
+ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
+ }
+
+ /**
+ * Replaces any bound references in the schema with the attributes at the corresponding ordinal
+ * in the provided schema. This can be used to "relocate" a given encoder to pull values from
+ * a different schema than it was initially bound to. It can also be used to assign attributes
+ * to ordinal based extraction (i.e. because the input data was a tuple).
+ */
+ def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ copy(constructExpression = constructExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ })
+ }
+
+ /**
+ * Given an encoder that has already been bound to a given schema, returns a new encoder
+ * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
+ * when you are trying to use an encoder on grouping keys that were originally part of a larger
+ * row, but now you have projected out only the key expressions.
+ */
+ def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(oldSchema)
+ val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference =>
+ r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+ })
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to create an object given an
+ * input row have been modified to pull the object out from a nested struct, instead of the
+ * top level fields.
+ */
+ def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
+ copy(constructExpression = constructExpression transform {
+ case u: Attribute if u != input =>
+ UnresolvedExtractValue(input, Literal(u.name))
+ case b: BoundReference if b != input =>
+ GetStructField(
+ input,
+ StructField(s"i[${b.ordinal}]", b.dataType),
+ b.ordinal)
+ })
+ }
+
+ protected val attrs = extractExpressions.flatMap(_.collect {
+ case _: UnresolvedAttribute => ""
+ case a: Attribute => s"#${a.exprId}"
+ case b: BoundReference => s"[${b.ordinal}]"
+ })
+
+ protected val schemaString =
+ schema
+ .zip(attrs)
+ .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+ override def toString: String = s"class[$schemaString]"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 34f5e6c030..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
- * internal binary representation.
- */
-object ProductEncoder {
- def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
- // We convert the not-serializable TypeTag into StructType and ClassTag.
- val mirror = typeTag[T].mirror
- val cls = mirror.runtimeClass(typeTag[T].tpe)
-
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
- val constructExpression = ScalaReflection.constructorFor[T]
-
- new ClassEncoder[T](
- extractExpression.dataType,
- extractExpression.flatten,
- constructExpression,
- ClassTag[T](cls))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e9cc00a2b6..0b42130a01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
* internal binary representation.
*/
object RowEncoder {
- def apply(schema: StructType): ClassEncoder[Row] = {
+ def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = extractorsFor(inputObject, schema)
val constructExpression = constructorFor(schema)
- new ClassEncoder[Row](
+ new ExpressionEncoder[Row](
schema,
+ flat = false,
extractExpressions.asInstanceOf[CreateStruct].children,
constructExpression,
ClassTag(cls))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 52f8383fac..d4642a5006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -15,29 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql.catalyst
-import org.apache.spark.SparkFunSuite
-
-class PrimitiveEncoderSuite extends SparkFunSuite {
- test("long encoder") {
- val enc = new LongEncoder()
- val row = enc.toRow(10)
- assert(row.getLong(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("int encoder") {
- val enc = new IntEncoder()
- val row = enc.toRow(10)
- assert(row.getInt(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("string encoder") {
- val enc = new StringEncoder()
- val row = enc.toRow("test")
- assert(row.getString(0) == "test")
- assert(enc.fromRow(row) == "test")
+package object encoders {
+ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
+ case e: ExpressionEncoder[A] => e
+ case _ => sys.error(s"Only expression encoders are supported today")
}
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
deleted file mode 100644
index a93f2d7c61..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-
-/** An encoder for primitive Long types. */
-case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Long] = ClassTag.Long
- override def schema: StructType =
- StructType(StructField(fieldName, LongType) :: Nil)
-
- override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
-
- override def toRow(t: Long): InternalRow = {
- row.setLong(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
- override def bind(schema: Seq[Attribute]): Encoder[Long] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
-}
-
-/** An encoder for primitive Integer types. */
-case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Int] = ClassTag.Int
- override def schema: StructType =
- StructType(StructField(fieldName, IntegerType) :: Nil)
-
- override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
-
- override def toRow(t: Int): InternalRow = {
- row.setInt(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
- override def bind(schema: Seq[Attribute]): Encoder[Int] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
-}
-
-/** An encoder for String types. */
-case class StringEncoder(
- fieldName: String = "value",
- ordinal: Int = 0) extends Encoder[String] {
-
- val record = new SpecificMutableRow(StringType :: Nil)
-
- @transient
- lazy val projection =
- GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
-
- override def schema: StructType =
- StructType(
- StructField("value", StringType, nullable = false) :: Nil)
-
- override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
-
-
- override final def fromRow(row: InternalRow): String = {
- row.getString(ordinal)
- }
-
- override final def toRow(value: String): InternalRow = {
- val utf8String = UTF8String.fromString(value)
- record(0) = utf8String
- // TODO: this is a bit of a hack to produce UnsafeRows
- projection(record)
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
- override def bind(schema: Seq[Attribute]): Encoder[String] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
deleted file mode 100644
index a48eeda7d2..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructField, StructType}
-
-// Most of this file is codegen.
-// scalastyle:off
-
-/**
- * A set of composite encoders that take sub encoders and map each of their objects to a
- * Scala tuple. Note that currently the implementation is fairly limited and only supports going
- * from an internal row to a tuple.
- */
-object TupleEncoder {
-
- /** Code generator for composite tuple encoders. */
- def main(args: Array[String]): Unit = {
- (2 to 5).foreach { i =>
- val types = (1 to i).map(t => s"T$t").mkString(", ")
- val tupleType = s"($types)"
- val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
- val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
- val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
-
- println(
- s"""
- |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
- | val schema = StructType(Array($fields))
- |
- | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
- |
- | def fromRow(row: InternalRow): $tupleType = {
- | ($fromRow)
- | }
- |
- | override def toRow(t: $tupleType): InternalRow =
- | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
- |
- | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
- | this
- | }
- |
- | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |
- |
- | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |}
- """.stripMargin)
- }
- }
-}
-
-class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
-
- def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
-
- def fromRow(row: InternalRow): (T1, T2) = {
- (e1.fromRow(row), e2.fromRow(row))
- }
-
- override def toRow(t: (T1, T2)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
-
- def fromRow(row: InternalRow): (T1, T2, T3) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21a55a5371..d2d3db0a44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
@@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode {
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
@@ -460,8 +460,8 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
- val attrs = implicitly[Encoder[U]].schema.toAttributes
- new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+ val attrs = encoderFor[U].schema.toAttributes
+ new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
}
}
@@ -472,8 +472,8 @@ object AppendColumn {
*/
case class AppendColumn[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
@@ -488,11 +488,11 @@ object MapGroups {
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- implicitly[Encoder[K]],
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
+ encoderFor[K],
+ encoderFor[T],
+ encoderFor[U],
groupingAttributes,
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[U].schema.toAttributes,
child)
}
}
@@ -504,9 +504,9 @@ object MapGroups {
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 008d0bea8a..a374da4da1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -47,7 +47,16 @@ case class RepeatedData(
case class SpecificCollection(l: List[Int])
-class ProductEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends SparkFunSuite {
+
+ encodeDecodeTest(1)
+ encodeDecodeTest(1L)
+ encodeDecodeTest(1.toDouble)
+ encodeDecodeTest(1.toFloat)
+ encodeDecodeTest(true)
+ encodeDecodeTest(false)
+ encodeDecodeTest(1.toShort)
+ encodeDecodeTest(1.toByte)
encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
@@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite {
{ (l, r) => l._2.toString == r._2.toString }
/** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
- protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+ protected def encodeDecodeTest[T : TypeTag](inputData: T) =
encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
/**
* Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
* matches the original.
*/
- protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+ protected def encodeDecodeTestCustom[T : TypeTag](
inputData: T)(
c: (T, T) => Boolean) = {
- test(s"encode/decode: $inputData") {
- val encoder = try ProductEncoder[T] catch {
+ test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
+ val encoder = try ExpressionEncoder[T]() catch {
case e: Exception =>
fail(s"Exception thrown generating encoder", e)
}
val convertedData = encoder.toRow(inputData)
val schema = encoder.schema.toAttributes
- val boundEncoder = encoder.bind(schema)
+ val boundEncoder = encoder.resolve(schema).bind(schema)
val convertedBack = try boundEncoder.fromRow(convertedData) catch {
case e: Exception =>
fail(
@@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite {
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
+ |Encoder:
+ |$boundEncoder
|
""".stripMargin, e)
}
if (!c(inputData, convertedBack)) {
- val types =
- convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ val types = convertedBack match {
+ case c: Product =>
+ c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ case other => other.getClass.getName
+ }
+
val encodedData = try {
convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
@@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite {
|${encoder.schema.treeString}
|
|Extract Expressions:
- |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
- |
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
- |
+ |$boundEncoder
""".stripMargin)
}
}