diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 53 |
1 files changed, 51 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3540014c3e..1e7296664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,8 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -161,6 +162,18 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def star(names: String*): Expression = names match { + case Seq() => UnresolvedStar(None) + case target => UnresolvedStar(Option(target)) + } + + def callFunction[T, U]( + func: T => U, + returnType: DataType, + argument: Expression): Expression = { + val function = Literal.create(func, ObjectType(classOf[T => U])) + Invoke(function, "apply", returnType, argument :: Nil) + } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? @@ -231,6 +244,12 @@ package object dsl { AttributeReference(s, structType, nullable = true)() def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + + /** Create a function. */ + def function(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = false) + def distinctFunction(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = true) } implicit class DslAttribute(a: AttributeReference) { @@ -243,11 +262,33 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore + def table(ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref), None) + + def table(db: String, ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + def select(exprs: Expression*): LogicalPlan = { + val namedExpressions = exprs.map { + case e: NamedExpression => e + case e => UnresolvedAlias(e) + } + Project(namedExpressions, logicalPlan) + } def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def filter[T : Encoder](func: T => Boolean): LogicalPlan = { + val deserialized = logicalPlan.deserialize[T] + val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) + Filter(condition, deserialized).serialize[T] + } + + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) + + def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( @@ -296,6 +337,14 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) + def as(alias: String): LogicalPlan = logicalPlan match { + case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) + case plan => SubqueryAlias(alias, plan) + } + + def distribute(exprs: Expression*): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan) + def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } |