aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-07-23 14:47:23 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-23 14:47:23 -0700
commit1b790cf7755cace0d89ac5777717e6df3be7356f (patch)
tree6458839315193b58b224d77b3984501b0b9c6fe9 /sql
parentf776bc98878428940b5130c0d7d9b7ee452c0bd3 (diff)
downloadspark-1b790cf7755cace0d89ac5777717e6df3be7356f.tar.gz
spark-1b790cf7755cace0d89ac5777717e6df3be7356f.tar.bz2
spark-1b790cf7755cace0d89ac5777717e6df3be7356f.zip
[SPARK-2588][SQL] Add some more DSLs.
Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #1491 from ueshin/issues/SPARK-2588 and squashes the following commits: 43d0a46 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-2588 1023ea0 [Takuya UESHIN] Modify tests to use DSLs. 2310bf1 [Takuya UESHIN] Add some more DSLs.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala29
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala15
3 files changed, 70 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1b503b957d..15c98efbca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -79,8 +79,24 @@ package object dsl {
def === (other: Expression) = EqualTo(expr, other)
def !== (other: Expression) = Not(EqualTo(expr, other))
+ def in(list: Expression*) = In(expr, list)
+
def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
+ def contains(other: Expression) = Contains(expr, other)
+ def startsWith(other: Expression) = StartsWith(expr, other)
+ def endsWith(other: Expression) = EndsWith(expr, other)
+ def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ Substring(expr, pos, len)
+ def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ Substring(expr, pos, len)
+
+ def isNull = IsNull(expr)
+ def isNotNull = IsNotNull(expr)
+
+ def getItem(ordinal: Expression) = GetItem(expr, ordinal)
+ def getField(fieldName: String) = GetField(expr, fieldName)
+
def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
@@ -112,6 +128,7 @@ package object dsl {
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
def countDistinct(e: Expression*) = CountDistinct(e)
+ def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
def min(e: Expression) = Min(e)
@@ -163,6 +180,18 @@ package object dsl {
/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = true)()
+
+ /** Creates a new AttributeReference of type array */
+ def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
+
+ /** Creates a new AttributeReference of type map */
+ def map(keyType: DataType, valueType: DataType): AttributeReference =
+ map(MapType(keyType, valueType))
+ def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
+
+ /** Creates a new AttributeReference of type struct */
+ def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
+ def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index db1ae29d40..c3f5c26fdb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite {
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)
- checkEvaluation(IsNull(c1), false, row)
- checkEvaluation(IsNotNull(c1), true, row)
+ checkEvaluation(c1.isNull, false, row)
+ checkEvaluation(c1.isNotNull, true, row)
- checkEvaluation(IsNull(c2), true, row)
- checkEvaluation(IsNotNull(c2), false, row)
+ checkEvaluation(c2.isNull, true, row)
+ checkEvaluation(c2.isNotNull, false, row)
- checkEvaluation(IsNull(Literal(1, ShortType)), false)
- checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
+ checkEvaluation(Literal(1, ShortType).isNull, false)
+ checkEvaluation(Literal(1, ShortType).isNotNull, true)
- checkEvaluation(IsNull(Literal(null, ShortType)), true)
- checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
+ checkEvaluation(Literal(null, ShortType).isNull, true)
+ checkEvaluation(Literal(null, ShortType).isNotNull, false)
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
@@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)
- checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
- checkEvaluation(In(Literal("^Ba*n", StringType),
- Literal("^Ba*n", StringType) :: Nil), true, row)
- checkEvaluation(In(Literal("^Ba*n", StringType),
- Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
+ checkEvaluation(c1 in (c1, c2), true, row)
+ checkEvaluation(
+ Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row)
+ checkEvaluation(
+ Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row)
}
test("case when") {
@@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
+
+ checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
+ checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
+ checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
}
test("arithmetic") {
@@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite {
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
- checkEvaluation(Contains(c1, "b"), true, row)
- checkEvaluation(Contains(c1, "x"), false, row)
- checkEvaluation(Contains(c2, "b"), null, row)
- checkEvaluation(Contains(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 contains "b", true, row)
+ checkEvaluation(c1 contains "x", false, row)
+ checkEvaluation(c2 contains "b", null, row)
+ checkEvaluation(c1 contains Literal(null, StringType), null, row)
- checkEvaluation(StartsWith(c1, "a"), true, row)
- checkEvaluation(StartsWith(c1, "b"), false, row)
- checkEvaluation(StartsWith(c2, "a"), null, row)
- checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 startsWith "a", true, row)
+ checkEvaluation(c1 startsWith "b", false, row)
+ checkEvaluation(c2 startsWith "a", null, row)
+ checkEvaluation(c1 startsWith Literal(null, StringType), null, row)
- checkEvaluation(EndsWith(c1, "c"), true, row)
- checkEvaluation(EndsWith(c1, "b"), false, row)
- checkEvaluation(EndsWith(c2, "b"), null, row)
- checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 endsWith "c", true, row)
+ checkEvaluation(c1 endsWith "b", false, row)
+ checkEvaluation(c2 endsWith "b", null, row)
+ checkEvaluation(c1 endsWith Literal(null, StringType), null, row)
}
test("Substring") {
@@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false)
assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true)
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true)
+
+ checkEvaluation(s.substr(0, 2), "ex", row)
+ checkEvaluation(s.substr(0), "example", row)
+ checkEvaluation(s.substring(0, 2), "ex", row)
+ checkEvaluation(s.substring(0), "example", row)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index c8ea01c4e1..1a6a6c1747 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
/* Implicits */
@@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
- testData2.groupBy('a)('a, Sum('b)),
+ testData2.groupBy('a)('a, sum('b)),
Seq((1,3),(2,3),(3,3))
)
checkAnswer(
- testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
+ testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
9
)
checkAnswer(
- testData2.aggregate(Sum('b)),
+ testData2.aggregate(sum('b)),
9
)
}
@@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest {
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
checkAnswer(
- arrayData.orderBy(GetItem('data, 0).asc),
+ arrayData.orderBy('data.getItem(0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)
checkAnswer(
- arrayData.orderBy(GetItem('data, 0).desc),
+ arrayData.orderBy('data.getItem(0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
checkAnswer(
- mapData.orderBy(GetItem('data, 1).asc),
+ mapData.orderBy('data.getItem(1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)
checkAnswer(
- mapData.orderBy(GetItem('data, 1).desc),
+ mapData.orderBy('data.getItem(1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}