aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala53
1 files changed, 51 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3540014c3e..1e7296664b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,7 +21,8 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -161,6 +162,18 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def star(names: String*): Expression = names match {
+ case Seq() => UnresolvedStar(None)
+ case target => UnresolvedStar(Option(target))
+ }
+
+ def callFunction[T, U](
+ func: T => U,
+ returnType: DataType,
+ argument: Expression): Expression = {
+ val function = Literal.create(func, ObjectType(classOf[T => U]))
+ Invoke(function, "apply", returnType, argument :: Nil)
+ }
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
@@ -231,6 +244,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+
+ /** Create a function. */
+ def function(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = false)
+ def distinctFunction(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
@@ -243,11 +262,33 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
+ def table(ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref), None)
+
+ def table(db: String, ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
+
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
- def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
+ def select(exprs: Expression*): LogicalPlan = {
+ val namedExpressions = exprs.map {
+ case e: NamedExpression => e
+ case e => UnresolvedAlias(e)
+ }
+ Project(namedExpressions, logicalPlan)
+ }
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
+ val deserialized = logicalPlan.deserialize[T]
+ val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
+ Filter(condition, deserialized).serialize[T]
+ }
+
+ def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
+
+ def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
+
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
@@ -296,6 +337,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
+ def as(alias: String): LogicalPlan = logicalPlan match {
+ case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
+ case plan => SubqueryAlias(alias, plan)
+ }
+
+ def distribute(exprs: Expression*): LogicalPlan =
+ RepartitionByExpression(exprs, logicalPlan)
+
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}