diff options
Diffstat (limited to 'sql/catalyst')
11 files changed, 304 insertions, 516 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c25161ee81..9cbb7c2ffd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,6 +146,10 @@ trait ScalaReflection { * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling unbind/bind with a new schema. */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) @@ -159,8 +163,14 @@ trait ScalaReflection { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType) = + path + .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(sys.error("Constructors must start at a class type")) + def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true)) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => @@ -387,12 +397,17 @@ trait ScalaReflection { val className: String = t.erasure.typeSymbol.asClass.fullName val cls = Utils.classForName(className) - val arguments = params.head.map { p => + val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = dataTypeFor(fieldType) + val dataType = schemaFor(fieldType).dataType - constructorFor(fieldType, Some(addToPath(fieldName))) + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } } val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) @@ -413,7 +428,10 @@ trait ScalaReflection { /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct] + extractorFor(inputObject, typeTag[T].tpe) match { + case s: CreateNamedStruct => s + case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) + } } } @@ -602,6 +620,21 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if t <:< definitions.IntTpe => + BoundReference(0, IntegerType, false) + case t if t <:< definitions.LongTpe => + BoundReference(0, LongType, false) + case t if t <:< definitions.DoubleTpe => + BoundReference(0, DoubleType, false) + case t if t <:< definitions.FloatTpe => + BoundReference(0, FloatType, false) + case t if t <:< definitions.ShortTpe => + BoundReference(0, ShortType, false) + case t if t <:< definitions.ByteTpe => + BoundReference(0, ByteType, false) + case t if t <:< definitions.BooleanTpe => + BoundReference(0, BooleanType, false) + case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala deleted file mode 100644 index b484b8fde6..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A generic encoder for JVM objects. - * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. - * @param clsTag A classtag for `T`. - */ -case class ClassEncoder[T]( - schema: StructType, - extractExpressions: Seq[Expression], - constructExpression: Expression, - clsTag: ClassTag[T]) - extends Encoder[T] { - - @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) - private val inputRow = new GenericMutableRow(1) - - @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) - private val dataType = ObjectType(clsTag.runtimeClass) - - override def toRow(t: T): InternalRow = { - inputRow(0) = t - extractProjection(inputRow) - } - - override def fromRow(row: InternalRow): T = { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } - - override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { - val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - val resolvedExpression = analyzedPlan.expressions.head.children.head - val boundExpression = BindReferences.bindReference(resolvedExpression, schema) - - copy(constructExpression = boundExpression) - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) - } - - override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = { - var remaining = schema - copy(constructExpression = constructExpression transform { - case u: UnresolvedAttribute => - val pos = remaining.head - remaining = remaining.drop(1) - pos - }) - } - - protected val attrs = extractExpressions.map(_.collect { - case a: Attribute => s"#${a.exprId}" - case b: BoundReference => s"[${b.ordinal}]" - }.headOption.getOrElse("")) - - - protected val schemaString = - schema - .zip(attrs) - .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") - - override def toString: String = s"class[$schemaString]" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index efb872ddb8..329a132d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst.encoders + import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType /** @@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. */ -trait Encoder[T] { +trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ def schema: StructType /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] - - /** - * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to - * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should - * copy the result before making another call if required. - */ - def toRow(t: T): InternalRow - - /** - * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `bind` an encoder to a specific schema before you can call this function. - */ - def fromRow(row: InternalRow): T - - /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the - * given schema. - */ - def bind(schema: Seq[Attribute]): Encoder[T] - - /** - * Binds this encoder to the given schema positionally. In this binding, the first reference to - * any input is mapped to `schema(0)`, and so on for each input that is encountered. - */ - def bindOrdinals(schema: Seq[Attribute]): Encoder[T] - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder that - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were orriginally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala new file mode 100644 index 0000000000..c287aebeee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} + +/** + * A factory for constructing encoders that convert objects and primitves to and from the + * internal row format using catalyst expressions and code generation. By default, the + * expressions used to retrieve values from an input row when producing an object will be created as + * follows: + * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions + * and [[UnresolvedExtractValue]] expressions. + * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions. + * - Primitives will have their values extracted from the first ordinal with a schema that defaults + * to the name `value`. + */ +object ExpressionEncoder { + def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(typeTag[T].tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpression = ScalaReflection.extractorsFor[T](inputObject) + val constructExpression = ScalaReflection.constructorFor[T] + + new ExpressionEncoder[T]( + extractExpression.dataType, + flat, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } + + /** + * Given a set of N encoders, constructs a new encoder that produce objects as items in an + * N-tuple. Note that these encoders should first be bound correctly to the combined input + * schema. + */ + def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + val schema = + StructType( + encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + } + val constructExpression = + NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + /** A helper for producing encoders of Tuple2 from other encoders. */ + def tuple[T1, T2]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = + tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] +} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param extractExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object. + * @param clsTag A classtag for `T`. + */ +case class ExpressionEncoder[T]( + schema: StructType, + flat: Boolean, + extractExpressions: Seq[Expression], + constructExpression: Expression, + clsTag: ClassTag[T]) + extends Encoder[T] { + + if (flat) require(extractExpressions.size == 1) + + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + + /** + * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to + * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should + * copy the result before making another call if required. + */ + def toRow(t: T): InternalRow = { + inputRow(0) = t + extractProjection(inputRow) + } + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * function. + */ + def fromRow(row: InternalRow): T = try { + constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + } + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + copy(constructExpression = analyzedPlan.expressions.head.children.head) + } + + /** + * Returns a copy of this encoder where the expressions used to construct an object from an input + * row have been bound to the ordinals of the given schema. Note that you need to first call + * resolve before bind. + */ + def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) + } + + /** + * Replaces any bound references in the schema with the attributes at the corresponding ordinal + * in the provided schema. This can be used to "relocate" a given encoder to pull values from + * a different schema than it was initially bound to. It can also be used to assign attributes + * to ordinal based extraction (i.e. because the input data was a tuple). + */ + def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(schema) + copy(constructExpression = constructExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + }) + } + + /** + * Given an encoder that has already been bound to a given schema, returns a new encoder + * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, + * when you are trying to use an encoder on grouping keys that were originally part of a larger + * row, but now you have projected out only the key expressions. + */ + def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(oldSchema) + val attributeToNewPosition = AttributeMap.byIndex(newSchema) + copy(constructExpression = constructExpression transform { + case r: BoundReference => + r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) + }) + } + + /** + * Returns a copy of this encoder where the expressions used to create an object given an + * input row have been modified to pull the object out from a nested struct, instead of the + * top level fields. + */ + def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case u: Attribute if u != input => + UnresolvedExtractValue(input, Literal(u.name)) + case b: BoundReference if b != input => + GetStructField( + input, + StructField(s"i[${b.ordinal}]", b.dataType), + b.ordinal) + }) + } + + protected val attrs = extractExpressions.flatMap(_.collect { + case _: UnresolvedAttribute => "" + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }) + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 34f5e6c030..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL - * internal binary representation. - */ -object ProductEncoder { - def apply[T <: Product : TypeTag]: ClassEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(typeTag[T].tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] - - new ClassEncoder[T]( - extractExpression.dataType, - extractExpression.flatten, - constructExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e9cc00a2b6..0b42130a01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String * internal binary representation. */ object RowEncoder { - def apply(schema: StructType): ClassEncoder[Row] = { + def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val extractExpressions = extractorsFor(inputObject, schema) val constructExpression = constructorFor(schema) - new ClassEncoder[Row]( + new ExpressionEncoder[Row]( schema, + flat = false, extractExpressions.asInstanceOf[CreateStruct].children, constructExpression, ClassTag(cls)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 52f8383fac..d4642a5006 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -15,29 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.encoders +package org.apache.spark.sql.catalyst -import org.apache.spark.SparkFunSuite - -class PrimitiveEncoderSuite extends SparkFunSuite { - test("long encoder") { - val enc = new LongEncoder() - val row = enc.toRow(10) - assert(row.getLong(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("int encoder") { - val enc = new IntEncoder() - val row = enc.toRow(10) - assert(row.getInt(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("string encoder") { - val enc = new StringEncoder() - val row = enc.toRow("test") - assert(row.getString(0) == "test") - assert(enc.fromRow(row) == "test") +package object encoders { + private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + case e: ExpressionEncoder[A] => e + case _ => sys.error(s"Only expression encoders are supported today") } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala deleted file mode 100644 index a93f2d7c61..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ - -/** An encoder for primitive Long types. */ -case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Long] = ClassTag.Long - override def schema: StructType = - StructType(StructField(fieldName, LongType) :: Nil) - - override def fromRow(row: InternalRow): Long = row.getLong(ordinal) - - override def toRow(t: Long): InternalRow = { - row.setLong(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this - override def bind(schema: Seq[Attribute]): Encoder[Long] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this -} - -/** An encoder for primitive Integer types. */ -case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Int] = ClassTag.Int - override def schema: StructType = - StructType(StructField(fieldName, IntegerType) :: Nil) - - override def fromRow(row: InternalRow): Int = row.getInt(ordinal) - - override def toRow(t: Int): InternalRow = { - row.setInt(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this - override def bind(schema: Seq[Attribute]): Encoder[Int] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this -} - -/** An encoder for String types. */ -case class StringEncoder( - fieldName: String = "value", - ordinal: Int = 0) extends Encoder[String] { - - val record = new SpecificMutableRow(StringType :: Nil) - - @transient - lazy val projection = - GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil) - - override def schema: StructType = - StructType( - StructField("value", StringType, nullable = false) :: Nil) - - override def clsTag: ClassTag[String] = scala.reflect.classTag[String] - - - override final def fromRow(row: InternalRow): String = { - row.getString(ordinal) - } - - override final def toRow(value: String): InternalRow = { - val utf8String = UTF8String.fromString(value) - record(0) = utf8String - // TODO: this is a bit of a hack to produce UnsafeRows - projection(record) - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this - override def bind(schema: Seq[Attribute]): Encoder[String] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala deleted file mode 100644 index a48eeda7d2..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructField, StructType} - -// Most of this file is codegen. -// scalastyle:off - -/** - * A set of composite encoders that take sub encoders and map each of their objects to a - * Scala tuple. Note that currently the implementation is fairly limited and only supports going - * from an internal row to a tuple. - */ -object TupleEncoder { - - /** Code generator for composite tuple encoders. */ - def main(args: Array[String]): Unit = { - (2 to 5).foreach { i => - val types = (1 to i).map(t => s"T$t").mkString(", ") - val tupleType = s"($types)" - val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ") - val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ") - val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ") - - println( - s""" - |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] { - | val schema = StructType(Array($fields)) - | - | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType] - | - | def fromRow(row: InternalRow): $tupleType = { - | ($fromRow) - | } - | - | override def toRow(t: $tupleType): InternalRow = - | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - | - | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = { - | this - | } - | - | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - | - | - | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - |} - """.stripMargin) - } - } -} - -class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema))) - - def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)] - - def fromRow(row: InternalRow): (T1, T2) = { - (e1.fromRow(row), e2.fromRow(row)) - } - - override def toRow(t: (T1, T2)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema))) - - def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)] - - def fromRow(row: InternalRow): (T1, T2, T3) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 21a55a5371..d2d3db0a44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ @@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode { */ case class MapPartitions[T, U]( func: Iterator[T] => Iterator[U], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def missingInput: AttributeSet = AttributeSet.empty @@ -460,8 +460,8 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumn { def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { - val attrs = implicitly[Encoder[U]].schema.toAttributes - new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child) + val attrs = encoderFor[U].schema.toAttributes + new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) } } @@ -472,8 +472,8 @@ object AppendColumn { */ case class AppendColumn[T, U]( func: T => U, - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns @@ -488,11 +488,11 @@ object MapGroups { child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - implicitly[Encoder[K]], - implicitly[Encoder[T]], - implicitly[Encoder[U]], + encoderFor[K], + encoderFor[T], + encoderFor[U], groupingAttributes, - implicitly[Encoder[U]].schema.toAttributes, + encoderFor[U].schema.toAttributes, child) } } @@ -504,9 +504,9 @@ object MapGroups { */ case class MapGroups[K, T, U]( func: (K, Iterator[T]) => Iterator[U], - kEncoder: Encoder[K], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 008d0bea8a..a374da4da1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -47,7 +47,16 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) -class ProductEncoderSuite extends SparkFunSuite { +class ExpressionEncoderSuite extends SparkFunSuite { + + encodeDecodeTest(1) + encodeDecodeTest(1L) + encodeDecodeTest(1.toDouble) + encodeDecodeTest(1.toFloat) + encodeDecodeTest(true) + encodeDecodeTest(false) + encodeDecodeTest(1.toShort) + encodeDecodeTest(1.toByte) encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) @@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite { { (l, r) => l._2.toString == r._2.toString } /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) = + protected def encodeDecodeTest[T : TypeTag](inputData: T) = encodeDecodeTestCustom[T](inputData)((l, r) => l == r) /** * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it * matches the original. */ - protected def encodeDecodeTestCustom[T <: Product : TypeTag]( + protected def encodeDecodeTestCustom[T : TypeTag]( inputData: T)( c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData") { - val encoder = try ProductEncoder[T] catch { + test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { + val encoder = try ExpressionEncoder[T]() catch { case e: Exception => fail(s"Exception thrown generating encoder", e) } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.bind(schema) + val boundEncoder = encoder.resolve(schema).bind(schema) val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail( @@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} + |Encoder: + |$boundEncoder | """.stripMargin, e) } if (!c(inputData, convertedBack)) { - val types = - convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + val types = convertedBack match { + case c: Product => + c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + case other => other.getClass.getName + } + val encodedData = try { convertedData.toSeq(encoder.schema).zip(encoder.schema).map { @@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite { |${encoder.schema.treeString} | |Extract Expressions: - |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")} - | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} - | + |$boundEncoder """.stripMargin) } } |