/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s * for classes whose fields are entirely defined by constructor params but should not be * case classes. */ trait DefinedByConstructorParams /** * A default version of ScalaReflection that uses the runtime universe. */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe // Since we are creating a runtime mirror using the class loader of current thread, // we need to use def at here. So, every time we call mirror, it is using the // class loader of the current thread. // SPARK-13640: Synchronize this because universe.runtimeMirror is not thread-safe in Scala 2.10. override def mirror: universe.Mirror = ScalaReflectionLock.synchronized { universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } import universe._ /* * Retrieves the runtime class corresponding to the provided type. */ override def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } /** * Support for generating catalyst schemas for scala objects. Note that unlike its companion * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait ScalaReflection { /** The universe we work in (runtime or macro) */ val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ def mirror: universe.Mirror import universe._ // The Predef.Map is scala.collection.immutable.Map. // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map /** * Returns the parameter names and types for the primary constructor of this class. * * Note that it only works for scala classes with primary constructor, and currently doesn't * support inner class. */ def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { val classSymbol = mirror.staticClass(cls.getName) val t = classSymbol.selfType getConstructorParameters(t) } /** * Returns the parameter names for the primary constructor of this class. * * Logically we should call `getConstructorParameters` and throw away the parameter types to get * parameter names, however there are some weird scala reflection problems and this method is a * workaround to avoid getting parameter types. */ def getConstructorParameterNames(cls: Class[_]): Seq[String] = { val classSymbol = mirror.staticClass(cls.getName) val t = classSymbol.selfType constructParams(t).map(_.name.toString) } def getClassFromType(tpe: Type): Class[_] = ??? /** * Return the Scala Type for `T` in the current classloader mirror. * * Use this method instead of the convenience method `universe.typeOf`, which * assumes that all types can be found in the classloader that loaded scala-reflect classes. * That's not necessarily the case when running using Eclipse launchers or even * Sbt console or test (without `fork := true`). * * @see SPARK-5281 */ // SPARK-13640: Synchronize this because WeakTypeTag.tpe is not thread-safe in Scala 2.10. def localTypeOf[T: WeakTypeTag]: `Type` = ScalaReflectionLock.synchronized { val tag = implicitly[WeakTypeTag[T]] tag.in(mirror).tpe.normalize } /** * Returns the full class name for a type. The returned name is the canonical * Scala name, where each component is separated by a period. It is NOT the * Java-equivalent runtime name (no dollar signs). * * In simple cases, both the Scala and Java names are the same, however when Scala * generates constructs that do not map to a Java equivalent, such as singleton objects * or nested classes in package objects, it uses the dollar sign ($) to create * synthetic classes, emulating behaviour in Java bytecode. */ def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } /** * Returns classes of input parameters of scala function object. */ def getParameterTypes(func: AnyRef): Seq[Class[_]] = { val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) assert(methods.length == 1) methods.head.getParameterTypes } /** * Returns the parameter names and types for the primary constructor of this type. * * Note that it only works for scala classes with primary constructor, and currently doesn't * support inner class. */ def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe constructParams(tpe).map { p => p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } protected def constructParams(tpe: Type): Seq[Symbol] = { val constructorSymbol = tpe.member(nme.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramss } else { // Find the primary constructor, and use its parameter ordering. val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( s => s.isMethod && s.asMethod.isPrimaryConstructor) if (primaryConstructorSymbol.isEmpty) { sys.error("Internal SQL error: Product object did not have a primary constructor.") } else { primaryConstructorSymbol.get.asMethod.paramss } } params.flatten } // The Predef.Map is scala.collection.immutable.Map. // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including * those that hold primitive types. * * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ def dataTypeFor[T : WeakTypeTag]: DataType = dataTypeFor(localTypeOf[T]) private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { tpe match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.ShortTpe => ShortType case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) className match { case "scala.Array" => val TypeRef(_, _, Seq(elementType)) = tpe arrayClassFor(elementType) case other => val clazz = getClassFromType(tpe) ObjectType(clazz) } } } /** * Given a type `T` this function constructs and ObjectType that holds a class of type * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] case t if t <:< definitions.FloatTpe => classOf[Array[Float]] case t if t <:< definitions.ShortTpe => classOf[Array[Short]] case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] case other => // There is probably a better way to do this, but I couldn't find it... val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls java.lang.reflect.Array.newInstance(elementType, 1).getClass } ObjectType(cls) } /** * Returns true if the value of this data type is same between internal and external. */ def isNativeType(dt: DataType): Boolean = dt match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => true case _ => false } /** * Returns an expression that can be used to deserialize an input row to an object of type `T` * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ def deserializerFor[T : WeakTypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil deserializerFor(tpe, None, walkedTypePath) } private def deserializerFor( tpe: `Type`, path: Option[Expression], walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal( ordinal: Int, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) .getOrElse(GetColumnByOrdinal(ordinal, dataType)) upCastToExpectedType(newPath, dataType, walkedTypePath) } /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType if (path.isDefined) { path.get } else { upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) } } /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't * match the encoder's schema. * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * * This method help us "remember" the required data type by adding a `UpCast`. Note that we * don't need to cast struct type because there must be `UnresolvedExtractValue` or * `GetStructField` wrapping it, thus we only need to handle leaf type. */ def upCastToExpectedType( expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _ => UpCast(expr, expected, walkedTypePath) } tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) case t if t <:< localTypeOf[BigDecimal] => Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) case t if t <:< localTypeOf[java.math.BigInteger] => Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) case t if t <:< localTypeOf[scala.math.BigInt] => Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t // TODO: add runtime null check for primitive array val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") case t if t <:< definitions.FloatTpe => Some("toFloatArray") case t if t <:< definitions.ShortTpe => Some("toShortArray") case t if t <:< definitions.ByteTpe => Some("toByteArray") case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") case _ => None } primitiveMethod.map { method => Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( p => deserializerFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", arrayClassFor(elementType)) } case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { val converter = deserializerFor(elementType, Some(p), newTypePath) if (nullable) { converter } else { AssertNotNull(converter, newTypePath) } } val array = Invoke( MapObjects(mapFunction, getPath, dataType), "array", ObjectType(classOf[Array[Any]])) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", array :: Nil) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val keyData = Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) StaticInvoke( ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() .asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, dataType = ObjectType(udt.getClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) val cls = getClassFromType(tpe) val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { val constructor = deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) if (!nullable) { AssertNotNull(constructor, newTypePath) } else { constructor } } } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( IsNull(getPath), expressions.Literal.create(null, ObjectType(cls)), newInstance ) } else { newInstance } } } /** * Returns an expression for serializing an object of type T to an internal row. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain * the type path walked so far and which class we are not supporting. * There are 4 kinds of type path: * * the root type: `root class: "abc.xyz.MyClass"` * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ def serializerFor[T : WeakTypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ private def serializerFor( inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath MapObjects(serializerFor(_, elementType, newPath), input, dt) case dt => NewInstance( classOf[GenericArrayData], input :: Nil, dataType = ArrayType(dt, schemaFor(elementType).nullable)) } } tpe match { case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) serializerFor(unwrapped, optType, newPath) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath ExternalMapToCatalyst( inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath), dataTypeFor(valueType), serializerFor(_, valueType, valuePath)) case t if t <:< localTypeOf[String] => StaticInvoke( classOf[UTF8String], StringType, "fromString", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", inputObject :: Nil) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) case t if t <:< localTypeOf[java.lang.Long] => Invoke(inputObject, "longValue", LongType) case t if t <:< localTypeOf[java.lang.Double] => Invoke(inputObject, "doubleValue", DoubleType) case t if t <:< localTypeOf[java.lang.Float] => Invoke(inputObject, "floatValue", FloatType) case t if t <:< localTypeOf[java.lang.Short] => Invoke(inputObject, "shortValue", ShortType) case t if t <:< localTypeOf[java.lang.Byte] => Invoke(inputObject, "byteValue", ByteType) case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() .asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, dataType = ObjectType(udt.getClass)) Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) } val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } /** * Returns the parameter values for the primary constructor of this class. */ def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = { getConstructorParameterNames(obj.getClass).map { name => obj.getClass.getMethod(name).invoke(obj) } } case class Schema(dataType: DataType, nullable: Boolean) /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: WeakTypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => s.toAttributes } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: WeakTypeTag]: Schema = schemaFor(localTypeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { tpe match { case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() .asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigInteger] => Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[scala.math.BigInt] => Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, nullable = true) case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, nullable = true) case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable = true) case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( params.map { case (fieldName, fieldType) => val Schema(dataType, nullable) = schemaFor(fieldType) StructField(fieldName, dataType, nullable) }), nullable = true) case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = { tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams] } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") }