aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala89
3 files changed, 135 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index be67605c45..be0d75a830 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -389,6 +390,15 @@ object ScalaReflection extends ScalaReflection {
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
}
}
@@ -603,6 +613,15 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass))
+ Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
case other =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
@@ -671,6 +690,10 @@ object ScalaReflection extends ScalaReflection {
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ Schema(udt, nullable = true)
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index a8397aa5e5..44e135cbf8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.collection.Map
import scala.reflect.ClassTag
+import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -55,10 +56,19 @@ object RowEncoder {
case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
+ val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
+ val udtClass: Class[_] = if (annotation != null) {
+ annotation.udt()
+ } else {
+ UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
+ throw new SparkException(s"${udt.userClass.getName} is not annotated with " +
+ "SQLUserDefinedType nor registered with UDTRegistration.}")
+ }
+ }
val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ udtClass,
Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ dataType = ObjectType(udtClass), false)
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
case TimestampType =>
@@ -187,10 +197,19 @@ object RowEncoder {
FloatType | DoubleType | BinaryType | CalendarIntervalType => input
case udt: UserDefinedType[_] =>
+ val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
+ val udtClass: Class[_] = if (annotation != null) {
+ annotation.udt()
+ } else {
+ UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
+ throw new SparkException(s"${udt.userClass.getName} is not annotated with " +
+ "SQLUserDefinedType nor registered with UDTRegistration.}")
+ }
+ }
val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ udtClass,
Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ dataType = ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
case TimestampType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
new file mode 100644
index 0000000000..0f24e51ed2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * This object keeps the mappings between user classes and their User Defined Types (UDTs).
+ * Previously we use the annotation `SQLUserDefinedType` to register UDTs for user classes.
+ * However, by doing this, we add SparkSQL dependency on user classes. This object provides
+ * alterntive approach to register UDTs for user classes.
+ */
+private[spark]
+object UDTRegistration extends Serializable with Logging {
+
+ /** The mapping between the Class between UserDefinedType and user classes. */
+ private lazy val udtMap: mutable.Map[String, String] = mutable.Map(
+ ("org.apache.spark.ml.linalg.Vector", "org.apache.spark.ml.linalg.VectorUDT"),
+ ("org.apache.spark.ml.linalg.DenseVector", "org.apache.spark.ml.linalg.VectorUDT"),
+ ("org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.VectorUDT"),
+ ("org.apache.spark.ml.linalg.Matrix", "org.apache.spark.ml.linalg.MatrixUDT"),
+ ("org.apache.spark.ml.linalg.DenseMatrix", "org.apache.spark.ml.linalg.MatrixUDT"),
+ ("org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.MatrixUDT"))
+
+ /**
+ * Queries if a given user class is already registered or not.
+ * @param userClassName the name of user class
+ * @return boolean value indicates if the given user class is registered or not
+ */
+ def exists(userClassName: String): Boolean = udtMap.contains(userClassName)
+
+ /**
+ * Registers an UserDefinedType to an user class. If the user class is already registered
+ * with another UserDefinedType, warning log message will be shown.
+ * @param userClass the name of user class
+ * @param udtClass the name of UserDefinedType class for the given userClass
+ */
+ def register(userClass: String, udtClass: String): Unit = {
+ if (udtMap.contains(userClass)) {
+ logWarning(s"Cannot register UDT for ${userClass}, which is already registered.")
+ } else {
+ // When register UDT with class name, we can't check if the UDT class is an UserDefinedType,
+ // or not. The check is deferred.
+ udtMap += ((userClass, udtClass))
+ }
+ }
+
+ /**
+ * Returns the Class of UserDefinedType for the name of a given user class.
+ * @param userClass class name of user class
+ * @return Option value of the Class object of UserDefinedType
+ */
+ def getUDTFor(userClass: String): Option[Class[_]] = {
+ udtMap.get(userClass).map { udtClassName =>
+ if (Utils.classIsLoadable(udtClassName)) {
+ val udtClass = Utils.classForName(udtClassName)
+ if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) {
+ udtClass
+ } else {
+ throw new SparkException(
+ s"${udtClass.getName} is not an UserDefinedType. Please make sure registering " +
+ s"an UserDefinedType for ${userClass}")
+ }
+ } else {
+ throw new SparkException(
+ s"Can not load in UserDefinedType ${udtClassName} for user class ${userClass}.")
+ }
+ }
+ }
+}