aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala64
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala1
5 files changed, 79 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index e1b5992a36..5dd19dd12d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -71,6 +71,8 @@ object DataType {
case JSortedObject(
("class", JString(udtClass)),
+ ("pyClass", _),
+ ("sqlType", _),
("type", JString("udt"))) =>
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}
@@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/** Underlying storage type for this UDT */
def sqlType: DataType
+ /** Paired Python UDT class, if exists. */
+ def pyUDT: String = null
+
/**
* Convert the user type to a SQL datum
*
@@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
- ("class" -> this.getClass.getName)
+ ("class" -> this.getClass.getName) ~
+ ("pyClass" -> pyUDT) ~
+ ("sqlType" -> sqlType.jsonValue)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9e61d18f7e..84eaf401f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
@@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
+ case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
case other => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 997669051e..a83cf5d441 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -135,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+ case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
+
case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
// Pyrolite can handle Timestamp
@@ -177,6 +179,9 @@ object EvaluatePython {
case (c: java.util.Calendar, TimestampType) =>
new java.sql.Timestamp(c.getTime().getTime())
+ case (_, udt: UserDefinedType[_]) =>
+ fromJava(obj, udt.sqlType)
+
case (c: Int, ByteType) => c.toByte
case (c: Long, ByteType) => c.toByte
case (c: Int, ShortType) => c.toShort
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
new file mode 100644
index 0000000000..b9569e96c0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * An example class to demonstrate UDT in Scala, Java, and Python.
+ * @param x x coordinate
+ * @param y y coordinate
+ */
+@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
+private[sql] class ExamplePoint(val x: Double, val y: Double)
+
+/**
+ * User-defined type for [[ExamplePoint]].
+ */
+private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
+
+ override def sqlType: DataType = ArrayType(DoubleType, false)
+
+ override def pyUDT: String = "pyspark.tests.ExamplePointUDT"
+
+ override def serialize(obj: Any): Seq[Double] = {
+ obj match {
+ case p: ExamplePoint =>
+ Seq(p.x, p.y)
+ }
+ }
+
+ override def deserialize(datum: Any): ExamplePoint = {
+ datum match {
+ case values: Seq[_] =>
+ val xy = values.asInstanceOf[Seq[Double]]
+ assert(xy.length == 2)
+ new ExamplePoint(xy(0), xy(1))
+ case values: util.ArrayList[_] =>
+ val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
+ new ExamplePoint(xy(0), xy(1))
+ }
+ }
+
+ override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 1bc15146f0..3fa4a7c648 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.UserDefinedType
-
protected[sql] object DataTypeConversions {
/**