aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala558
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala2
5 files changed, 382 insertions, 265 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index da22180563..599e9cfd23 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -55,9 +55,9 @@ public class JavaHashingTFSuite {
@Test
public void hashingTF() {
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
- RowFactory.create(0, "Hi I heard about Spark"),
- RowFactory.create(0, "I wish Java could use case classes"),
- RowFactory.create(1, "Logistic regression models are neat")
+ RowFactory.create(0.0, "Hi I heard about Spark"),
+ RowFactory.create(0.0, "I wish Java could use case classes"),
+ RowFactory.create(1.0, "Logistic regression models are neat")
));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 1c0ddb5093..2e7b4c236d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.catalyst
import java.lang.{Iterable => JavaIterable}
+import java.math.{BigDecimal => JavaBigDecimal}
+import java.sql.Date
import java.util.{Map => JavaMap}
+import javax.annotation.Nullable
import scala.collection.mutable.HashMap
@@ -34,197 +37,338 @@ object CatalystTypeConverters {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map
+ private def isPrimitive(dataType: DataType): Boolean = {
+ dataType match {
+ case BooleanType => true
+ case ByteType => true
+ case ShortType => true
+ case IntegerType => true
+ case LongType => true
+ case FloatType => true
+ case DoubleType => true
+ case _ => false
+ }
+ }
+
+ private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
+ val converter = dataType match {
+ case udt: UserDefinedType[_] => UDTConverter(udt)
+ case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
+ case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
+ case structType: StructType => StructConverter(structType)
+ case StringType => StringConverter
+ case DateType => DateConverter
+ case dt: DecimalType => BigDecimalConverter
+ case BooleanType => BooleanConverter
+ case ByteType => ByteConverter
+ case ShortType => ShortConverter
+ case IntegerType => IntConverter
+ case LongType => LongConverter
+ case FloatType => FloatConverter
+ case DoubleType => DoubleConverter
+ case _ => IdentityConverter
+ }
+ converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
+ }
+
/**
- * Converts Scala objects to catalyst rows / types. This method is slow, and for batch
- * conversion you should be using converter produced by createToCatalystConverter.
- * Note: This is always called after schemaFor has been called.
- * This ordering is important for UDT registration.
+ * Converts a Scala type to its Catalyst equivalent (and vice versa).
+ *
+ * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst.
+ * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala.
+ * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type.
*/
- def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
- // Check UDT first since UDTs can override other types
- case (obj, udt: UserDefinedType[_]) =>
- udt.serialize(obj)
-
- case (o: Option[_], _) =>
- o.map(convertToCatalyst(_, dataType)).orNull
-
- case (s: Seq[_], arrayType: ArrayType) =>
- s.map(convertToCatalyst(_, arrayType.elementType))
-
- case (jit: JavaIterable[_], arrayType: ArrayType) => {
- val iter = jit.iterator
- var listOfItems: List[Any] = List()
- while (iter.hasNext) {
- val item = iter.next()
- listOfItems :+= convertToCatalyst(item, arrayType.elementType)
+ private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType]
+ extends Serializable {
+
+ /**
+ * Converts a Scala type to its Catalyst equivalent while automatically handling nulls
+ * and Options.
+ */
+ final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = {
+ if (maybeScalaValue == null) {
+ null.asInstanceOf[CatalystType]
+ } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) {
+ val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]]
+ if (opt.isDefined) {
+ toCatalystImpl(opt.get)
+ } else {
+ null.asInstanceOf[CatalystType]
+ }
+ } else {
+ toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType])
}
- listOfItems
}
- case (s: Array[_], arrayType: ArrayType) =>
- s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ */
+ final def toScala(row: Row, column: Int): ScalaOutputType = {
+ if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column)
+ }
+
+ /**
+ * Convert a Catalyst value to its Scala equivalent.
+ */
+ def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType
+
+ /**
+ * Converts a Scala value to its Catalyst equivalent.
+ * @param scalaValue the Scala value, guaranteed not to be null.
+ * @return the Catalyst value.
+ */
+ protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType
+
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ * This method will only be called on non-null columns.
+ */
+ protected def toScalaImpl(row: Row, column: Int): ScalaOutputType
+ }
- case (m: Map[_, _], mapType: MapType) =>
- m.map { case (k, v) =>
- convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
- }
+ private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] {
+ override def toCatalystImpl(scalaValue: Any): Any = scalaValue
+ override def toScala(catalystValue: Any): Any = catalystValue
+ override def toScalaImpl(row: Row, column: Int): Any = row(column)
+ }
- case (jmap: JavaMap[_, _], mapType: MapType) =>
- val iter = jmap.entrySet.iterator
- var listOfEntries: List[(Any, Any)] = List()
- while (iter.hasNext) {
- val entry = iter.next()
- listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType),
- convertToCatalyst(entry.getValue, mapType.valueType))
+ private case class UDTConverter(
+ udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+ override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
+ override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
+ override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column))
+ }
+
+ /** Converter for arrays, sequences, and Java iterables. */
+ private case class ArrayConverter(
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
+
+ private[this] val elementConverter = getConverterForType(elementType)
+
+ override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
+ scalaValue match {
+ case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
+ case s: Seq[_] => s.map(elementConverter.toCatalyst)
+ case i: JavaIterable[_] =>
+ val iter = i.iterator
+ var convertedIterable: List[Any] = List()
+ while (iter.hasNext) {
+ val item = iter.next()
+ convertedIterable :+= elementConverter.toCatalyst(item)
+ }
+ convertedIterable
}
- listOfEntries.toMap
-
- case (p: Product, structType: StructType) =>
- val ar = new Array[Any](structType.size)
- val iter = p.productIterator
- var idx = 0
- while (idx < structType.size) {
- ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType)
- idx += 1
+ }
+
+ override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
+ if (catalystValue == null) {
+ null
+ } else {
+ catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
}
- new GenericRowWithSchema(ar, structType)
+ }
- case (d: String, _) =>
- UTF8String(d)
+ override def toScalaImpl(row: Row, column: Int): Seq[Any] =
+ toScala(row(column).asInstanceOf[Seq[Any]])
+ }
+
+ private case class MapConverter(
+ keyType: DataType,
+ valueType: DataType)
+ extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] {
- case (d: BigDecimal, _) =>
- Decimal(d)
+ private[this] val keyConverter = getConverterForType(keyType)
+ private[this] val valueConverter = getConverterForType(valueType)
- case (d: java.math.BigDecimal, _) =>
- Decimal(d)
+ override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
+ case m: Map[_, _] =>
+ m.map { case (k, v) =>
+ keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v)
+ }
- case (d: java.sql.Date, _) =>
- DateUtils.fromJavaDate(d)
+ case jmap: JavaMap[_, _] =>
+ val iter = jmap.entrySet.iterator
+ val convertedMap: HashMap[Any, Any] = HashMap()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val key = keyConverter.toCatalyst(entry.getKey)
+ convertedMap(key) = valueConverter.toCatalyst(entry.getValue)
+ }
+ convertedMap
+ }
- case (r: Row, structType: StructType) =>
- val converters = structType.fields.map {
- f => (item: Any) => convertToCatalyst(item, f.dataType)
+ override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
+ if (catalystValue == null) {
+ null
+ } else {
+ catalystValue.map { case (k, v) =>
+ keyConverter.toScala(k) -> valueConverter.toScala(v)
+ }
}
- convertRowWithConverters(r, structType, converters)
+ }
- case (other, _) =>
- other
+ override def toScalaImpl(row: Row, column: Int): Map[Any, Any] =
+ toScala(row(column).asInstanceOf[Map[Any, Any]])
}
- /**
- * Creates a converter function that will convert Scala objects to the specified catalyst type.
- * Typical use case would be converting a collection of rows that have the same schema. You will
- * call this function once to get a converter, and apply it to every row.
- */
- private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
- def extractOption(item: Any): Any = item match {
- case opt: Option[_] => opt.orNull
- case other => other
- }
+ private case class StructConverter(
+ structType: StructType) extends CatalystTypeConverter[Any, Row, Row] {
- dataType match {
- // Check UDT first since UDTs can override other types
- case udt: UserDefinedType[_] =>
- (item) => extractOption(item) match {
- case null => null
- case other => udt.serialize(other)
- }
+ private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
- case arrayType: ArrayType =>
- val elementConverter = createToCatalystConverter(arrayType.elementType)
- (item: Any) => {
- extractOption(item) match {
- case a: Array[_] => a.toSeq.map(elementConverter)
- case s: Seq[_] => s.map(elementConverter)
- case i: JavaIterable[_] => {
- val iter = i.iterator
- var convertedIterable: List[Any] = List()
- while (iter.hasNext) {
- val item = iter.next()
- convertedIterable :+= elementConverter(item)
- }
- convertedIterable
- }
- case null => null
- }
+ override def toCatalystImpl(scalaValue: Any): Row = scalaValue match {
+ case row: Row =>
+ val ar = new Array[Any](row.size)
+ var idx = 0
+ while (idx < row.size) {
+ ar(idx) = converters(idx).toCatalyst(row(idx))
+ idx += 1
}
-
- case mapType: MapType =>
- val keyConverter = createToCatalystConverter(mapType.keyType)
- val valueConverter = createToCatalystConverter(mapType.valueType)
- (item: Any) => {
- extractOption(item) match {
- case m: Map[_, _] =>
- m.map { case (k, v) =>
- keyConverter(k) -> valueConverter(v)
- }
-
- case jmap: JavaMap[_, _] =>
- val iter = jmap.entrySet.iterator
- val convertedMap: HashMap[Any, Any] = HashMap()
- while (iter.hasNext) {
- val entry = iter.next()
- convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue)
- }
- convertedMap
-
- case null => null
- }
+ new GenericRowWithSchema(ar, structType)
+
+ case p: Product =>
+ val ar = new Array[Any](structType.size)
+ val iter = p.productIterator
+ var idx = 0
+ while (idx < structType.size) {
+ ar(idx) = converters(idx).toCatalyst(iter.next())
+ idx += 1
}
+ new GenericRowWithSchema(ar, structType)
+ }
- case structType: StructType =>
- val converters = structType.fields.map(f => createToCatalystConverter(f.dataType))
- (item: Any) => {
- extractOption(item) match {
- case r: Row =>
- convertRowWithConverters(r, structType, converters)
-
- case p: Product =>
- val ar = new Array[Any](structType.size)
- val iter = p.productIterator
- var idx = 0
- while (idx < structType.size) {
- ar(idx) = converters(idx)(iter.next())
- idx += 1
- }
- new GenericRowWithSchema(ar, structType)
-
- case null =>
- null
- }
+ override def toScala(row: Row): Row = {
+ if (row == null) {
+ null
+ } else {
+ val ar = new Array[Any](row.size)
+ var idx = 0
+ while (idx < row.size) {
+ ar(idx) = converters(idx).toScala(row, idx)
+ idx += 1
}
-
- case dateType: DateType => (item: Any) => extractOption(item) match {
- case d: java.sql.Date => DateUtils.fromJavaDate(d)
- case other => other
+ new GenericRowWithSchema(ar, structType)
}
+ }
- case dataType: StringType => (item: Any) => extractOption(item) match {
- case s: String => UTF8String(s)
- case other => other
- }
+ override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row])
+ }
+
+ private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
+ override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
+ case str: String => UTF8String(str)
+ case utf8: UTF8String => utf8
+ }
+ override def toScala(catalystValue: Any): String = catalystValue match {
+ case null => null
+ case str: String => str
+ case utf8: UTF8String => utf8.toString()
+ }
+ override def toScalaImpl(row: Row, column: Int): String = row(column).toString
+ }
+
+ private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
+ override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue)
+ override def toScala(catalystValue: Any): Date =
+ if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int])
+ override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column))
+ }
+
+ private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
+ override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
+ case d: BigDecimal => Decimal(d)
+ case d: JavaBigDecimal => Decimal(d)
+ case d: Decimal => d
+ }
+ override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
+ override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match {
+ case d: JavaBigDecimal => d
+ case d: Decimal => d.toJavaBigDecimal
+ }
+ }
+
+ private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
+ final override def toScala(catalystValue: Any): Any = catalystValue
+ final override def toCatalystImpl(scalaValue: T): Any = scalaValue
+ }
+
+ private object BooleanConverter extends PrimitiveConverter[Boolean] {
+ override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column)
+ }
+
+ private object ByteConverter extends PrimitiveConverter[Byte] {
+ override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column)
+ }
+
+ private object ShortConverter extends PrimitiveConverter[Short] {
+ override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column)
+ }
+
+ private object IntConverter extends PrimitiveConverter[Int] {
+ override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column)
+ }
+
+ private object LongConverter extends PrimitiveConverter[Long] {
+ override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column)
+ }
+
+ private object FloatConverter extends PrimitiveConverter[Float] {
+ override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column)
+ }
- case _ =>
- (item: Any) => extractOption(item) match {
- case d: BigDecimal => Decimal(d)
- case d: java.math.BigDecimal => Decimal(d)
- case other => other
+ private object DoubleConverter extends PrimitiveConverter[Double] {
+ override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column)
+ }
+
+ /**
+ * Converts Scala objects to catalyst rows / types. This method is slow, and for batch
+ * conversion you should be using converter produced by createToCatalystConverter.
+ * Note: This is always called after schemaFor has been called.
+ * This ordering is important for UDT registration.
+ */
+ def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = {
+ getConverterForType(dataType).toCatalyst(scalaValue)
+ }
+
+ /**
+ * Creates a converter function that will convert Scala objects to the specified Catalyst type.
+ * Typical use case would be converting a collection of rows that have the same schema. You will
+ * call this function once to get a converter, and apply it to every row.
+ */
+ private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
+ if (isPrimitive(dataType)) {
+ // Although the `else` branch here is capable of handling inbound conversion of primitives,
+ // we add some special-case handling for those types here. The motivation for this relates to
+ // Java method invocation costs: if we have rows that consist entirely of primitive columns,
+ // then returning the same conversion function for all of the columns means that the call site
+ // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in
+ // a measurable performance impact. Note that this optimization will be unnecessary if we
+ // use code generation to construct Scala Row -> Catalyst Row converters.
+ def convert(maybeScalaValue: Any): Any = {
+ if (maybeScalaValue.isInstanceOf[Option[Any]]) {
+ maybeScalaValue.asInstanceOf[Option[Any]].orNull
+ } else {
+ maybeScalaValue
}
+ }
+ convert
+ } else {
+ getConverterForType(dataType).toCatalyst
}
}
/**
- * Converts Scala objects to catalyst rows / types.
+ * Converts Scala objects to Catalyst rows / types.
*
* Note: This should be called before do evaluation on Row
* (It does not support UDT)
* This is used to create an RDD or test results with correct types for Catalyst.
*/
def convertToCatalyst(a: Any): Any = a match {
- case s: String => UTF8String(s)
- case d: java.sql.Date => DateUtils.fromJavaDate(d)
- case d: BigDecimal => Decimal(d)
- case d: java.math.BigDecimal => Decimal(d)
+ case s: String => StringConverter.toCatalyst(s)
+ case d: Date => DateConverter.toCatalyst(d)
+ case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
+ case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
@@ -238,33 +382,8 @@ object CatalystTypeConverters {
* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
*/
- def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
- // Check UDT first since UDTs can override other types
- case (d, udt: UserDefinedType[_]) =>
- udt.deserialize(d)
-
- case (s: Seq[_], arrayType: ArrayType) =>
- s.map(convertToScala(_, arrayType.elementType))
-
- case (m: Map[_, _], mapType: MapType) =>
- m.map { case (k, v) =>
- convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
- }
-
- case (r: Row, s: StructType) =>
- convertRowToScala(r, s)
-
- case (d: Decimal, _: DecimalType) =>
- d.toJavaBigDecimal
-
- case (i: Int, DateType) =>
- DateUtils.toJavaDate(i)
-
- case (s: UTF8String, StringType) =>
- s.toString()
-
- case (other, _) =>
- other
+ def convertToScala(catalystValue: Any, dataType: DataType): Any = {
+ getConverterForType(dataType).toScala(catalystValue)
}
/**
@@ -272,82 +391,7 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
- private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match {
- // Check UDT first since UDTs can override other types
- case udt: UserDefinedType[_] =>
- (item: Any) => if (item == null) null else udt.deserialize(item)
-
- case arrayType: ArrayType =>
- val elementConverter = createToScalaConverter(arrayType.elementType)
- (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter)
-
- case mapType: MapType =>
- val keyConverter = createToScalaConverter(mapType.keyType)
- val valueConverter = createToScalaConverter(mapType.valueType)
- (item: Any) => if (item == null) {
- null
- } else {
- item.asInstanceOf[Map[_, _]].map { case (k, v) =>
- keyConverter(k) -> valueConverter(v)
- }
- }
-
- case s: StructType =>
- val converters = s.fields.map(f => createToScalaConverter(f.dataType))
- (item: Any) => {
- if (item == null) {
- null
- } else {
- convertRowWithConverters(item.asInstanceOf[Row], s, converters)
- }
- }
-
- case _: DecimalType =>
- (item: Any) => item match {
- case d: Decimal => d.toJavaBigDecimal
- case other => other
- }
-
- case DateType =>
- (item: Any) => item match {
- case i: Int => DateUtils.toJavaDate(i)
- case other => other
- }
-
- case StringType =>
- (item: Any) => item match {
- case s: UTF8String => s.toString()
- case other => other
- }
-
- case other =>
- (item: Any) => item
- }
-
- def convertRowToScala(r: Row, schema: StructType): Row = {
- val ar = new Array[Any](r.size)
- var idx = 0
- while (idx < r.size) {
- ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType)
- idx += 1
- }
- new GenericRowWithSchema(ar, schema)
- }
-
- /**
- * Converts a row by applying the provided set of converter functions. It is used for both
- * toScala and toCatalyst conversions.
- */
- private[sql] def convertRowWithConverters(
- row: Row,
- schema: StructType,
- converters: Array[Any => Any]): Row = {
- val ar = new Array[Any](row.size)
- var idx = 0
- while (idx < row.size) {
- ar(idx) = converters(idx)(row(idx))
- idx += 1
- }
- new GenericRowWithSchema(ar, schema)
+ private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
+ getConverterForType(dataType).toScala
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 634138010f..b6191eafba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -71,12 +71,23 @@ case class UserDefinedGenerator(
children: Seq[Expression])
extends Generator {
+ @transient private[this] var inputRow: InterpretedProjection = _
+ @transient private[this] var convertToScala: (Row) => Row = _
+
+ private def initializeConverters(): Unit = {
+ inputRow = new InterpretedProjection(children)
+ convertToScala = {
+ val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
+ CatalystTypeConverters.createToScalaConverter(inputSchema)
+ }.asInstanceOf[(Row => Row)]
+ }
+
override def eval(input: Row): TraversableOnce[Row] = {
- // TODO(davies): improve this
+ if (inputRow == null) {
+ initializeConverters()
+ }
// Convert the objects into Scala Type before calling function, we need schema to support UDT
- val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
- val inputRow = new InterpretedProjection(children)
- function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
+ function(convertToScala(inputRow(input)))
}
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
new file mode 100644
index 0000000000..df0f04563e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+class CatalystTypeConvertersSuite extends SparkFunSuite {
+
+ private val simpleTypes: Seq[DataType] = Seq(
+ StringType,
+ DateType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType)
+
+ test("null handling in rows") {
+ val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
+ val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
+ val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)
+
+ val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
+ assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
+ }
+
+ test("null handling for individual values") {
+ for (dataType <- simpleTypes) {
+ assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
+ }
+ }
+
+ test("option handling in convertToCatalyst") {
+ // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
+ // createToCatalystConverter but it may not actually matter as this is only called internally
+ // in a handful of places where we don't expect to receive Options.
+ assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
+ }
+
+ test("option handling in createToCatalystConverter") {
+ assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 56591d9dba..055453e688 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
new Timestamp(i),
(1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
- Row((i - 0.25).toFloat, (1 to i).toSeq))
+ Row((i - 0.25).toFloat, Seq(true, false, null)))
}
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.