aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakeshi Yamamuro <yamamuro@apache.org>2017-03-05 03:53:19 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-05 03:53:19 -0800
commit14bb398fae974137c3e38162cefc088e12838258 (patch)
treecb83b0f4b81c86a8a22237d818170247bbc9e825 /sql
parentf48461ab2bdb91cd00efa5a5ec4b0b2bc361e7a2 (diff)
downloadspark-14bb398fae974137c3e38162cefc088e12838258.tar.gz
spark-14bb398fae974137c3e38162cefc088e12838258.tar.bz2
spark-14bb398fae974137c3e38162cefc088e12838258.zip
[SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit
## What changes were proposed in this pull request? This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure. ## How was this patch tested? Added tests in `LiteralExpressionSuite` Author: Takeshi Yamamuro <yamamuro@apache.org> Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #16610 from maropu/SPARK-19254.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala90
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala14
4 files changed, 121 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e66fb89339..eaeaf08c37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -32,11 +32,13 @@ import java.util.Objects
import javax.xml.bind.DatatypeConverter
import scala.math.{BigDecimal, BigInt}
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.Try
import org.json4s.JsonAST._
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -153,6 +155,14 @@ object Literal {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
+ def create[T : TypeTag](v: T): Literal = Try {
+ val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
+ val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
+ Literal(convert(v), dataType)
+ }.getOrElse {
+ Literal(v)
+ }
+
/**
* Create a literal with default value for given DataType
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 15e8e6c057..a9e0eb0e37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)
+
+ checkEvaluation(Literal.create(true), true)
+ checkEvaluation(Literal.create(false), false)
}
test("int literals") {
@@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal(d.toLong), d.toLong)
checkEvaluation(Literal(d.toShort), d.toShort)
checkEvaluation(Literal(d.toByte), d.toByte)
+
+ checkEvaluation(Literal.create(d), d)
+ checkEvaluation(Literal.create(d.toLong), d.toLong)
+ checkEvaluation(Literal.create(d.toShort), d.toShort)
+ checkEvaluation(Literal.create(d.toByte), d.toByte)
}
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
+
+ checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
+ checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
}
test("double literals") {
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
+
+ checkEvaluation(Literal.create(d), d)
+ checkEvaluation(Literal.create(d.toFloat), d.toFloat)
}
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
+ checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
+ checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
+ checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
+ checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)
+
}
test("string literals") {
checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal("\u0000"), "\u0000")
+
+ checkEvaluation(Literal.create(""), "")
+ checkEvaluation(Literal.create("test"), "test")
+ checkEvaluation(Literal.create("\u0000"), "\u0000")
}
test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
+ checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
}
test("binary literals") {
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
+
+ checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
+ checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
}
test("decimal") {
@@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Decimal((d * 1000L).toLong, 10, 3))
checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))
+
+ checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
+ checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
+ checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
+ checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
+ Decimal((d * 1000L).toLong, 10, 3))
+ checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
+ checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))
+
}
}
+ private def toCatalyst[T: TypeTag](value: T): Any = {
+ val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
+ CatalystTypeConverters.createToCatalystConverter(dataType)(value)
+ }
+
test("array") {
- def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
- val toCatalyst = (a: Array[_], elementType: DataType) => {
- CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
- }
- checkEvaluation(Literal(a), toCatalyst(a, elementType))
+ def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
+ checkEvaluation(Literal(a), toCatalyst(a))
+ checkEvaluation(Literal.create(a), toCatalyst(a))
+ }
+ checkArrayLiteral(Array(1, 2, 3))
+ checkArrayLiteral(Array("a", "b", "c"))
+ checkArrayLiteral(Array(1.0, 4.0))
+ checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
+ }
+
+ test("seq") {
+ def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
+ checkEvaluation(Literal.create(a), toCatalyst(a))
}
- checkArrayLiteral(Array(1, 2, 3), IntegerType)
- checkArrayLiteral(Array("a", "b", "c"), StringType)
- checkArrayLiteral(Array(1.0, 4.0), DoubleType)
- checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
+ checkSeqLiteral(Seq(1, 2, 3), IntegerType)
+ checkSeqLiteral(Seq("a", "b", "c"), StringType)
+ checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
+ checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
CalendarIntervalType)
}
- test("unsupported types (map and struct) in literals") {
+ test("map") {
+ def checkMapLiteral[T: TypeTag](m: T): Unit = {
+ checkEvaluation(Literal.create(m), toCatalyst(m))
+ }
+ checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
+ checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
+ }
+
+ test("struct") {
+ def checkStructLiteral[T: TypeTag](s: T): Unit = {
+ checkEvaluation(Literal.create(s), toCatalyst(s))
+ }
+ checkStructLiteral((1, 3.0, "abcde"))
+ checkStructLiteral(("de", 1, 2.0f))
+ checkStructLiteral((1, ("fgh", 3.0)))
+ }
+
+ test("unsupported types (map and struct) in Literal.apply") {
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
val errMsgMap = intercept[RuntimeException] {
Literal(v)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 24ed906d33..2247010ac3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -91,15 +91,24 @@ object functions {
* @group normal_funcs
* @since 1.3.0
*/
- def lit(literal: Any): Column = {
- literal match {
- case c: Column => return c
- case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
- case _ => // continue
- }
+ def lit(literal: Any): Column = typedLit(literal)
- val literalExpr = Literal(literal)
- Column(literalExpr)
+ /**
+ * Creates a [[Column]] of literal value.
+ *
+ * The passed in object is returned directly if it is already a [[Column]].
+ * If the object is a Scala Symbol, it is converted into a [[Column]] also.
+ * Otherwise, a new [[Column]] is created to represent the literal value.
+ * The difference between this function and [[lit]] is that this function
+ * can handle parameterized scala types e.g.: List, Seq and Map.
+ *
+ * @group normal_funcs
+ * @since 2.2.0
+ */
+ def typedLit[T : TypeTag](literal: T): Column = literal match {
+ case c: Column => c
+ case s: Symbol => new ColumnName(s.name)
+ case _ => Column(Literal.create(literal))
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index ee280a313c..b0f398dab7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
}
+
+ test("typedLit") {
+ val df = Seq(Tuple1(0)).toDF("a")
+ // Only check the types `lit` cannot handle
+ checkAnswer(
+ df.select(typedLit(Seq(1, 2, 3))),
+ Row(Seq(1, 2, 3)) :: Nil)
+ checkAnswer(
+ df.select(typedLit(Map("a" -> 1, "b" -> 2))),
+ Row(Map("a" -> 1, "b" -> 2)) :: Nil)
+ checkAnswer(
+ df.select(typedLit(("a", 2, 1.0))),
+ Row(Row("a", 2, 1.0)) :: Nil)
+ }
}