aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/pom.xml66
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala328
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala185
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala107
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala275
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala54
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala109
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala224
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala83
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala79
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala196
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala127
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala214
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala49
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala265
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala89
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala96
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala116
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala73
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala156
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala213
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala167
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala64
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala128
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala132
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala158
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala201
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala79
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala364
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala137
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala122
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala49
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala41
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala175
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala115
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala74
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala57
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala81
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala176
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala222
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala44
62 files changed, 6538 insertions, 0 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
new file mode 100644
index 0000000000..740f1fdc83
--- /dev/null
+++ b/sql/catalyst/pom.xml
@@ -0,0 +1,66 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Catalyst</name>
+ <url>http://spark.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.typesafe</groupId>
+ <artifactId>scalalogging-slf4j_${scala.binary.version}</artifactId>
+ <version>1.0.1</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
new file mode 100644
index 0000000000..d3b1070a58
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import scala.util.matching.Regex
+import scala.util.parsing.combinator._
+import scala.util.parsing.input.CharArrayReader.EofCh
+import lexical._
+import syntactical._
+import token._
+
+import analysis._
+import expressions._
+import plans._
+import plans.logical._
+import types._
+
+/**
+ * A very simple SQL parser. Based loosly on:
+ * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala
+ *
+ * Limitations:
+ * - Only supports a very limited subset of SQL.
+ * - Keywords must be capital.
+ *
+ * This is currently included mostly for illustrative purposes. Users wanting more complete support
+ * for a SQL like language should checkout the HiveQL support in the sql/hive subproject.
+ */
+class SqlParser extends StandardTokenParsers {
+
+ def apply(input: String): LogicalPlan = {
+ phrase(query)(new lexical.Scanner(input)) match {
+ case Success(r, x) => r
+ case x => sys.error(x.toString)
+ }
+ }
+
+ protected case class Keyword(str: String)
+ protected implicit def asParser(k: Keyword): Parser[String] = k.str
+
+ protected class SqlLexical extends StdLexical {
+ case class FloatLit(chars: String) extends Token {
+ override def toString = chars
+ }
+ override lazy val token: Parser[Token] = (
+ identChar ~ rep( identChar | digit ) ^^
+ { case first ~ rest => processIdent(first :: rest mkString "") }
+ | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
+ case i ~ None => NumericLit(i mkString "")
+ case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
+ }
+ | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
+ { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
+ | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
+ { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '\"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar = letter | elem('.') | elem('_')
+
+ override def whitespace: Parser[Any] = rep(
+ whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
+ | '#' ~ rep( chrExcept(EofCh, '\n') )
+ | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
+ | '/' ~ '*' ~ failure("unclosed comment")
+ )
+ }
+
+ override val lexical = new SqlLexical
+
+ protected val ALL = Keyword("ALL")
+ protected val AND = Keyword("AND")
+ protected val AS = Keyword("AS")
+ protected val ASC = Keyword("ASC")
+ protected val AVG = Keyword("AVG")
+ protected val BY = Keyword("BY")
+ protected val CAST = Keyword("CAST")
+ protected val COUNT = Keyword("COUNT")
+ protected val DESC = Keyword("DESC")
+ protected val DISTINCT = Keyword("DISTINCT")
+ protected val FALSE = Keyword("FALSE")
+ protected val FIRST = Keyword("FIRST")
+ protected val FROM = Keyword("FROM")
+ protected val FULL = Keyword("FULL")
+ protected val GROUP = Keyword("GROUP")
+ protected val HAVING = Keyword("HAVING")
+ protected val IF = Keyword("IF")
+ protected val IN = Keyword("IN")
+ protected val INNER = Keyword("INNER")
+ protected val IS = Keyword("IS")
+ protected val JOIN = Keyword("JOIN")
+ protected val LEFT = Keyword("LEFT")
+ protected val LIMIT = Keyword("LIMIT")
+ protected val NOT = Keyword("NOT")
+ protected val NULL = Keyword("NULL")
+ protected val ON = Keyword("ON")
+ protected val OR = Keyword("OR")
+ protected val ORDER = Keyword("ORDER")
+ protected val OUTER = Keyword("OUTER")
+ protected val RIGHT = Keyword("RIGHT")
+ protected val SELECT = Keyword("SELECT")
+ protected val STRING = Keyword("STRING")
+ protected val SUM = Keyword("SUM")
+ protected val TRUE = Keyword("TRUE")
+ protected val UNION = Keyword("UNION")
+ protected val WHERE = Keyword("WHERE")
+
+ // Use reflection to find the reserved words defined in this class.
+ protected val reservedWords =
+ this.getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword])
+
+ lexical.reserved ++= reservedWords.map(_.str)
+
+ lexical.delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":"
+ )
+
+ protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.zipWithIndex.map {
+ case (ne: NamedExpression, _) => ne
+ case (e, i) => Alias(e, s"c$i")()
+ }
+ }
+
+ protected lazy val query: Parser[LogicalPlan] =
+ select * (
+ UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
+ UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
+ )
+
+ protected lazy val select: Parser[LogicalPlan] =
+ SELECT ~> opt(DISTINCT) ~ projections ~
+ opt(from) ~ opt(filter) ~
+ opt(grouping) ~
+ opt(having) ~
+ opt(orderBy) ~
+ opt(limit) <~ opt(";") ^^ {
+ case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
+ val base = r.getOrElse(NoRelation)
+ val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
+ val withProjection =
+ g.map {g =>
+ Aggregate(assignAliases(g), assignAliases(p), withFilter)
+ }.getOrElse(Project(assignAliases(p), withFilter))
+ val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
+ val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
+ val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
+ val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
+ withLimit
+ }
+
+ protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
+
+ protected lazy val projection: Parser[Expression] =
+ expression ~ (opt(AS) ~> opt(ident)) ^^ {
+ case e ~ None => e
+ case e ~ Some(a) => Alias(e, a)()
+ }
+
+ protected lazy val from: Parser[LogicalPlan] = FROM ~> relations
+
+ // Based very loosly on the MySQL Grammar.
+ // http://dev.mysql.com/doc/refman/5.0/en/join.html
+ protected lazy val relations: Parser[LogicalPlan] =
+ relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } |
+ relation
+
+ protected lazy val relation: Parser[LogicalPlan] =
+ joinedRelation |
+ relationFactor
+
+ protected lazy val relationFactor: Parser[LogicalPlan] =
+ ident ~ (opt(AS) ~> opt(ident)) ^^ {
+ case ident ~ alias => UnresolvedRelation(alias, ident)
+ } |
+ "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) }
+
+ protected lazy val joinedRelation: Parser[LogicalPlan] =
+ relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ {
+ case r1 ~ jt ~ _ ~ r2 ~ cond =>
+ Join(r1, r2, joinType = jt.getOrElse(Inner), cond)
+ }
+
+ protected lazy val joinConditions: Parser[Expression] =
+ ON ~> expression
+
+ protected lazy val joinType: Parser[JoinType] =
+ INNER ^^^ Inner |
+ LEFT ~ opt(OUTER) ^^^ LeftOuter |
+ RIGHT ~ opt(OUTER) ^^^ RightOuter |
+ FULL ~ opt(OUTER) ^^^ FullOuter
+
+ protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e }
+
+ protected lazy val orderBy: Parser[Seq[SortOrder]] =
+ ORDER ~> BY ~> ordering
+
+ protected lazy val ordering: Parser[Seq[SortOrder]] =
+ rep1sep(singleOrder, ",") |
+ rep1sep(expression, ",") ~ opt(direction) ^^ {
+ case exps ~ None => exps.map(SortOrder(_, Ascending))
+ case exps ~ Some(d) => exps.map(SortOrder(_, d))
+ }
+
+ protected lazy val singleOrder: Parser[SortOrder] =
+ expression ~ direction ^^ { case e ~ o => SortOrder(e,o) }
+
+ protected lazy val direction: Parser[SortDirection] =
+ ASC ^^^ Ascending |
+ DESC ^^^ Descending
+
+ protected lazy val grouping: Parser[Seq[Expression]] =
+ GROUP ~> BY ~> rep1sep(expression, ",")
+
+ protected lazy val having: Parser[Expression] =
+ HAVING ~> expression
+
+ protected lazy val limit: Parser[Expression] =
+ LIMIT ~> expression
+
+ protected lazy val expression: Parser[Expression] = orExpression
+
+ protected lazy val orExpression: Parser[Expression] =
+ andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) })
+
+ protected lazy val andExpression: Parser[Expression] =
+ comparisionExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) })
+
+ protected lazy val comparisionExpression: Parser[Expression] =
+ termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } |
+ termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } |
+ termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } |
+ termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } |
+ termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } |
+ termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } |
+ termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } |
+ termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
+ case e1 ~ _ ~ _ ~ e2 => In(e1, e2)
+ } |
+ termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
+ case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2))
+ } |
+ termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } |
+ termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } |
+ NOT ~> termExpression ^^ {e => Not(e)} |
+ termExpression
+
+ protected lazy val termExpression: Parser[Expression] =
+ productExpression * (
+ "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } |
+ "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } )
+
+ protected lazy val productExpression: Parser[Expression] =
+ baseExpression * (
+ "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } |
+ "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } |
+ "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) }
+ )
+
+ protected lazy val function: Parser[Expression] =
+ SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } |
+ SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } |
+ COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
+ COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
+ COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
+ FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
+ AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
+ IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
+ case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
+ } |
+ ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
+ case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
+ }
+
+ protected lazy val cast: Parser[Expression] =
+ CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) }
+
+ protected lazy val literal: Parser[Literal] =
+ numericLit ^^ {
+ case i if i.toLong > Int.MaxValue => Literal(i.toLong)
+ case i => Literal(i.toInt)
+ } |
+ NULL ^^^ Literal(null, NullType) |
+ floatLit ^^ {case f => Literal(f.toDouble) } |
+ stringLit ^^ {case s => Literal(s, StringType) }
+
+ protected lazy val floatLit: Parser[String] =
+ elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
+
+ protected lazy val baseExpression: Parser[Expression] =
+ TRUE ^^^ Literal(true, BooleanType) |
+ FALSE ^^^ Literal(false, BooleanType) |
+ cast |
+ "(" ~> expression <~ ")" |
+ function |
+ "-" ~> literal ^^ UnaryMinus |
+ ident ^^ UnresolvedAttribute |
+ "*" ^^^ Star(None) |
+ literal
+
+ protected lazy val dataType: Parser[DataType] =
+ STRING ^^^ StringType
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
new file mode 100644
index 0000000000..9eb992ee58
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import expressions._
+import plans.logical._
+import rules._
+
+/**
+ * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
+ * when all relations are already filled in and the analyser needs only to resolve attribute
+ * references.
+ */
+object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true)
+
+/**
+ * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
+ * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
+ * a [[FunctionRegistry]].
+ */
+class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean)
+ extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
+
+ // TODO: pass this in as a parameter.
+ val fixedPoint = FixedPoint(100)
+
+ val batches: Seq[Batch] = Seq(
+ Batch("MultiInstanceRelations", Once,
+ NewRelationInstances),
+ Batch("CaseInsensitiveAttributeReferences", Once,
+ (if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*),
+ Batch("Resolution", fixedPoint,
+ ResolveReferences ::
+ ResolveRelations ::
+ NewRelationInstances ::
+ ImplicitGenerate ::
+ StarExpansion ::
+ ResolveFunctions ::
+ GlobalAggregates ::
+ typeCoercionRules :_*)
+ )
+
+ /**
+ * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
+ */
+ object ResolveRelations extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case UnresolvedRelation(databaseName, name, alias) =>
+ catalog.lookupRelation(databaseName, name, alias)
+ }
+ }
+
+ /**
+ * Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase.
+ */
+ object LowercaseAttributeReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case UnresolvedRelation(databaseName, name, alias) =>
+ UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase))
+ case Subquery(alias, child) => Subquery(alias.toLowerCase, child)
+ case q: LogicalPlan => q transformExpressions {
+ case s: Star => s.copy(table = s.table.map(_.toLowerCase))
+ case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase)
+ case Alias(c, name) => Alias(c, name.toLowerCase)()
+ }
+ }
+ }
+
+ /**
+ * Replaces [[UnresolvedAttribute]]s with concrete
+ * [[expressions.AttributeReference AttributeReferences]] from a logical plan node's children.
+ */
+ object ResolveReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case q: LogicalPlan if q.childrenResolved =>
+ logger.trace(s"Attempting to resolve ${q.simpleString}")
+ q transformExpressions {
+ case u @ UnresolvedAttribute(name) =>
+ // Leave unchanged if resolution fails. Hopefully will be resolved next round.
+ val result = q.resolve(name).getOrElse(u)
+ logger.debug(s"Resolving $u to $result")
+ result
+ }
+ }
+ }
+
+ /**
+ * Replaces [[UnresolvedFunction]]s with concrete [[expressions.Expression Expressions]].
+ */
+ object ResolveFunctions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan =>
+ q transformExpressions {
+ case u @ UnresolvedFunction(name, children) if u.childrenResolved =>
+ registry.lookupFunction(name, children)
+ }
+ }
+ }
+
+ /**
+ * Turns projections that contain aggregate expressions into aggregations.
+ */
+ object GlobalAggregates extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Project(projectList, child) if containsAggregates(projectList) =>
+ Aggregate(Nil, projectList, child)
+ }
+
+ def containsAggregates(exprs: Seq[Expression]): Boolean = {
+ exprs.foreach(_.foreach {
+ case agg: AggregateExpression => return true
+ case _ =>
+ })
+ false
+ }
+ }
+
+ /**
+ * When a SELECT clause has only a single expression and that expression is a
+ * [[catalyst.expressions.Generator Generator]] we convert the
+ * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]].
+ */
+ object ImplicitGenerate extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Project(Seq(Alias(g: Generator, _)), child) =>
+ Generate(g, join = false, outer = false, None, child)
+ }
+ }
+
+ /**
+ * Expands any references to [[Star]] (*) in project operators.
+ */
+ object StarExpansion extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // Wait until children are resolved
+ case p: LogicalPlan if !p.childrenResolved => p
+ // If the projection list contains Stars, expand it.
+ case p @ Project(projectList, child) if containsStar(projectList) =>
+ Project(
+ projectList.flatMap {
+ case s: Star => s.expand(child.output)
+ case o => o :: Nil
+ },
+ child)
+ case t: ScriptTransformation if containsStar(t.input) =>
+ t.copy(
+ input = t.input.flatMap {
+ case s: Star => s.expand(t.child.output)
+ case o => o :: Nil
+ }
+ )
+ // If the aggregate function argument contains Stars, expand it.
+ case a: Aggregate if containsStar(a.aggregateExpressions) =>
+ a.copy(
+ aggregateExpressions = a.aggregateExpressions.flatMap {
+ case s: Star => s.expand(a.child.output)
+ case o => o :: Nil
+ }
+ )
+ }
+
+ /**
+ * Returns true if `exprs` contains a [[Star]].
+ */
+ protected def containsStar(exprs: Seq[Expression]): Boolean =
+ exprs.collect { case _: Star => true }.nonEmpty
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
new file mode 100644
index 0000000000..71e4dcdb15
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import plans.logical.{LogicalPlan, Subquery}
+import scala.collection.mutable
+
+/**
+ * An interface for looking up relations by name. Used by an [[Analyzer]].
+ */
+trait Catalog {
+ def lookupRelation(
+ databaseName: Option[String],
+ tableName: String,
+ alias: Option[String] = None): LogicalPlan
+
+ def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
+}
+
+class SimpleCatalog extends Catalog {
+ val tables = new mutable.HashMap[String, LogicalPlan]()
+
+ def registerTable(databaseName: Option[String],tableName: String, plan: LogicalPlan): Unit = {
+ tables += ((tableName, plan))
+ }
+
+ def dropTable(tableName: String) = tables -= tableName
+
+ def lookupRelation(
+ databaseName: Option[String],
+ tableName: String,
+ alias: Option[String] = None): LogicalPlan = {
+ val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName"))
+
+ // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
+ // properly qualified with this alias.
+ alias.map(a => Subquery(a.toLowerCase, table)).getOrElse(table)
+ }
+}
+
+/**
+ * A trait that can be mixed in with other Catalogs allowing specific tables to be overridden with
+ * new logical plans. This can be used to bind query result to virtual tables, or replace tables
+ * with in-memory cached versions. Note that the set of overrides is stored in memory and thus
+ * lost when the JVM exits.
+ */
+trait OverrideCatalog extends Catalog {
+
+ // TODO: This doesn't work when the database changes...
+ val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
+
+ abstract override def lookupRelation(
+ databaseName: Option[String],
+ tableName: String,
+ alias: Option[String] = None): LogicalPlan = {
+
+ val overriddenTable = overrides.get((databaseName, tableName))
+
+ // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
+ // properly qualified with this alias.
+ val withAlias =
+ overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r))
+
+ withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias))
+ }
+
+ override def registerTable(
+ databaseName: Option[String],
+ tableName: String,
+ plan: LogicalPlan): Unit = {
+ overrides.put((databaseName, tableName), plan)
+ }
+}
+
+/**
+ * A trivial catalog that returns an error when a relation is requested. Used for testing when all
+ * relations are already filled in and the analyser needs only to resolve attribute references.
+ */
+object EmptyCatalog extends Catalog {
+ def lookupRelation(
+ databaseName: Option[String],
+ tableName: String,
+ alias: Option[String] = None) = {
+ throw new UnsupportedOperationException
+ }
+
+ def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
+ throw new UnsupportedOperationException
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
new file mode 100644
index 0000000000..a359eb5411
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import expressions._
+
+/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
+trait FunctionRegistry {
+ def lookupFunction(name: String, children: Seq[Expression]): Expression
+}
+
+/**
+ * A trivial catalog that returns an error when a function is requested. Used for testing when all
+ * functions are already filled in and the analyser needs only to resolve attribute references.
+ */
+object EmptyFunctionRegistry extends FunctionRegistry {
+ def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ throw new UnsupportedOperationException
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
new file mode 100644
index 0000000000..a0105cd7cf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import expressions._
+import plans.logical._
+import rules._
+import types._
+
+/**
+ * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that
+ * participate in operations into compatible ones. Most of these rules are based on Hive semantics,
+ * but they do not introduce any dependencies on the hive codebase. For this reason they remain in
+ * Catalyst until we have a more standard set of coercions.
+ */
+trait HiveTypeCoercion {
+
+ val typeCoercionRules =
+ List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts,
+ StringToIntegralCasts, FunctionArgumentConversion)
+
+ /**
+ * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] dataTypes
+ * that are made by other rules to instances higher in the query tree.
+ */
+ object PropagateTypes extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // No propagation required for leaf nodes.
+ case q: LogicalPlan if q.children.isEmpty => q
+
+ // Don't propagate types from unresolved children.
+ case q: LogicalPlan if !q.childrenResolved => q
+
+ case q: LogicalPlan => q transformExpressions {
+ case a: AttributeReference =>
+ q.inputSet.find(_.exprId == a.exprId) match {
+ // This can happen when a Attribute reference is born in a non-leaf node, for example
+ // due to a call to an external script like in the Transform operator.
+ // TODO: Perhaps those should actually be aliases?
+ case None => a
+ // Leave the same if the dataTypes match.
+ case Some(newType) if a.dataType == newType.dataType => a
+ case Some(newType) =>
+ logger.debug(s"Promoting $a to $newType in ${q.simpleString}}")
+ newType
+ }
+ }
+ }
+ }
+
+ /**
+ * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to
+ * the appropriate numeric equivalent.
+ */
+ object ConvertNaNs extends Rule[LogicalPlan] {
+ val stringNaN = Literal("NaN", StringType)
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ /* Double Conversions */
+ case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
+ b.makeCopy(Array(b.right, Literal(Double.NaN)))
+ case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
+ b.makeCopy(Array(Literal(Double.NaN), b.left))
+ case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
+ b.makeCopy(Array(Literal(Double.NaN), b.left))
+
+ /* Float Conversions */
+ case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
+ b.makeCopy(Array(b.right, Literal(Float.NaN)))
+ case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
+ b.makeCopy(Array(Literal(Float.NaN), b.left))
+ case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
+ b.makeCopy(Array(Literal(Float.NaN), b.left))
+ }
+ }
+ }
+
+ /**
+ * Widens numeric types and converts strings to numbers when appropriate.
+ *
+ * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White
+ *
+ * The implicit conversion rules can be summarized as follows:
+ * - Any integral numeric type can be implicitly converted to a wider type.
+ * - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be implicitly
+ * converted to DOUBLE.
+ * - TINYINT, SMALLINT, and INT can all be converted to FLOAT.
+ * - BOOLEAN types cannot be converted to any other type.
+ *
+ * Additionally, all types when UNION-ed with strings will be promoted to strings.
+ * Other string conversions are handled by PromoteStrings.
+ */
+ object WidenTypes extends Rule[LogicalPlan] {
+ // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
+ // The conversion for integral and floating point types have a linear widening hierarchy:
+ val numericPrecedence =
+ Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
+ // Boolean is only wider than Void
+ val booleanPrecedence = Seq(NullType, BooleanType)
+ val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
+
+ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
+ // Try and find a promotion rule that contains both types in question.
+ val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2))
+
+ // If found return the widest common type, otherwise None
+ applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+ val castedInput = left.output.zip(right.output).map {
+ // When a string is found on one side, make the other side a string too.
+ case (l, r) if l.dataType == StringType && r.dataType != StringType =>
+ (l, Alias(Cast(r, StringType), r.name)())
+ case (l, r) if l.dataType != StringType && r.dataType == StringType =>
+ (Alias(Cast(l, StringType), l.name)(), r)
+
+ case (l, r) if l.dataType != r.dataType =>
+ logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
+ findTightestCommonType(l.dataType, r.dataType).map { widestType =>
+ val newLeft =
+ if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
+ val newRight =
+ if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)()
+
+ (newLeft, newRight)
+ }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged.
+ case other => other
+ }
+
+ val (castedLeft, castedRight) = castedInput.unzip
+
+ val newLeft =
+ if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+ logger.debug(s"Widening numeric types in union $castedLeft ${left.output}")
+ Project(castedLeft, left)
+ } else {
+ left
+ }
+
+ val newRight =
+ if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+ logger.debug(s"Widening numeric types in union $castedRight ${right.output}")
+ Project(castedRight, right)
+ } else {
+ right
+ }
+
+ Union(newLeft, newRight)
+
+ // Also widen types for BinaryExpressions.
+ case q: LogicalPlan => q transformExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case b: BinaryExpression if b.left.dataType != b.right.dataType =>
+ findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
+ val newLeft =
+ if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
+ val newRight =
+ if (b.right.dataType == widestType) b.right else Cast(b.right, widestType)
+ b.makeCopy(Array(newLeft, newRight))
+ }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
+ }
+ }
+ }
+
+ /**
+ * Promotes strings that appear in arithmetic expressions.
+ */
+ object PromoteStrings extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case a: BinaryArithmetic if a.left.dataType == StringType =>
+ a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
+ case a: BinaryArithmetic if a.right.dataType == StringType =>
+ a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
+
+ case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
+ p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
+ case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
+ p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
+
+ case Sum(e) if e.dataType == StringType =>
+ Sum(Cast(e, DoubleType))
+ case Average(e) if e.dataType == StringType =>
+ Average(Cast(e, DoubleType))
+ }
+ }
+
+ /**
+ * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
+ */
+ object BooleanComparisons extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+ // No need to change Equals operators as that actually makes sense for boolean types.
+ case e: Equals => e
+ // Otherwise turn them to Byte types so that there exists and ordering.
+ case p: BinaryComparison
+ if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
+ p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
+ }
+ }
+
+ /**
+ * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since
+ * the JVM does not consider Booleans to be numeric types.
+ */
+ object BooleanCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case Cast(e, BooleanType) => Not(Equals(e, Literal(0)))
+ case Cast(e, dataType) if e.dataType == BooleanType =>
+ Cast(If(e, Literal(1), Literal(0)), dataType)
+ }
+ }
+
+ /**
+ * When encountering a cast from a string representing a valid fractional number to an integral
+ * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the
+ * truncated version of this number.
+ */
+ object StringToIntegralCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ case Cast(e @ StringType(), t: IntegralType) =>
+ Cast(Cast(e, DecimalType), t)
+ }
+ }
+
+ /**
+ * This ensure that the types for various functions are as expected.
+ */
+ object FunctionArgumentConversion extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ // Promote SUM to largest types to prevent overflows.
+ case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
+ case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
+ case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
new file mode 100644
index 0000000000..fe18cc466f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+package analysis
+
+import plans.logical.LogicalPlan
+import rules._
+
+/**
+ * A trait that should be mixed into query operators where an single instance might appear multiple
+ * times in a logical query plan. It is invalid to have multiple copies of the same attribute
+ * produced by distinct operators in a query tree as this breaks the gurantee that expression
+ * ids, which are used to differentate attributes, are unique.
+ *
+ * Before analysis, all operators that include this trait will be asked to produce a new version
+ * of itself with globally unique expression ids.
+ */
+trait MultiInstanceRelation {
+ def newInstance: this.type
+}
+
+/**
+ * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so
+ * that each instance has unique expression ids for the attributes produced.
+ */
+object NewRelationInstances extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val localRelations = plan collect { case l: MultiInstanceRelation => l}
+ val multiAppearance = localRelations
+ .groupBy(identity[MultiInstanceRelation])
+ .filter { case (_, ls) => ls.size > 1 }
+ .map(_._1)
+ .toSet
+
+ plan transform {
+ case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
new file mode 100644
index 0000000000..375c99f48e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package catalyst
+
+/**
+ * Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis.
+ * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
+ * into fully typed objects using information in a schema [[Catalog]].
+ */
+package object analysis
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
new file mode 100644
index 0000000000..2ed2af1352
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import expressions._
+import plans.logical.BaseRelation
+import trees.TreeNode
+
+/**
+ * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully
+ * resolved.
+ */
+class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) extends
+ errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null)
+
+/**
+ * Holds the name of a relation that has yet to be looked up in a [[Catalog]].
+ */
+case class UnresolvedRelation(
+ databaseName: Option[String],
+ tableName: String,
+ alias: Option[String] = None) extends BaseRelation {
+ def output = Nil
+ override lazy val resolved = false
+}
+
+/**
+ * Holds the name of an attribute that has yet to be resolved.
+ */
+case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
+ def exprId = throw new UnresolvedException(this, "exprId")
+ def dataType = throw new UnresolvedException(this, "dataType")
+ def nullable = throw new UnresolvedException(this, "nullable")
+ def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override lazy val resolved = false
+
+ def newInstance = this
+ def withQualifiers(newQualifiers: Seq[String]) = this
+
+ override def toString: String = s"'$name"
+}
+
+case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
+ def exprId = throw new UnresolvedException(this, "exprId")
+ def dataType = throw new UnresolvedException(this, "dataType")
+ override def foldable = throw new UnresolvedException(this, "foldable")
+ def nullable = throw new UnresolvedException(this, "nullable")
+ def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ def references = children.flatMap(_.references).toSet
+ override lazy val resolved = false
+ override def toString = s"'$name(${children.mkString(",")})"
+}
+
+/**
+ * Represents all of the input attributes to a given relational operator, for example in
+ * "SELECT * FROM ...".
+ *
+ * @param table an optional table that should be the target of the expansion. If omitted all
+ * tables' columns are produced.
+ */
+case class Star(
+ table: Option[String],
+ mapFunction: Attribute => Expression = identity[Attribute])
+ extends Attribute with trees.LeafNode[Expression] {
+
+ def name = throw new UnresolvedException(this, "exprId")
+ def exprId = throw new UnresolvedException(this, "exprId")
+ def dataType = throw new UnresolvedException(this, "dataType")
+ def nullable = throw new UnresolvedException(this, "nullable")
+ def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override lazy val resolved = false
+
+ def newInstance = this
+ def withQualifiers(newQualifiers: Seq[String]) = this
+
+ def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
+ val expandedAttributes: Seq[Attribute] = table match {
+ // If there is no table specified, use all input attributes.
+ case None => input
+ // If there is a table, pick out attributes that are part of this table.
+ case Some(table) => input.filter(_.qualifiers contains table)
+ }
+ val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
+ case (n: NamedExpression, _) => n
+ case (e, originalAttribute) =>
+ Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
+ }
+ mappedAttributes
+ }
+
+ override def toString = table.map(_ + ".").getOrElse("") + "*"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
new file mode 100644
index 0000000000..cd8de9d52f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import analysis.UnresolvedAttribute
+import expressions._
+import plans._
+import plans.logical._
+import types._
+
+/**
+ * Provides experimental support for generating catalyst schemas for scala objects.
+ */
+object ScalaReflection {
+ import scala.reflect.runtime.universe._
+
+ /** Returns a Sequence of attributes for the given case class type. */
+ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
+ case s: StructType =>
+ s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)())
+ }
+
+ /** Returns a catalyst DataType for the given Scala Type using reflection. */
+ def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T])
+
+ /** Returns a catalyst DataType for the given Scala Type using reflection. */
+ def schemaFor(tpe: `Type`): DataType = tpe match {
+ case t if t <:< typeOf[Product] =>
+ val params = t.member("<init>": TermName).asMethod.paramss
+ StructType(
+ params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
+ case t if t <:< typeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ ArrayType(schemaFor(elementType))
+ case t if t <:< typeOf[String] => StringType
+ case t if t <:< definitions.IntTpe => IntegerType
+ case t if t <:< definitions.LongTpe => LongType
+ case t if t <:< definitions.DoubleTpe => DoubleType
+ case t if t <:< definitions.ShortTpe => ShortType
+ case t if t <:< definitions.ByteTpe => ByteType
+ }
+
+ implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
+
+ /**
+ * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation
+ * for the the data in the sequence.
+ */
+ def asRelation: LocalRelation = {
+ val output = attributesFor[A]
+ LocalRelation(output, data)
+ }
+ }
+}
+
+/**
+ * A collection of implicit conversions that create a DSL for constructing catalyst data structures.
+ *
+ * {{{
+ * scala> import catalyst.dsl._
+ *
+ * // Standard operators are added to expressions.
+ * scala> Literal(1) + Literal(1)
+ * res1: catalyst.expressions.Add = (1 + 1)
+ *
+ * // There is a conversion from 'symbols to unresolved attributes.
+ * scala> 'a.attr
+ * res2: catalyst.analysis.UnresolvedAttribute = 'a
+ *
+ * // These unresolved attributes can be used to create more complicated expressions.
+ * scala> 'a === 'b
+ * res3: catalyst.expressions.Equals = ('a = 'b)
+ *
+ * // SQL verbs can be used to construct logical query plans.
+ * scala> TestRelation('key.int, 'value.string).where('key === 1).select('value).analyze
+ * res4: catalyst.plans.logical.LogicalPlan =
+ * Project {value#1}
+ * Filter (key#0 = 1)
+ * TestRelation {key#0,value#1}
+ * }}}
+ */
+package object dsl {
+ trait ImplicitOperators {
+ def expr: Expression
+
+ def + (other: Expression) = Add(expr, other)
+ def - (other: Expression) = Subtract(expr, other)
+ def * (other: Expression) = Multiply(expr, other)
+ def / (other: Expression) = Divide(expr, other)
+
+ def && (other: Expression) = And(expr, other)
+ def || (other: Expression) = Or(expr, other)
+
+ def < (other: Expression) = LessThan(expr, other)
+ def <= (other: Expression) = LessThanOrEqual(expr, other)
+ def > (other: Expression) = GreaterThan(expr, other)
+ def >= (other: Expression) = GreaterThanOrEqual(expr, other)
+ def === (other: Expression) = Equals(expr, other)
+ def != (other: Expression) = Not(Equals(expr, other))
+
+ def asc = SortOrder(expr, Ascending)
+ def desc = SortOrder(expr, Descending)
+
+ def as(s: Symbol) = Alias(expr, s.name)()
+ }
+
+ trait ExpressionConversions {
+ implicit class DslExpression(e: Expression) extends ImplicitOperators {
+ def expr = e
+ }
+
+ implicit def intToLiteral(i: Int) = Literal(i)
+ implicit def longToLiteral(l: Long) = Literal(l)
+ implicit def floatToLiteral(f: Float) = Literal(f)
+ implicit def doubleToLiteral(d: Double) = Literal(d)
+ implicit def stringToLiteral(s: String) = Literal(s)
+
+ implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
+
+ implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
+ implicit class DslString(val s: String) extends ImplicitAttribute
+
+ abstract class ImplicitAttribute extends ImplicitOperators {
+ def s: String
+ def expr = attr
+ def attr = analysis.UnresolvedAttribute(s)
+
+ /** Creates a new typed attributes of type int */
+ def int = AttributeReference(s, IntegerType, nullable = false)()
+
+ /** Creates a new typed attributes of type string */
+ def string = AttributeReference(s, StringType, nullable = false)()
+ }
+
+ implicit class DslAttribute(a: AttributeReference) {
+ def notNull = a.withNullability(false)
+ def nullable = a.withNullability(true)
+
+ // Protobuf terminology
+ def required = a.withNullability(false)
+ }
+ }
+
+
+ object expressions extends ExpressionConversions // scalastyle:ignore
+
+ abstract class LogicalPlanFunctions {
+ def logicalPlan: LogicalPlan
+
+ def select(exprs: NamedExpression*) = Project(exprs, logicalPlan)
+
+ def where(condition: Expression) = Filter(condition, logicalPlan)
+
+ def join(
+ otherPlan: LogicalPlan,
+ joinType: JoinType = Inner,
+ condition: Option[Expression] = None) =
+ Join(logicalPlan, otherPlan, joinType, condition)
+
+ def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan)
+
+ def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = {
+ val aliasedExprs = aggregateExprs.map {
+ case ne: NamedExpression => ne
+ case e => Alias(e, e.toString)()
+ }
+ Aggregate(groupingExprs, aliasedExprs, logicalPlan)
+ }
+
+ def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan)
+
+ def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan)
+
+ def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
+ Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
+
+ def sfilter(dynamicUdf: (DynamicRow) => Boolean) =
+ Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)
+
+ def sample(
+ fraction: Double,
+ withReplacement: Boolean = true,
+ seed: Int = (math.random * 1000).toInt) =
+ Sample(fraction, withReplacement, seed, logicalPlan)
+
+ def generate(
+ generator: Generator,
+ join: Boolean = false,
+ outer: Boolean = false,
+ alias: Option[String] = None) =
+ Generate(generator, join, outer, None, logicalPlan)
+
+ def insertInto(tableName: String, overwrite: Boolean = false) =
+ InsertIntoTable(
+ analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)
+
+ def analyze = analysis.SimpleAnalyzer(logicalPlan)
+ }
+
+ object plans { // scalastyle:ignore
+ implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions {
+ def writeToFile(path: String) = WriteToFile(path, logicalPlan)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
new file mode 100644
index 0000000000..c253587f67
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+import trees._
+
+/**
+ * Functions for attaching and retrieving trees that are associated with errors.
+ */
+package object errors {
+
+ class TreeNodeException[TreeType <: TreeNode[_]]
+ (tree: TreeType, msg: String, cause: Throwable) extends Exception(msg, cause) {
+
+ // Yes, this is the same as a default parameter, but... those don't seem to work with SBT
+ // external project dependencies for some reason.
+ def this(tree: TreeType, msg: String) = this(tree, msg, null)
+
+ override def getMessage: String = {
+ val treeString = tree.toString
+ s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
+ }
+ }
+
+ /**
+ * Wraps any exceptions that are thrown while executing `f` in a
+ * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`.
+ */
+ def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = {
+ try f catch {
+ case e: Exception => throw new TreeNodeException(tree, msg, e)
+ }
+ }
+
+ /**
+ * Executes `f` which is expected to throw a
+ * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in
+ * the stack of exceptions of type `TreeType` is returned.
+ */
+ def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
new file mode 100644
index 0000000000..3b6bac16ff
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import rules._
+import errors._
+
+import catalyst.plans.QueryPlan
+
+/**
+ * A bound reference points to a specific slot in the input tuple, allowing the actual value
+ * to be retrieved more efficiently. However, since operations like column pruning can change
+ * the layout of intermediate tuples, BindReferences should be run after all such transformations.
+ */
+case class BoundReference(ordinal: Int, baseReference: Attribute)
+ extends Attribute with trees.LeafNode[Expression] {
+
+ type EvaluatedType = Any
+
+ def nullable = baseReference.nullable
+ def dataType = baseReference.dataType
+ def exprId = baseReference.exprId
+ def qualifiers = baseReference.qualifiers
+ def name = baseReference.name
+
+ def newInstance = BoundReference(ordinal, baseReference.newInstance)
+ def withQualifiers(newQualifiers: Seq[String]) =
+ BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
+
+ override def toString = s"$baseReference:$ordinal"
+
+ override def apply(input: Row): Any = input(ordinal)
+}
+
+class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
+ import BindReferences._
+
+ def apply(plan: TreeNode): TreeNode = {
+ plan.transform {
+ case leafNode if leafNode.children.isEmpty => leafNode
+ case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
+ bindReference(e, unaryNode.children.head.output)
+ }
+ }
+ }
+}
+
+object BindReferences extends Logging {
+ def bindReference(expression: Expression, input: Seq[Attribute]): Expression = {
+ expression.transform { case a: AttributeReference =>
+ attachTree(a, "Binding attribute") {
+ val ordinal = input.indexWhere(_.exprId == a.exprId)
+ if (ordinal == -1) {
+ // TODO: This fallback is required because some operators (such as ScriptTransform)
+ // produce new attributes that can't be bound. Likely the right thing to do is remove
+ // this rule and require all operators to explicitly bind to the input schema that
+ // they specify.
+ logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+ a
+ } else {
+ BoundReference(ordinal, a)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
new file mode 100644
index 0000000000..608656d3a9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+
+/** Cast the child expression to the target data type. */
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"CAST($child, $dataType)"
+
+ type EvaluatedType = Any
+
+ lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
+ case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
+ case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
+ case (_, StringType) => a: Any => a.toString
+ case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
+ case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
+ case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
+ case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
+ case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
+ case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
+ case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
+ case (BooleanType, ByteType) => a: Any => a match {
+ case null => null
+ case true => 1.toByte
+ case false => 0.toByte
+ }
+ case (dt, IntegerType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
+ case (dt, DoubleType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
+ case (dt, FloatType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
+ case (dt, LongType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
+ case (dt, ShortType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
+ case (dt, ByteType) =>
+ a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
+ case (dt, DecimalType) =>
+ a: Any =>
+ BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
+ }
+
+ @inline
+ protected def castOrNull[A](a: Any, f: String => A) =
+ try f(a.asInstanceOf[String]) catch {
+ case _: java.lang.NumberFormatException => null
+ }
+
+ override def apply(input: Row): Any = {
+ val evaluated = child.apply(input)
+ if (evaluated == null) {
+ null
+ } else {
+ castingFunction(evaluated)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
new file mode 100644
index 0000000000..78aaaeebbd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import errors._
+import trees._
+import types._
+
+abstract class Expression extends TreeNode[Expression] {
+ self: Product =>
+
+ /** The narrowest possible type that is produced when this expression is evaluated. */
+ type EvaluatedType <: Any
+
+ def dataType: DataType
+
+ /**
+ * Returns true when an expression is a candidate for static evaluation before the query is
+ * executed.
+ *
+ * The following conditions are used to determine suitability for constant folding:
+ * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable
+ * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right
+ * child are foldable
+ * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or
+ * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable.
+ * - A [[expressions.Literal]] is foldable.
+ * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
+ * child is foldable.
+ */
+ // TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
+ def foldable: Boolean = false
+ def nullable: Boolean
+ def references: Set[Attribute]
+
+ /** Returns the result of evaluating this expression on a given input Row */
+ def apply(input: Row = null): EvaluatedType =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
+ /**
+ * Returns `true` if this expression and all its children have been resolved to a specific schema
+ * and `false` if it is still contains any unresolved placeholders. Implementations of expressions
+ * should override this if the resolution of this type of expression involves more than just
+ * the resolution of its children.
+ */
+ lazy val resolved: Boolean = childrenResolved
+
+ /**
+ * Returns true if all the children of this expression have been resolved to a specific schema
+ * and false if any still contains any unresolved placeholders.
+ */
+ def childrenResolved = !children.exists(!_.resolved)
+
+ /**
+ * A set of helper functions that return the correct descendant of [[scala.math.Numeric]] type
+ * and do any casting necessary of child evaluation.
+ */
+ @inline
+ def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = {
+ val evalE = e.apply(i)
+ if (evalE == null) {
+ null
+ } else {
+ e.dataType match {
+ case n: NumericType =>
+ val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType]
+ castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType])
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+ }
+ }
+
+ @inline
+ protected final def n2(
+ i: Row,
+ e1: Expression,
+ e2: Expression,
+ f: ((Numeric[Any], Any, Any) => Any)): Any = {
+
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.apply(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = e2.apply(i)
+ if (evalE2 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case n: NumericType =>
+ f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => Int](
+ n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType])
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+ }
+ }
+ }
+
+ @inline
+ protected final def f2(
+ i: Row,
+ e1: Expression,
+ e2: Expression,
+ f: ((Fractional[Any], Any, Any) => Any)): Any = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.apply(i: Row)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = e2.apply(i: Row)
+ if (evalE2 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case ft: FractionalType =>
+ f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType](
+ ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType])
+ case other => sys.error(s"Type $other does not support fractional operations")
+ }
+ }
+ }
+ }
+
+ @inline
+ protected final def i2(
+ i: Row,
+ e1: Expression,
+ e2: Expression,
+ f: ((Integral[Any], Any, Any) => Any)): Any = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.apply(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = e2.apply(i)
+ if (evalE2 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case i: IntegralType =>
+ f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
+ i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+ }
+ }
+ }
+}
+
+abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
+ self: Product =>
+
+ def symbol: String
+
+ override def foldable = left.foldable && right.foldable
+
+ def references = left.references ++ right.references
+
+ override def toString = s"($left $symbol $right)"
+}
+
+abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
+ self: Product =>
+}
+
+abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
+ self: Product =>
+
+ def references = child.references
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
new file mode 100644
index 0000000000..8c407d2fdd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+package expressions
+
+/**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound to
+ * that schema.
+ */
+class Projection(expressions: Seq[Expression]) extends (Row => Row) {
+ def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
+ this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+
+ protected val exprArray = expressions.toArray
+ def apply(input: Row): Row = {
+ val outputArray = new Array[Any](exprArray.size)
+ var i = 0
+ while (i < exprArray.size) {
+ outputArray(i) = exprArray(i).apply(input)
+ i += 1
+ }
+ new GenericRow(outputArray)
+ }
+}
+
+/**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of th
+ * new row. If the schema of the input row is specified, then the given expression will be bound to
+ * that schema.
+ *
+ * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
+ * each time an input row is added. This significatly reduces the cost of calcuating the
+ * projection, but means that it is not safe
+ */
+case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
+ def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
+ this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+
+ private[this] val exprArray = expressions.toArray
+ private[this] val mutableRow = new GenericMutableRow(exprArray.size)
+ def currentValue: Row = mutableRow
+
+ def apply(input: Row): Row = {
+ var i = 0
+ while (i < exprArray.size) {
+ mutableRow(i) = exprArray(i).apply(input)
+ i += 1
+ }
+ mutableRow
+ }
+}
+
+/**
+ * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to
+ * be instantiated once per thread and reused.
+ */
+class JoinedRow extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) = apply(i) == null
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
new file mode 100644
index 0000000000..a5d0ecf964
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types.DoubleType
+
+case object Rand extends LeafExpression {
+ def dataType = DoubleType
+ def nullable = false
+ def references = Set.empty
+ override def toString = "RAND()"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
new file mode 100644
index 0000000000..3529675468
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+
+/**
+ * Represents one row of output from a relational operator. Allows both generic access by ordinal,
+ * which will incur boxing overhead for primitives, as well as native primitive access.
+ *
+ * It is invalid to use the native primitive interface to retrieve a value that is null, instead a
+ * user must check [[isNullAt]] before attempting to retrieve a value that might be null.
+ */
+trait Row extends Seq[Any] with Serializable {
+ def apply(i: Int): Any
+
+ def isNullAt(i: Int): Boolean
+
+ def getInt(i: Int): Int
+ def getLong(i: Int): Long
+ def getDouble(i: Int): Double
+ def getFloat(i: Int): Float
+ def getBoolean(i: Int): Boolean
+ def getShort(i: Int): Short
+ def getByte(i: Int): Byte
+ def getString(i: Int): String
+
+ override def toString() =
+ s"[${this.mkString(",")}]"
+
+ def copy(): Row
+}
+
+/**
+ * An extended interface to [[Row]] that allows the values for each column to be updated. Setting
+ * a value through a primitive function implicitly marks that column as not null.
+ */
+trait MutableRow extends Row {
+ def setNullAt(i: Int): Unit
+
+ def update(ordinal: Int, value: Any)
+
+ def setInt(ordinal: Int, value: Int)
+ def setLong(ordinal: Int, value: Long)
+ def setDouble(ordinal: Int, value: Double)
+ def setBoolean(ordinal: Int, value: Boolean)
+ def setShort(ordinal: Int, value: Short)
+ def setByte(ordinal: Int, value: Byte)
+ def setFloat(ordinal: Int, value: Float)
+ def setString(ordinal: Int, value: String)
+
+ /**
+ * EXPERIMENTAL
+ *
+ * Returns a mutable string builder for the specified column. A given row should return the
+ * result of any mutations made to the returned buffer next time getString is called for the same
+ * column.
+ */
+ def getStringBuilder(ordinal: Int): StringBuilder
+}
+
+/**
+ * A row with no data. Calling any methods will result in an error. Can be used as a placeholder.
+ */
+object EmptyRow extends Row {
+ def apply(i: Int): Any = throw new UnsupportedOperationException
+
+ def iterator = Iterator.empty
+ def length = 0
+ def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
+
+ def getInt(i: Int): Int = throw new UnsupportedOperationException
+ def getLong(i: Int): Long = throw new UnsupportedOperationException
+ def getDouble(i: Int): Double = throw new UnsupportedOperationException
+ def getFloat(i: Int): Float = throw new UnsupportedOperationException
+ def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException
+ def getShort(i: Int): Short = throw new UnsupportedOperationException
+ def getByte(i: Int): Byte = throw new UnsupportedOperationException
+ def getString(i: Int): String = throw new UnsupportedOperationException
+
+ def copy() = this
+}
+
+/**
+ * A row implementation that uses an array of objects as the underlying storage. Note that, while
+ * the array is not copied, and thus could technically be mutated after creation, this is not
+ * allowed.
+ */
+class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
+ /** No-arg constructor for serialization. */
+ def this() = this(null)
+
+ def this(size: Int) = this(new Array[Any](size))
+
+ def iterator = values.iterator
+
+ def length = values.length
+
+ def apply(i: Int) = values(i)
+
+ def isNullAt(i: Int) = values(i) == null
+
+ def getInt(i: Int): Int = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
+ values(i).asInstanceOf[Int]
+ }
+
+ def getLong(i: Int): Long = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
+ values(i).asInstanceOf[Long]
+ }
+
+ def getDouble(i: Int): Double = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
+ values(i).asInstanceOf[Double]
+ }
+
+ def getFloat(i: Int): Float = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
+ values(i).asInstanceOf[Float]
+ }
+
+ def getBoolean(i: Int): Boolean = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
+ values(i).asInstanceOf[Boolean]
+ }
+
+ def getShort(i: Int): Short = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
+ values(i).asInstanceOf[Short]
+ }
+
+ def getByte(i: Int): Byte = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
+ values(i).asInstanceOf[Byte]
+ }
+
+ def getString(i: Int): String = {
+ if (values(i) == null) sys.error("Failed to check null bit for primitive String value.")
+ values(i).asInstanceOf[String]
+ }
+
+ def copy() = this
+}
+
+class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
+ /** No-arg constructor for serialization. */
+ def this() = this(0)
+
+ def getStringBuilder(ordinal: Int): StringBuilder = ???
+
+ override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
+ override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
+ override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
+ override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value }
+ override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
+ override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
+ override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
+
+ override def setNullAt(i: Int): Unit = { values(i) = null }
+
+ override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value }
+
+ override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value }
+
+ override def copy() = new GenericRow(values.clone())
+}
+
+
+class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
+ def compare(a: Row, b: Row): Int = {
+ var i = 0
+ while (i < ordering.size) {
+ val order = ordering(i)
+ val left = order.child.apply(a)
+ val right = order.child.apply(b)
+
+ if (left == null && right == null) {
+ // Both null, continue looking.
+ } else if (left == null) {
+ return if (order.direction == Ascending) -1 else 1
+ } else if (right == null) {
+ return if (order.direction == Ascending) 1 else -1
+ } else {
+ val comparison = order.dataType match {
+ case n: NativeType if order.direction == Ascending =>
+ n.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
+ case n: NativeType if order.direction == Descending =>
+ n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+ }
+ if (comparison != 0) return comparison
+ }
+ i += 1
+ }
+ return 0
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
new file mode 100644
index 0000000000..a3c7ca1acd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+
+case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
+ extends Expression {
+
+ type EvaluatedType = Any
+
+ def references = children.flatMap(_.references).toSet
+ def nullable = true
+
+ override def apply(input: Row): Any = {
+ children.size match {
+ case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input))
+ case 2 =>
+ function.asInstanceOf[(Any, Any) => Any](
+ children(0).apply(input),
+ children(1).apply(input))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
new file mode 100644
index 0000000000..171997b90e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+abstract sealed class SortDirection
+case object Ascending extends SortDirection
+case object Descending extends SortDirection
+
+/**
+ * An expression that can be used to sort a tuple. This class extends expression primarily so that
+ * transformations over expression will descend into its child.
+ */
+case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
+ def dataType = child.dataType
+ def nullable = child.nullable
+ override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
new file mode 100644
index 0000000000..2ad8d6f31d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import scala.language.dynamics
+
+import types._
+
+case object DynamicType extends DataType
+
+case class WrapDynamic(children: Seq[Attribute]) extends Expression {
+ type EvaluatedType = DynamicRow
+
+ def nullable = false
+ def references = children.toSet
+ def dataType = DynamicType
+
+ override def apply(input: Row): DynamicRow = input match {
+ // Avoid copy for generic rows.
+ case g: GenericRow => new DynamicRow(children, g.values)
+ case otherRowType => new DynamicRow(children, otherRowType.toArray)
+ }
+}
+
+class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
+ extends GenericRow(values) with Dynamic {
+
+ def selectDynamic(attributeName: String): String = {
+ val ordinal = schema.indexWhere(_.name == attributeName)
+ values(ordinal).toString
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
new file mode 100644
index 0000000000..2287a849e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.types._
+
+abstract class AggregateExpression extends Expression {
+ self: Product =>
+
+ /**
+ * Creates a new instance that can be used to compute this aggregate expression for a group
+ * of input rows/
+ */
+ def newInstance: AggregateFunction
+}
+
+/**
+ * Represents an aggregation that has been rewritten to be performed in two steps.
+ *
+ * @param finalEvaluation an aggregate expression that evaluates to same final result as the
+ * original aggregation.
+ * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial
+ * data sets and are required to compute the `finalEvaluation`.
+ */
+case class SplitEvaluation(
+ finalEvaluation: Expression,
+ partialEvaluations: Seq[NamedExpression])
+
+/**
+ * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples.
+ * These partial evaluations can then be combined to compute the actual answer.
+ */
+abstract class PartialAggregate extends AggregateExpression {
+ self: Product =>
+
+ /**
+ * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
+ */
+ def asPartial: SplitEvaluation
+}
+
+/**
+ * A specific implementation of an aggregate function. Used to wrap a generic
+ * [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
+ */
+abstract class AggregateFunction
+ extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
+ self: Product =>
+
+ type EvaluatedType = Any
+
+ /** Base should return the generic aggregate expression that this function is computing */
+ val base: AggregateExpression
+ def references = base.references
+ def nullable = base.nullable
+ def dataType = base.dataType
+
+ def update(input: Row): Unit
+ override def apply(input: Row): Any
+
+ // Do we really need this?
+ def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+}
+
+case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ def references = child.references
+ def nullable = false
+ def dataType = IntegerType
+ override def toString = s"COUNT($child)"
+
+ def asPartial: SplitEvaluation = {
+ val partialCount = Alias(Count(child), "PartialCount")()
+ SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
+ }
+
+ override def newInstance = new CountFunction(child, this)
+}
+
+case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
+ def children = expressions
+ def references = expressions.flatMap(_.references).toSet
+ def nullable = false
+ def dataType = IntegerType
+ override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
+ override def newInstance = new CountDistinctFunction(expressions, this)
+}
+
+case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ def references = child.references
+ def nullable = false
+ def dataType = DoubleType
+ override def toString = s"AVG($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ val partialCount = Alias(Count(child), "PartialCount")()
+ val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
+ val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
+
+ SplitEvaluation(
+ Divide(castedSum, castedCount),
+ partialCount :: partialSum :: Nil)
+ }
+
+ override def newInstance = new AverageFunction(child, this)
+}
+
+case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ def references = child.references
+ def nullable = false
+ def dataType = child.dataType
+ override def toString = s"SUM($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ SplitEvaluation(
+ Sum(partialSum.toAttribute),
+ partialSum :: Nil)
+ }
+
+ override def newInstance = new SumFunction(child, this)
+}
+
+case class SumDistinct(child: Expression)
+ extends AggregateExpression with trees.UnaryNode[Expression] {
+
+ def references = child.references
+ def nullable = false
+ def dataType = child.dataType
+ override def toString = s"SUM(DISTINCT $child)"
+
+ override def newInstance = new SumDistinctFunction(child, this)
+}
+
+case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ def references = child.references
+ def nullable = child.nullable
+ def dataType = child.dataType
+ override def toString = s"FIRST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialFirst = Alias(First(child), "PartialFirst")()
+ SplitEvaluation(
+ First(partialFirst.toAttribute),
+ partialFirst :: Nil)
+ }
+ override def newInstance = new FirstFunction(child, this)
+}
+
+case class AverageFunction(expr: Expression, base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ private var count: Long = _
+ private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow))
+ private val sumAsDouble = Cast(sum, DoubleType)
+
+
+
+ private val addFunction = Add(sum, expr)
+
+ override def apply(input: Row): Any =
+ sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble
+
+ def update(input: Row): Unit = {
+ count += 1
+ sum.update(addFunction, input)
+ }
+}
+
+case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var count: Int = _
+
+ def update(input: Row): Unit = {
+ val evaluatedExpr = expr.map(_.apply(input))
+ if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
+ count += 1
+ }
+ }
+
+ override def apply(input: Row): Any = count
+}
+
+case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null))
+
+ private val addFunction = Add(sum, expr)
+
+ def update(input: Row): Unit = {
+ sum.update(addFunction, input)
+ }
+
+ override def apply(input: Row): Any = sum.apply(null)
+}
+
+case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new scala.collection.mutable.HashSet[Any]()
+
+ def update(input: Row): Unit = {
+ val evaluatedExpr = expr.apply(input)
+ if (evaluatedExpr != null) {
+ seen += evaluatedExpr
+ }
+ }
+
+ override def apply(input: Row): Any =
+ seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
+}
+
+case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new scala.collection.mutable.HashSet[Any]()
+
+ def update(input: Row): Unit = {
+ val evaluatedExpr = expr.map(_.apply(input))
+ if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
+ seen += evaluatedExpr
+ }
+ }
+
+ override def apply(input: Row): Any = seen.size
+}
+
+case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var result: Any = null
+
+ def update(input: Row): Unit = {
+ if (result == null) {
+ result = expr.apply(input)
+ }
+ }
+
+ override def apply(input: Row): Any = result
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
new file mode 100644
index 0000000000..db235645cd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.analysis.UnresolvedException
+import catalyst.types._
+
+case class UnaryMinus(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def dataType = child.dataType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"-$child"
+
+ override def apply(input: Row): Any = {
+ n1(child, input, _.negate(_))
+ }
+}
+
+abstract class BinaryArithmetic extends BinaryExpression {
+ self: Product =>
+
+ type EvaluatedType = Any
+
+ def nullable = left.nullable || right.nullable
+
+ override lazy val resolved =
+ left.resolved && right.resolved && left.dataType == right.dataType
+
+ def dataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this,
+ s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
+ }
+ left.dataType
+ }
+}
+
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+ def symbol = "+"
+
+ override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _))
+}
+
+case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+ def symbol = "-"
+
+ override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _))
+}
+
+case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+ def symbol = "*"
+
+ override def apply(input: Row): Any = n2(input, left, right, _.times(_, _))
+}
+
+case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+ def symbol = "/"
+
+ override def apply(input: Row): Any = dataType match {
+ case _: FractionalType => f2(input, left, right, _.div(_, _))
+ case _: IntegralType => i2(input, left , right, _.quot(_, _))
+ }
+
+}
+
+case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+ def symbol = "%"
+
+ override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
new file mode 100644
index 0000000000..d3feb6c461
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+
+/**
+ * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`.
+ */
+case class GetItem(child: Expression, ordinal: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ val children = child :: ordinal :: Nil
+ /** `Null` is returned for invalid ordinals. */
+ override def nullable = true
+ override def references = children.flatMap(_.references).toSet
+ def dataType = child.dataType match {
+ case ArrayType(dt) => dt
+ case MapType(_, vt) => vt
+ }
+ override lazy val resolved =
+ childrenResolved &&
+ (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
+
+ override def toString = s"$child[$ordinal]"
+
+ override def apply(input: Row): Any = {
+ if (child.dataType.isInstanceOf[ArrayType]) {
+ val baseValue = child.apply(input).asInstanceOf[Seq[_]]
+ val o = ordinal.apply(input).asInstanceOf[Int]
+ if (baseValue == null) {
+ null
+ } else if (o >= baseValue.size || o < 0) {
+ null
+ } else {
+ baseValue(o)
+ }
+ } else {
+ val baseValue = child.apply(input).asInstanceOf[Map[Any, _]]
+ val key = ordinal.apply(input)
+ if (baseValue == null) {
+ null
+ } else {
+ baseValue.get(key).orNull
+ }
+ }
+ }
+}
+
+/**
+ * Returns the value of fields in the Struct `child`.
+ */
+case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def dataType = field.dataType
+ def nullable = field.nullable
+
+ protected def structType = child.dataType match {
+ case s: StructType => s
+ case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
+ }
+
+ lazy val field =
+ structType.fields
+ .find(_.name == fieldName)
+ .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))
+
+ lazy val ordinal = structType.fields.indexOf(field)
+
+ override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType]
+
+ override def apply(input: Row): Any = {
+ val baseValue = child.apply(input).asInstanceOf[Row]
+ if (baseValue == null) null else baseValue(ordinal)
+ }
+
+ override def toString = s"$child.$fieldName"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
new file mode 100644
index 0000000000..c367de2a3e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.types._
+
+/**
+ * An expression that produces zero or more rows given a single input row.
+ *
+ * Generators produce multiple output rows instead of a single value like other expressions,
+ * and thus they must have a schema to associate with the rows that are output.
+ *
+ * However, unlike row producing relational operators, which are either leaves or determine their
+ * output schema functionally from their input, generators can contain other expressions that
+ * might result in their modification by rules. This structure means that they might be copied
+ * multiple times after first determining their output schema. If a new output schema is created for
+ * each copy references up the tree might be rendered invalid. As a result generators must
+ * instead define a function `makeOutput` which is called only once when the schema is first
+ * requested. The attributes produced by this function will be automatically copied anytime rules
+ * result in changes to the Generator or its children.
+ */
+abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
+ self: Product =>
+
+ type EvaluatedType = TraversableOnce[Row]
+
+ lazy val dataType =
+ ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))
+
+ def nullable = false
+
+ def references = children.flatMap(_.references).toSet
+
+ /**
+ * Should be overridden by specific generators. Called only once for each instance to ensure
+ * that rule application does not change the output schema of a generator.
+ */
+ protected def makeOutput(): Seq[Attribute]
+
+ private var _output: Seq[Attribute] = null
+
+ def output: Seq[Attribute] = {
+ if (_output == null) {
+ _output = makeOutput()
+ }
+ _output
+ }
+
+ /** Should be implemented by child classes to perform specific Generators. */
+ def apply(input: Row): TraversableOnce[Row]
+
+ /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
+ override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ val copy = super.makeCopy(newArgs)
+ copy._output = _output
+ copy
+ }
+}
+
+/**
+ * Given an input array produces a sequence of rows for each value in the array.
+ */
+case class Explode(attributeNames: Seq[String], child: Expression)
+ extends Generator with trees.UnaryNode[Expression] {
+
+ override lazy val resolved =
+ child.resolved &&
+ (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
+
+ lazy val elementTypes = child.dataType match {
+ case ArrayType(et) => et :: Nil
+ case MapType(kt,vt) => kt :: vt :: Nil
+ }
+
+ // TODO: Move this pattern into Generator.
+ protected def makeOutput() =
+ if (attributeNames.size == elementTypes.size) {
+ attributeNames.zip(elementTypes).map {
+ case (n, t) => AttributeReference(n, t, nullable = true)()
+ }
+ } else {
+ elementTypes.zipWithIndex.map {
+ case (t, i) => AttributeReference(s"c_$i", t, nullable = true)()
+ }
+ }
+
+ override def apply(input: Row): TraversableOnce[Row] = {
+ child.dataType match {
+ case ArrayType(_) =>
+ val inputArray = child.apply(input).asInstanceOf[Seq[Any]]
+ if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
+ case MapType(_, _) =>
+ val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]]
+ if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) }
+ }
+ }
+
+ override def toString() = s"explode($child)"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
new file mode 100644
index 0000000000..229d8f7f7b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+
+object Literal {
+ def apply(v: Any): Literal = v match {
+ case i: Int => Literal(i, IntegerType)
+ case l: Long => Literal(l, LongType)
+ case d: Double => Literal(d, DoubleType)
+ case f: Float => Literal(f, FloatType)
+ case b: Byte => Literal(b, ByteType)
+ case s: Short => Literal(s, ShortType)
+ case s: String => Literal(s, StringType)
+ case b: Boolean => Literal(b, BooleanType)
+ case null => Literal(null, NullType)
+ }
+}
+
+/**
+ * Extractor for retrieving Int literals.
+ */
+object IntegerLiteral {
+ def unapply(a: Any): Option[Int] = a match {
+ case Literal(a: Int, IntegerType) => Some(a)
+ case _ => None
+ }
+}
+
+case class Literal(value: Any, dataType: DataType) extends LeafExpression {
+
+ override def foldable = true
+ def nullable = value == null
+ def references = Set.empty
+
+ override def toString = if (value != null) value.toString else "null"
+
+ type EvaluatedType = Any
+ override def apply(input: Row):Any = value
+}
+
+// TODO: Specialize
+case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression {
+ type EvaluatedType = Any
+
+ val dataType = Literal(value).dataType
+
+ def references = Set.empty
+
+ def update(expression: Expression, input: Row) = {
+ value = expression.apply(input)
+ }
+
+ override def apply(input: Row) = value
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
new file mode 100644
index 0000000000..0a06e85325
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.analysis.UnresolvedAttribute
+import types._
+
+object NamedExpression {
+ private val curId = new java.util.concurrent.atomic.AtomicLong()
+ def newExprId = ExprId(curId.getAndIncrement())
+}
+
+/**
+ * A globally (within this JVM) id for a given named expression.
+ * Used to identify with attribute output by a relation is being
+ * referenced in a subsuqent computation.
+ */
+case class ExprId(id: Long)
+
+abstract class NamedExpression extends Expression {
+ self: Product =>
+
+ def name: String
+ def exprId: ExprId
+ def qualifiers: Seq[String]
+
+ def toAttribute: Attribute
+
+ protected def typeSuffix =
+ if (resolved) {
+ dataType match {
+ case LongType => "L"
+ case _ => ""
+ }
+ } else {
+ ""
+ }
+}
+
+abstract class Attribute extends NamedExpression {
+ self: Product =>
+
+ def withQualifiers(newQualifiers: Seq[String]): Attribute
+
+ def references = Set(this)
+ def toAttribute = this
+ def newInstance: Attribute
+}
+
+/**
+ * Used to assign a new name to a computation.
+ * For example the SQL expression "1 + 1 AS a" could be represented as follows:
+ * Alias(Add(Literal(1), Literal(1), "a")()
+ *
+ * @param child the computation being performed
+ * @param name the name to be associated with the result of computing [[child]].
+ * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
+ * alias. Auto-assigned if left blank.
+ */
+case class Alias(child: Expression, name: String)
+ (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
+ extends NamedExpression with trees.UnaryNode[Expression] {
+
+ type EvaluatedType = Any
+
+ override def apply(input: Row) = child.apply(input)
+
+ def dataType = child.dataType
+ def nullable = child.nullable
+ def references = child.references
+
+ def toAttribute = {
+ if (resolved) {
+ AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
+ } else {
+ UnresolvedAttribute(name)
+ }
+ }
+
+ override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
+
+ override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
+}
+
+/**
+ * A reference to an attribute produced by another operator in the tree.
+ *
+ * @param name The name of this attribute, should only be used during analysis or for debugging.
+ * @param dataType The [[types.DataType DataType]] of this attribute.
+ * @param nullable True if null is a valid value for this attribute.
+ * @param exprId A globally unique id used to check if different AttributeReferences refer to the
+ * same attribute.
+ * @param qualifiers a list of strings that can be used to referred to this attribute in a fully
+ * qualified way. Consider the examples tableName.name, subQueryAlias.name.
+ * tableName and subQueryAlias are possible qualifiers.
+ */
+case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true)
+ (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
+ extends Attribute with trees.LeafNode[Expression] {
+
+ override def equals(other: Any) = other match {
+ case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // See http://stackoverflow.com/questions/113511/hash-code-implementation
+ var h = 17
+ h = h * 37 + exprId.hashCode()
+ h = h * 37 + dataType.hashCode()
+ h
+ }
+
+ def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+
+ /**
+ * Returns a copy of this [[AttributeReference]] with changed nullability.
+ */
+ def withNullability(newNullability: Boolean) = {
+ if (nullable == newNullability) {
+ this
+ } else {
+ AttributeReference(name, dataType, newNullability)(exprId, qualifiers)
+ }
+ }
+
+ /**
+ * Returns a copy of this [[AttributeReference]] with new qualifiers.
+ */
+ def withQualifiers(newQualifiers: Seq[String]) = {
+ if (newQualifiers == qualifiers) {
+ this
+ } else {
+ AttributeReference(name, dataType, nullable)(exprId, newQualifiers)
+ }
+ }
+
+ override def toString: String = s"$name#${exprId.id}$typeSuffix"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
new file mode 100644
index 0000000000..e869a4d9b0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.analysis.UnresolvedException
+
+case class Coalesce(children: Seq[Expression]) extends Expression {
+ type EvaluatedType = Any
+
+ /** Coalesce is nullable if all of its children are nullable, or if it has no children. */
+ def nullable = !children.exists(!_.nullable)
+
+ def references = children.flatMap(_.references).toSet
+ // Coalesce is foldable if all children are foldable.
+ override def foldable = !children.exists(!_.foldable)
+
+ // Only resolved if all the children are of the same type.
+ override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
+
+ override def toString = s"Coalesce(${children.mkString(",")})"
+
+ def dataType = if (resolved) {
+ children.head.dataType
+ } else {
+ throw new UnresolvedException(this, "Coalesce cannot have children of different types.")
+ }
+
+ override def apply(input: Row): Any = {
+ var i = 0
+ var result: Any = null
+ while(i < children.size && result == null) {
+ result = children(i).apply(input)
+ i += 1
+ }
+ result
+ }
+}
+
+case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
+ def references = child.references
+ override def foldable = child.foldable
+ def nullable = false
+
+ override def apply(input: Row): Any = {
+ child.apply(input) == null
+ }
+}
+
+case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
+ def references = child.references
+ override def foldable = child.foldable
+ def nullable = false
+ override def toString = s"IS NOT NULL $child"
+
+ override def apply(input: Row): Any = {
+ child.apply(input) != null
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
new file mode 100644
index 0000000000..76554e160b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+/**
+ * A set of classes that can be used to represent trees of relational expressions. A key goal of
+ * the expression library is to hide the details of naming and scoping from developers who want to
+ * manipulate trees of relational operators. As such, the library defines a special type of
+ * expression, a [[NamedExpression]] in addition to the standard collection of expressions.
+ *
+ * ==Standard Expressions==
+ * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT),
+ * and other computations (e.g. UDFs). Each expression type is capable of determining its output
+ * schema as a function of its children's output schema.
+ *
+ * ==Named Expressions==
+ * Some expression are named and thus can be referenced by later operators in the dataflow graph.
+ * The two types of named expressions are [[AttributeReference]]s and [[Alias]]es.
+ * [[AttributeReference]]s refer to attributes of the input tuple for a given operator and form
+ * the leaves of some expression trees. Aliases assign a name to intermediate computations.
+ * For example, in the SQL statement `SELECT a+b AS c FROM ...`, the expressions `a` and `b` would
+ * be represented by `AttributeReferences` and `c` would be represented by an `Alias`.
+ *
+ * During [[analysis]], all named expressions are assigned a globally unique expression id, which
+ * can be used for equality comparisons. While the original names are kept around for debugging
+ * purposes, they should never be used to check if two attributes refer to the same value, as
+ * plan transformations can result in the introduction of naming ambiguity. For example, consider
+ * a plan that contains subqueries, both of which are reading from the same table. If an
+ * optimization removes the subqueries, scoping information would be destroyed, eliminating the
+ * ability to reason about which subquery produced a given attribute.
+ *
+ * ==Evaluation==
+ * The result of expressions can be evaluated using the [[Evaluate]] object.
+ */
+package object expressions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
new file mode 100644
index 0000000000..561396eb43
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import types._
+import catalyst.analysis.UnresolvedException
+
+trait Predicate extends Expression {
+ self: Product =>
+
+ def dataType = BooleanType
+
+ type EvaluatedType = Any
+}
+
+trait PredicateHelper {
+ def splitConjunctivePredicates(condition: Expression): Seq[Expression] = condition match {
+ case And(cond1, cond2) => splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
+ case other => other :: Nil
+ }
+}
+
+abstract class BinaryPredicate extends BinaryExpression with Predicate {
+ self: Product =>
+ def nullable = left.nullable || right.nullable
+}
+
+case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
+ def references = child.references
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"NOT $child"
+
+ override def apply(input: Row): Any = {
+ child.apply(input) match {
+ case null => null
+ case b: Boolean => !b
+ }
+ }
+}
+
+/**
+ * Evaluates to `true` if `list` contains `value`.
+ */
+case class In(value: Expression, list: Seq[Expression]) extends Predicate {
+ def children = value +: list
+ def references = children.flatMap(_.references).toSet
+ def nullable = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
+
+ override def apply(input: Row): Any = {
+ val evaluatedValue = value.apply(input)
+ list.exists(e => e.apply(input) == evaluatedValue)
+ }
+}
+
+case class And(left: Expression, right: Expression) extends BinaryPredicate {
+ def symbol = "&&"
+
+ override def apply(input: Row): Any = {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if (l == false || r == false) {
+ false
+ } else if (l == null || r == null ) {
+ null
+ } else {
+ true
+ }
+ }
+}
+
+case class Or(left: Expression, right: Expression) extends BinaryPredicate {
+ def symbol = "||"
+
+ override def apply(input: Row): Any = {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if (l == true || r == true) {
+ true
+ } else if (l == null || r == null) {
+ null
+ } else {
+ false
+ }
+ }
+}
+
+abstract class BinaryComparison extends BinaryPredicate {
+ self: Product =>
+}
+
+case class Equals(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = "="
+ override def apply(input: Row): Any = {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if (l == null || r == null) null else l == r
+ }
+}
+
+case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = "<"
+ override def apply(input: Row): Any = {
+ if (left.dataType == StringType && right.dataType == StringType) {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if(l == null || r == null) {
+ null
+ } else {
+ l.asInstanceOf[String] < r.asInstanceOf[String]
+ }
+ } else {
+ n2(input, left, right, _.lt(_, _))
+ }
+ }
+}
+
+case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = "<="
+ override def apply(input: Row): Any = {
+ if (left.dataType == StringType && right.dataType == StringType) {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if(l == null || r == null) {
+ null
+ } else {
+ l.asInstanceOf[String] <= r.asInstanceOf[String]
+ }
+ } else {
+ n2(input, left, right, _.lteq(_, _))
+ }
+ }
+}
+
+case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = ">"
+ override def apply(input: Row): Any = {
+ if (left.dataType == StringType && right.dataType == StringType) {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if(l == null || r == null) {
+ null
+ } else {
+ l.asInstanceOf[String] > r.asInstanceOf[String]
+ }
+ } else {
+ n2(input, left, right, _.gt(_, _))
+ }
+ }
+}
+
+case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+ def symbol = ">="
+ override def apply(input: Row): Any = {
+ if (left.dataType == StringType && right.dataType == StringType) {
+ val l = left.apply(input)
+ val r = right.apply(input)
+ if(l == null || r == null) {
+ null
+ } else {
+ l.asInstanceOf[String] >= r.asInstanceOf[String]
+ }
+ } else {
+ n2(input, left, right, _.gteq(_, _))
+ }
+ }
+}
+
+case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
+ extends Expression {
+
+ def children = predicate :: trueValue :: falseValue :: Nil
+ def nullable = trueValue.nullable || falseValue.nullable
+ def references = children.flatMap(_.references).toSet
+ override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
+ def dataType = {
+ if (!resolved) {
+ throw new UnresolvedException(
+ this,
+ s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}")
+ }
+ trueValue.dataType
+ }
+
+ type EvaluatedType = Any
+ override def apply(input: Row): Any = {
+ if (predicate(input).asInstanceOf[Boolean]) {
+ trueValue.apply(input)
+ } else {
+ falseValue.apply(input)
+ }
+ }
+
+ override def toString = s"if ($predicate) $trueValue else $falseValue"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
new file mode 100644
index 0000000000..6e585236b1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import catalyst.types.BooleanType
+
+case class Like(left: Expression, right: Expression) extends BinaryExpression {
+ def dataType = BooleanType
+ def nullable = left.nullable // Right cannot be null.
+ def symbol = "LIKE"
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
new file mode 100644
index 0000000000..4db2803173
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package optimizer
+
+import catalyst.expressions._
+import catalyst.plans.logical._
+import catalyst.rules._
+import catalyst.types.BooleanType
+import catalyst.plans.Inner
+
+object Optimizer extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubqueries) ::
+ Batch("ConstantFolding", Once,
+ ConstantFolding,
+ BooleanSimplification,
+ SimplifyCasts) ::
+ Batch("Filter Pushdown", Once,
+ EliminateSubqueries,
+ CombineFilters,
+ PushPredicateThroughProject,
+ PushPredicateThroughInnerJoin) :: Nil
+}
+
+/**
+ * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are
+ * only required to provide scoping information for attributes and can be removed once analysis is
+ * complete.
+ */
+object EliminateSubqueries extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Subquery(_, child) => child
+ }
+}
+
+/**
+ * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
+ * equivalent [[catalyst.expressions.Literal Literal]] values.
+ */
+object ConstantFolding extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ // Skip redundant folding of literals.
+ case l: Literal => l
+ case e if e.foldable => Literal(e.apply(null), e.dataType)
+ }
+ }
+}
+
+/**
+ * Simplifies boolean expressions where the answer can be determined without evaluating both sides.
+ * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
+ * is only safe when evaluations of expressions does not result in side effects.
+ */
+object BooleanSimplification extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case and @ And(left, right) => {
+ (left, right) match {
+ case (Literal(true, BooleanType), r) => r
+ case (l, Literal(true, BooleanType)) => l
+ case (Literal(false, BooleanType), _) => Literal(false)
+ case (_, Literal(false, BooleanType)) => Literal(false)
+ case (_, _) => and
+ }
+ }
+ case or @ Or(left, right) => {
+ (left, right) match {
+ case (Literal(true, BooleanType), _) => Literal(true)
+ case (_, Literal(true, BooleanType)) => Literal(true)
+ case (Literal(false, BooleanType), r) => r
+ case (l, Literal(false, BooleanType)) => l
+ case (_, _) => or
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the
+ * conditions into one conjunctive predicate.
+ */
+object CombineFilters extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case ff@Filter(fc, nf@Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild)
+ }
+}
+
+/**
+ * Pushes [[catalyst.plans.logical.Filter Filter]] operators through
+ * [[catalyst.plans.logical.Project Project]] operators, in-lining any
+ * [[catalyst.expressions.Alias Aliases]] that were defined in the projection.
+ *
+ * This heuristic is valid assuming the expression evaluation cost is minimal.
+ */
+object PushPredicateThroughProject extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case filter@Filter(condition, project@Project(fields, grandChild)) =>
+ val sourceAliases = fields.collect { case a@Alias(c, _) => a.toAttribute -> c }.toMap
+ project.copy(child = filter.copy(
+ replaceAlias(condition, sourceAliases),
+ grandChild))
+ }
+
+ //
+ def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = {
+ condition transform {
+ case a: AttributeReference => sourceAliases.getOrElse(a, a)
+ }
+ }
+}
+
+/**
+ * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be
+ * evaluated using only the attributes of the left or right side of an inner join. Other
+ * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the
+ * [[catalyst.plans.logical.Join Join]].
+ */
+object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) =>
+ val allConditions =
+ splitConjunctivePredicates(filterCondition) ++
+ joinCondition.map(splitConjunctivePredicates).getOrElse(Nil)
+
+ // Split the predicates into those that can be evaluated on the left, right, and those that
+ // must be evaluated after the join.
+ val (rightConditions, leftOrJoinConditions) =
+ allConditions.partition(_.references subsetOf right.outputSet)
+ val (leftConditions, joinConditions) =
+ leftOrJoinConditions.partition(_.references subsetOf left.outputSet)
+
+ // Build the new left and right side, optionally with the pushed down filters.
+ val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And))
+ }
+}
+
+/**
+ * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already
+ * the correct type.
+ */
+object SimplifyCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case Cast(e, dataType) if e.dataType == dataType => e
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
new file mode 100644
index 0000000000..22f8ea005b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package planning
+
+
+import plans.logical.LogicalPlan
+import trees._
+
+/**
+ * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans.
+ * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which
+ * can return a list of possible physical plan options. If a given strategy is unable to plan all
+ * of the remaining operators in the tree, it can call [[planLater]], which returns a placeholder
+ * object that will be filled in using other available strategies.
+ *
+ * TODO: RIGHT NOW ONLY ONE PLAN IS RETURNED EVER...
+ * PLAN SPACE EXPLORATION WILL BE IMPLEMENTED LATER.
+ *
+ * @tparam PhysicalPlan The type of physical plan produced by this [[QueryPlanner]]
+ */
+abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
+ /** A list of execution strategies that can be used by the planner */
+ def strategies: Seq[Strategy]
+
+ /**
+ * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can
+ * be used for execution. If this strategy does not apply to the give logical operation then an
+ * empty list should be returned.
+ */
+ abstract protected class Strategy extends Logging {
+ def apply(plan: LogicalPlan): Seq[PhysicalPlan]
+ }
+
+ /**
+ * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be
+ * filled in automatically by the QueryPlanner using the other execution strategies that are
+ * available.
+ */
+ protected def planLater(plan: LogicalPlan) = apply(plan).next()
+
+ def apply(plan: LogicalPlan): Iterator[PhysicalPlan] = {
+ // Obviously a lot to do here still...
+ val iter = strategies.view.flatMap(_(plan)).toIterator
+ assert(iter.hasNext, s"No plan for $plan")
+ iter
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala
new file mode 100644
index 0000000000..64370ec7c0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+/**
+ * Contains classes for enumerating possible physical plans for a given logical query plan.
+ */
+package object planning
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
new file mode 100644
index 0000000000..613b028ca8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package planning
+
+import scala.annotation.tailrec
+
+import expressions._
+import plans.logical._
+
+/**
+ * A pattern that matches any number of filter operations on top of another relational operator.
+ * Adjacent filter operators are collected and their conditions are broken up and returned as a
+ * sequence of conjunctive predicates.
+ *
+ * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the
+ * output and a relational operator.
+ */
+object FilteredOperation extends PredicateHelper {
+ type ReturnType = (Seq[Expression], LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan))
+
+ @tailrec
+ private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match {
+ case Filter(condition, child) =>
+ collectFilters(filters ++ splitConjunctivePredicates(condition), child)
+ case other => (filters, other)
+ }
+}
+
+/**
+ * A pattern that matches any number of project or filter operations on top of another relational
+ * operator. All filter operators are collected and their conditions are broken up and returned
+ * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if necessary.
+ */
+object PhysicalOperation extends PredicateHelper {
+ type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = {
+ val (fields, filters, child, _) = collectProjectsAndFilters(plan)
+ Some((fields.getOrElse(child.output), filters, child))
+ }
+
+ /**
+ * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two
+ * examples for alias in-lining/substitution. Before:
+ * {{{
+ * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
+ * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
+ * }}}
+ * After:
+ * {{{
+ * SELECT key AS c1 FROM t1 WHERE key > 10
+ * SELECT key AS c2 FROM t1 WHERE key > 10
+ * }}}
+ */
+ def collectProjectsAndFilters(plan: LogicalPlan):
+ (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) =
+ plan match {
+ case Project(fields, child) =>
+ val (_, filters, other, aliases) = collectProjectsAndFilters(child)
+ val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
+ (Some(substitutedFields), filters, other, collectAliases(substitutedFields))
+
+ case Filter(condition, child) =>
+ val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
+ val substitutedCondition = substitute(aliases)(condition)
+ (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+
+ case other =>
+ (None, Nil, other, Map.empty)
+ }
+
+ def collectAliases(fields: Seq[Expression]) = fields.collect {
+ case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child
+ }.toMap
+
+ def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform {
+ case a @ Alias(ref: AttributeReference, name) =>
+ aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
+
+ case a: AttributeReference =>
+ aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ }
+}
+
+/**
+ * A pattern that collects all adjacent unions and returns their children as a Seq.
+ */
+object Unions {
+ def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
+ case u: Union => Some(collectUnionChildren(u))
+ case _ => None
+ }
+
+ private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
+ case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r)
+ case other => other :: Nil
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
new file mode 100644
index 0000000000..20f230c5c4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+
+import catalyst.expressions.{SortOrder, Attribute, Expression}
+import catalyst.trees._
+
+abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
+ self: PlanType with Product =>
+
+ def output: Seq[Attribute]
+
+ /**
+ * Returns the set of attributes that are output by this node.
+ */
+ def outputSet: Set[Attribute] = output.toSet
+
+ /**
+ * Runs [[transform]] with `rule` on all expressions present in this query operator.
+ * Users should not expect a specific directionality. If a specific directionality is needed,
+ * transformExpressionsDown or transformExpressionsUp should be used.
+ * @param rule the rule to be applied to every expression in this operator.
+ */
+ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
+ transformExpressionsDown(rule)
+ }
+
+ /**
+ * Runs [[transformDown]] with `rule` on all expressions present in this query operator.
+ * @param rule the rule to be applied to every expression in this operator.
+ */
+ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
+ var changed = false
+
+ @inline def transformExpressionDown(e: Expression) = {
+ val newE = e.transformDown(rule)
+ if (newE.id != e.id && newE != e) {
+ changed = true
+ newE
+ } else {
+ e
+ }
+ }
+
+ val newArgs = productIterator.map {
+ case e: Expression => transformExpressionDown(e)
+ case Some(e: Expression) => Some(transformExpressionDown(e))
+ case m: Map[_,_] => m
+ case seq: Traversable[_] => seq.map {
+ case e: Expression => transformExpressionDown(e)
+ case other => other
+ }
+ case other: AnyRef => other
+ }.toArray
+
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /**
+ * Runs [[transformUp]] with `rule` on all expressions present in this query operator.
+ * @param rule the rule to be applied to every expression in this operator.
+ * @return
+ */
+ def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
+ var changed = false
+
+ @inline def transformExpressionUp(e: Expression) = {
+ val newE = e.transformUp(rule)
+ if (newE.id != e.id && newE != e) {
+ changed = true
+ newE
+ } else {
+ e
+ }
+ }
+
+ val newArgs = productIterator.map {
+ case e: Expression => transformExpressionUp(e)
+ case Some(e: Expression) => Some(transformExpressionUp(e))
+ case m: Map[_,_] => m
+ case seq: Traversable[_] => seq.map {
+ case e: Expression => transformExpressionUp(e)
+ case other => other
+ }
+ case other: AnyRef => other
+ }.toArray
+
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /** Returns the result of running [[transformExpressions]] on this node
+ * and all its children. */
+ def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
+ transform {
+ case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType]
+ }.asInstanceOf[this.type]
+ }
+
+ /** Returns all of the expressions present in this query plan operator. */
+ def expressions: Seq[Expression] = {
+ productIterator.flatMap {
+ case e: Expression => e :: Nil
+ case Some(e: Expression) => e :: Nil
+ case seq: Traversable[_] => seq.flatMap {
+ case e: Expression => e :: Nil
+ case other => Nil
+ }
+ case other => Nil
+ }.toSeq
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
new file mode 100644
index 0000000000..9f2283ad43
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+
+sealed abstract class JoinType
+case object Inner extends JoinType
+case object LeftOuter extends JoinType
+case object RightOuter extends JoinType
+case object FullOuter extends JoinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
new file mode 100644
index 0000000000..48ff45c3d3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+abstract class BaseRelation extends LeafNode {
+ self: Product =>
+
+ def tableName: String
+ def isPartitioned: Boolean = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
new file mode 100644
index 0000000000..bc7b6871df
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+import catalyst.expressions._
+import catalyst.errors._
+import catalyst.types.StructType
+
+abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
+ self: Product =>
+
+ /**
+ * Returns the set of attributes that are referenced by this node
+ * during evaluation.
+ */
+ def references: Set[Attribute]
+
+ /**
+ * Returns the set of attributes that this node takes as
+ * input from its children.
+ */
+ lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet
+
+ /**
+ * Returns true if this expression and all its children have been resolved to a specific schema
+ * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan
+ * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should
+ * return `false`).
+ */
+ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved
+
+ /**
+ * Returns true if all its children of this query plan have been resolved.
+ */
+ def childrenResolved = !children.exists(!_.resolved)
+
+ /**
+ * Optionally resolves the given string to a
+ * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as
+ * as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
+ */
+ def resolve(name: String): Option[NamedExpression] = {
+ val parts = name.split("\\.")
+ // Collect all attributes that are output by this nodes children where either the first part
+ // matches the name or where the first part matches the scope and the second part matches the
+ // name. Return these matches along with any remaining parts, which represent dotted access to
+ // struct fields.
+ val options = children.flatMap(_.output).flatMap { option =>
+ // If the first part of the desired name matches a qualifier for this possible match, drop it.
+ val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts
+ if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
+ }
+
+ options.distinct match {
+ case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
+ // One match, but we also need to extract the requested nested field.
+ case (a, nestedFields) :: Nil =>
+ a.dataType match {
+ case StructType(fields) =>
+ Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
+ case _ => None // Don't know how to resolve these field references
+ }
+ case Nil => None // No matches.
+ case ambiguousReferences =>
+ throw new TreeNodeException(
+ this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
+ }
+ }
+}
+
+/**
+ * A logical plan node with no children.
+ */
+abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
+ self: Product =>
+
+ // Leaf nodes by definition cannot reference any input attributes.
+ def references = Set.empty
+}
+
+/**
+ * A logical node that represents a non-query command to be executed by the system. For example,
+ * commands can be used by parsers to represent DDL operations.
+ */
+abstract class Command extends LeafNode {
+ self: Product =>
+ def output = Seq.empty
+}
+
+/**
+ * Returned for commands supported by a given parser, but not catalyst. In general these are DDL
+ * commands that are passed directly to another system.
+ */
+case class NativeCommand(cmd: String) extends Command
+
+/**
+ * Returned by a parser when the users only wants to see what query plan would be executed, without
+ * actually performing the execution.
+ */
+case class ExplainCommand(plan: LogicalPlan) extends Command
+
+/**
+ * A logical plan node with single child.
+ */
+abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] {
+ self: Product =>
+}
+
+/**
+ * A logical plan node with a left and right child.
+ */
+abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] {
+ self: Product =>
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
new file mode 100644
index 0000000000..1a1a2b9b88
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+import expressions._
+
+/**
+ * Transforms the input by forking and running the specified script.
+ *
+ * @param input the set of expression that should be passed to the script.
+ * @param script the command that should be executed.
+ * @param output the attributes that are produced by the script.
+ */
+case class ScriptTransformation(
+ input: Seq[Expression],
+ script: String,
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ def references = input.flatMap(_.references).toSet
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
new file mode 100644
index 0000000000..b5905a4456
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+import expressions._
+import rules._
+
+object LocalRelation {
+ def apply(output: Attribute*) =
+ new LocalRelation(output)
+}
+
+case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
+ extends LeafNode with analysis.MultiInstanceRelation {
+
+ // TODO: Validate schema compliance.
+ def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData)
+
+ /**
+ * Returns an identical copy of this relation with new exprIds for all attributes. Different
+ * attributes are required when a relation is going to be included multiple times in the same
+ * query.
+ */
+ override final def newInstance: this.type = {
+ LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type]
+ }
+
+ override protected def stringArgs = Iterator(output)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
new file mode 100644
index 0000000000..8e98aab736
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+import expressions._
+
+case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
+ def output = projectList.map(_.toAttribute)
+ def references = projectList.flatMap(_.references).toSet
+}
+
+/**
+ * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the
+ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
+ * programming with one important additional feature, which allows the input rows to be joined with
+ * their output.
+ * @param join when true, each output row is implicitly joined with the input tuple that produced
+ * it.
+ * @param outer when true, each input row will be output at least once, even if the output of the
+ * given `generator` is empty. `outer` has no effect when `join` is false.
+ * @param alias when set, this string is applied to the schema of the output of the transformation
+ * as a qualifier.
+ */
+case class Generate(
+ generator: Generator,
+ join: Boolean,
+ outer: Boolean,
+ alias: Option[String],
+ child: LogicalPlan)
+ extends UnaryNode {
+
+ protected def generatorOutput =
+ alias
+ .map(a => generator.output.map(_.withQualifiers(a :: Nil)))
+ .getOrElse(generator.output)
+
+ def output =
+ if (join) child.output ++ generatorOutput else generatorOutput
+
+ def references =
+ if (join) child.outputSet else generator.references
+}
+
+case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
+ def output = child.output
+ def references = condition.references
+}
+
+case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
+ // TODO: These aren't really the same attributes as nullability etc might change.
+ def output = left.output
+
+ override lazy val resolved =
+ childrenResolved &&
+ !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
+
+ def references = Set.empty
+}
+
+case class Join(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression]) extends BinaryNode {
+
+ def references = condition.map(_.references).getOrElse(Set.empty)
+ def output = left.output ++ right.output
+}
+
+case class InsertIntoTable(
+ table: BaseRelation,
+ partition: Map[String, Option[String]],
+ child: LogicalPlan,
+ overwrite: Boolean)
+ extends LogicalPlan {
+ // The table being inserted into is a child for the purposes of transformations.
+ def children = table :: child :: Nil
+ def references = Set.empty
+ def output = child.output
+
+ override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
+ case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType
+ }
+}
+
+case class InsertIntoCreatedTable(
+ databaseName: Option[String],
+ tableName: String,
+ child: LogicalPlan) extends UnaryNode {
+ def references = Set.empty
+ def output = child.output
+}
+
+case class WriteToFile(
+ path: String,
+ child: LogicalPlan) extends UnaryNode {
+ def references = Set.empty
+ def output = child.output
+}
+
+case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
+ def output = child.output
+ def references = order.flatMap(_.references).toSet
+}
+
+case class Aggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: LogicalPlan)
+ extends UnaryNode {
+
+ def output = aggregateExpressions.map(_.toAttribute)
+ def references = child.references
+}
+
+case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
+ def output = child.output
+ def references = limit.references
+}
+
+case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
+ def output = child.output.map(_.withQualifiers(alias :: Nil))
+ def references = Set.empty
+}
+
+case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
+ extends UnaryNode {
+
+ def output = child.output
+ def references = Set.empty
+}
+
+case class Distinct(child: LogicalPlan) extends UnaryNode {
+ def output = child.output
+ def references = child.outputSet
+}
+
+case object NoRelation extends LeafNode {
+ def output = Nil
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
new file mode 100644
index 0000000000..f7fcdc5fdb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package logical
+
+import expressions._
+
+/**
+ * Performs a physical redistribution of the data. Used when the consumer of the query
+ * result have expectations about the distribution and ordering of partitioned input data.
+ */
+abstract class RedistributeData extends UnaryNode {
+ self: Product =>
+
+ def output = child.output
+}
+
+case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
+ extends RedistributeData {
+
+ def references = sortExpressions.flatMap(_.references).toSet
+}
+
+case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
+ extends RedistributeData {
+
+ def references = partitionExpressions.flatMap(_.references).toSet
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala
new file mode 100644
index 0000000000..a40ab4bbb1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+/**
+ * A a collection of common abstractions for query plans as well as
+ * a base logical plan representation.
+ */
+package object plans
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
new file mode 100644
index 0000000000..2d8f3ad335
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package plans
+package physical
+
+import expressions._
+import types._
+
+/**
+ * Specifies how tuples that share common expressions will be distributed when a query is executed
+ * in parallel on many machines. Distribution can be used to refer to two distinct physical
+ * properties:
+ * - Inter-node partitioning of data: In this case the distribution describes how tuples are
+ * partitioned across physical machines in a cluster. Knowing this property allows some
+ * operators (e.g., Aggregate) to perform partition local operations instead of global ones.
+ * - Intra-partition ordering of data: In this case the distribution describes guarantees made
+ * about how tuples are distributed within a single partition.
+ */
+sealed trait Distribution
+
+/**
+ * Represents a distribution where no promises are made about co-location of data.
+ */
+case object UnspecifiedDistribution extends Distribution
+
+/**
+ * Represents a distribution that only has a single partition and all tuples of the dataset
+ * are co-located.
+ */
+case object AllTuples extends Distribution
+
+/**
+ * Represents data where tuples that share the same values for the `clustering`
+ * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this
+ * can mean such tuples are either co-located in the same partition or they will be contiguous
+ * within a single partition.
+ */
+case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
+ require(
+ clustering != Nil,
+ "The clustering expressions of a ClusteredDistribution should not be Nil. " +
+ "An AllTuples should be used to represent a distribution that only has " +
+ "a single partition.")
+}
+
+/**
+ * Represents data where tuples have been ordered according to the `ordering`
+ * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than
+ * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
+ * the ordering expressions are contiguous and will never be split across partitions.
+ */
+case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
+ require(
+ ordering != Nil,
+ "The ordering expressions of a OrderedDistribution should not be Nil. " +
+ "An AllTuples should be used to represent a distribution that only has " +
+ "a single partition.")
+
+ def clustering = ordering.map(_.child).toSet
+}
+
+sealed trait Partitioning {
+ /** Returns the number of partitions that the data is split across */
+ val numPartitions: Int
+
+ /**
+ * Returns true iff the guarantees made by this
+ * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy
+ * the partitioning scheme mandated by the `required`
+ * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not
+ * need to be re-partitioned for the `required` Distribution (it is possible that tuples within
+ * a partition need to be reorganized).
+ */
+ def satisfies(required: Distribution): Boolean
+
+ /**
+ * Returns true iff all distribution guarantees made by this partitioning can also be made
+ * for the `other` specified partitioning.
+ * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are
+ * only compatible if the `numPartitions` of them is the same.
+ */
+ def compatibleWith(other: Partitioning): Boolean
+}
+
+case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
+ override def satisfies(required: Distribution): Boolean = required match {
+ case UnspecifiedDistribution => true
+ case _ => false
+ }
+
+ override def compatibleWith(other: Partitioning): Boolean = other match {
+ case UnknownPartitioning(_) => true
+ case _ => false
+ }
+}
+
+case object SinglePartition extends Partitioning {
+ val numPartitions = 1
+
+ override def satisfies(required: Distribution): Boolean = true
+
+ override def compatibleWith(other: Partitioning) = other match {
+ case SinglePartition => true
+ case _ => false
+ }
+}
+
+case object BroadcastPartitioning extends Partitioning {
+ val numPartitions = 1
+
+ override def satisfies(required: Distribution): Boolean = true
+
+ override def compatibleWith(other: Partitioning) = other match {
+ case SinglePartition => true
+ case _ => false
+ }
+}
+
+/**
+ * Represents a partitioning where rows are split up across partitions based on the hash
+ * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
+ * in the same partition.
+ */
+case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
+ extends Expression
+ with Partitioning {
+
+ def children = expressions
+ def references = expressions.flatMap(_.references).toSet
+ def nullable = false
+ def dataType = IntegerType
+
+ lazy val clusteringSet = expressions.toSet
+
+ override def satisfies(required: Distribution): Boolean = required match {
+ case UnspecifiedDistribution => true
+ case ClusteredDistribution(requiredClustering) =>
+ clusteringSet.subsetOf(requiredClustering.toSet)
+ case _ => false
+ }
+
+ override def compatibleWith(other: Partitioning) = other match {
+ case BroadcastPartitioning => true
+ case h: HashPartitioning if h == this => true
+ case _ => false
+ }
+}
+
+/**
+ * Represents a partitioning where rows are split across partitions based on some total ordering of
+ * the expressions specified in `ordering`. When data is partitioned in this manner the following
+ * two conditions are guaranteed to hold:
+ * - All row where the expressions in `ordering` evaluate to the same values will be in the same
+ * partition.
+ * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows
+ * that are in between `min` and `max` in this `ordering` will reside in this partition.
+ */
+case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
+ extends Expression
+ with Partitioning {
+
+ def children = ordering
+ def references = ordering.flatMap(_.references).toSet
+ def nullable = false
+ def dataType = IntegerType
+
+ lazy val clusteringSet = ordering.map(_.child).toSet
+
+ override def satisfies(required: Distribution): Boolean = required match {
+ case UnspecifiedDistribution => true
+ case OrderedDistribution(requiredOrdering) =>
+ val minSize = Seq(requiredOrdering.size, ordering.size).min
+ requiredOrdering.take(minSize) == ordering.take(minSize)
+ case ClusteredDistribution(requiredClustering) =>
+ clusteringSet.subsetOf(requiredClustering.toSet)
+ case _ => false
+ }
+
+ override def compatibleWith(other: Partitioning) = other match {
+ case BroadcastPartitioning => true
+ case r: RangePartitioning if r == this => true
+ case _ => false
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala
new file mode 100644
index 0000000000..6ff4891a3f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package rules
+
+import trees._
+
+abstract class Rule[TreeType <: TreeNode[_]] extends Logging {
+
+ /** Name for this rule, automatically inferred based on class name. */
+ val ruleName: String = {
+ val className = getClass.getName
+ if (className endsWith "$") className.dropRight(1) else className
+ }
+
+ def apply(plan: TreeType): TreeType
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
new file mode 100644
index 0000000000..68ae30cde1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package rules
+
+import trees._
+import util._
+
+abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
+
+ /**
+ * An execution strategy for rules that indicates the maximum number of executions. If the
+ * execution reaches fix point (i.e. converge) before maxIterations, it will stop.
+ */
+ abstract class Strategy { def maxIterations: Int }
+
+ /** A strategy that only runs once. */
+ case object Once extends Strategy { val maxIterations = 1 }
+
+ /** A strategy that runs until fix point or maxIterations times, whichever comes first. */
+ case class FixedPoint(maxIterations: Int) extends Strategy
+
+ /** A batch of rules. */
+ protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)
+
+ /** Defines a sequence of rule batches, to be overridden by the implementation. */
+ protected val batches: Seq[Batch]
+
+ /**
+ * Executes the batches of rules defined by the subclass. The batches are executed serially
+ * using the defined execution strategy. Within each batch, rules are also executed serially.
+ */
+ def apply(plan: TreeType): TreeType = {
+ var curPlan = plan
+
+ batches.foreach { batch =>
+ var iteration = 1
+ var lastPlan = curPlan
+ curPlan = batch.rules.foldLeft(curPlan) { case (curPlan, rule) => rule(curPlan) }
+
+ // Run until fix point (or the max number of iterations as specified in the strategy.
+ while (iteration < batch.strategy.maxIterations && !curPlan.fastEquals(lastPlan)) {
+ lastPlan = curPlan
+ curPlan = batch.rules.foldLeft(curPlan) {
+ case (curPlan, rule) =>
+ val result = rule(curPlan)
+ if (!result.fastEquals(curPlan)) {
+ logger.debug(
+ s"""
+ |=== Applying Rule ${rule.ruleName} ===
+ |${sideBySide(curPlan.treeString, result.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+
+ result
+ }
+ iteration += 1
+ }
+ }
+
+ curPlan
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala
new file mode 100644
index 0000000000..26ab543082
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+/**
+ * A framework for applying batches rewrite rules to trees, possibly to fixed point.
+ */
+package object rules
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
new file mode 100644
index 0000000000..76ede87e4e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package trees
+
+import errors._
+
+object TreeNode {
+ private val currentId = new java.util.concurrent.atomic.AtomicLong
+ protected def nextId() = currentId.getAndIncrement()
+}
+
+/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
+private class MutableInt(var i: Int)
+
+abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
+ self: BaseType with Product =>
+
+ /** Returns a Seq of the children of this node */
+ def children: Seq[BaseType]
+
+ /**
+ * A globally unique id for this specific instance. Not preserved across copies.
+ * Unlike `equals`, `id` can be used to differentiate distinct but structurally
+ * identical branches of a tree.
+ */
+ val id = TreeNode.nextId()
+
+ /**
+ * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike
+ * `equals` this function will return false for different instances of structurally identical
+ * trees.
+ */
+ def sameInstance(other: TreeNode[_]): Boolean = {
+ this.id == other.id
+ }
+
+ /**
+ * Faster version of equality which short-circuits when two treeNodes are the same instance.
+ * We don't just override Object.Equals, as doing so prevents the scala compiler from from
+ * generating case class `equals` methods
+ */
+ def fastEquals(other: TreeNode[_]): Boolean = {
+ sameInstance(other) || this == other
+ }
+
+ /**
+ * Runs the given function on this node and then recursively on [[children]].
+ * @param f the function to be applied to each node in the tree.
+ */
+ def foreach(f: BaseType => Unit): Unit = {
+ f(this)
+ children.foreach(_.foreach(f))
+ }
+
+ /**
+ * Returns a Seq containing the result of applying the given function to each
+ * node in this tree in a preorder traversal.
+ * @param f the function to be applied.
+ */
+ def map[A](f: BaseType => A): Seq[A] = {
+ val ret = new collection.mutable.ArrayBuffer[A]()
+ foreach(ret += f(_))
+ ret
+ }
+
+ /**
+ * Returns a Seq by applying a function to all nodes in this tree and using the elements of the
+ * resulting collections.
+ */
+ def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = {
+ val ret = new collection.mutable.ArrayBuffer[A]()
+ foreach(ret ++= f(_))
+ ret
+ }
+
+ /**
+ * Returns a Seq containing the result of applying a partial function to all elements in this
+ * tree on which the function is defined.
+ */
+ def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = {
+ val ret = new collection.mutable.ArrayBuffer[B]()
+ val lifted = pf.lift
+ foreach(node => lifted(node).foreach(ret.+=))
+ ret
+ }
+
+ /**
+ * Returns a copy of this node where `f` has been applied to all the nodes children.
+ */
+ def mapChildren(f: BaseType => BaseType): this.type = {
+ var changed = false
+ val newArgs = productIterator.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = f(arg.asInstanceOf[BaseType])
+ if (newChild fastEquals arg) {
+ arg
+ } else {
+ changed = true
+ newChild
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.toArray
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /**
+ * Returns a copy of this node with the children replaced.
+ * TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
+ */
+ def withNewChildren(newChildren: Seq[BaseType]): this.type = {
+ assert(newChildren.size == children.size, "Incorrect number of children")
+ var changed = false
+ val remainingNewChildren = newChildren.toBuffer
+ val remainingOldChildren = children.toBuffer
+ val newArgs = productIterator.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = remainingNewChildren.remove(0)
+ val oldChild = remainingOldChildren.remove(0)
+ if (newChild fastEquals oldChild) {
+ oldChild
+ } else {
+ changed = true
+ newChild
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.toArray
+
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to the tree.
+ * When `rule` does not apply to a given node it is left unchanged.
+ * Users should not expect a specific directionality. If a specific directionality is needed,
+ * transformDown or transformUp should be used.
+ * @param rule the function use to transform this nodes children
+ */
+ def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
+ transformDown(rule)
+ }
+
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to it and all of its
+ * children (pre-order). When `rule` does not apply to a given node it is left unchanged.
+ * @param rule the function used to transform this nodes children
+ */
+ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
+ val afterRule = rule.applyOrElse(this, identity[BaseType])
+ // Check if unchanged and then possibly return old copy to avoid gc churn.
+ if (this fastEquals afterRule) {
+ transformChildrenDown(rule)
+ } else {
+ afterRule.transformChildrenDown(rule)
+ }
+ }
+
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to all the children of
+ * this node. When `rule` does not apply to a given node it is left unchanged.
+ * @param rule the function used to transform this nodes children
+ */
+ def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
+ var changed = false
+ val newArgs = productIterator.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case m: Map[_,_] => m
+ case args: Traversable[_] => args.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case other => other
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.toArray
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied first to all of its
+ * children and then itself (post-order). When `rule` does not apply to a given node, it is left
+ * unchanged.
+ * @param rule the function use to transform this nodes children
+ */
+ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
+ val afterRuleOnChildren = transformChildrenUp(rule);
+ if (this fastEquals afterRuleOnChildren) {
+ rule.applyOrElse(this, identity[BaseType])
+ } else {
+ rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ }
+ }
+
+ def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
+ var changed = false
+ val newArgs = productIterator.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case m: Map[_,_] => m
+ case args: Traversable[_] => args.map {
+ case arg: TreeNode[_] if children contains arg =>
+ val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case other => other
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.toArray
+ if (changed) makeCopy(newArgs) else this
+ }
+
+ /**
+ * Args to the constructor that should be copied, but not transformed.
+ * These are appended to the transformed args automatically by makeCopy
+ * @return
+ */
+ protected def otherCopyArgs: Seq[AnyRef] = Nil
+
+ /**
+ * Creates a copy of this type of tree node after a transformation.
+ * Must be overridden by child classes that have constructor arguments
+ * that are not present in the productIterator.
+ * @param newArgs the new product arguments.
+ */
+ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ try {
+ val defaultCtor = getClass.getConstructors.head
+ if (otherCopyArgs.isEmpty) {
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
+ } else {
+ defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ }
+ } catch {
+ case e: java.lang.IllegalArgumentException =>
+ throw new TreeNodeException(
+ this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?")
+ }
+ }
+
+ /** Returns the name of this type of TreeNode. Defaults to the class name. */
+ def nodeName = getClass.getSimpleName
+
+ /**
+ * The arguments that should be included in the arg string. Defaults to the `productIterator`.
+ */
+ protected def stringArgs = productIterator
+
+ /** Returns a string representing the arguments to this node, minus any children */
+ def argString: String = productIterator.flatMap {
+ case tn: TreeNode[_] if children contains tn => Nil
+ case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil
+ case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil
+ case set: Set[_] => set.mkString("{", ",", "}") :: Nil
+ case other => other :: Nil
+ }.mkString(", ")
+
+ /** String representation of this node without any children */
+ def simpleString = s"$nodeName $argString"
+
+ override def toString: String = treeString
+
+ /** Returns a string representation of the nodes in this tree */
+ def treeString = generateTreeString(0, new StringBuilder).toString
+
+ /**
+ * Returns a string representation of the nodes in this tree, where each operator is numbered.
+ * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees.
+ */
+ def numberedTreeString =
+ treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n")
+
+ /**
+ * Returns the tree node at the specified number.
+ * Numbers for each node can be found in the [[numberedTreeString]].
+ */
+ def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number))
+
+ protected def getNodeNumbered(number: MutableInt): BaseType = {
+ if (number.i < 0) {
+ null.asInstanceOf[BaseType]
+ } else if (number.i == 0) {
+ this
+ } else {
+ number.i -= 1
+ children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType])
+ }
+ }
+
+ /** Appends the string represent of this node and its children to the given StringBuilder. */
+ protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = {
+ builder.append(" " * depth)
+ builder.append(simpleString)
+ builder.append("\n")
+ children.foreach(_.generateTreeString(depth + 1, builder))
+ builder
+ }
+}
+
+/**
+ * A [[TreeNode]] that has two children, [[left]] and [[right]].
+ */
+trait BinaryNode[BaseType <: TreeNode[BaseType]] {
+ def left: BaseType
+ def right: BaseType
+
+ def children = Seq(left, right)
+}
+
+/**
+ * A [[TreeNode]] with no children.
+ */
+trait LeafNode[BaseType <: TreeNode[BaseType]] {
+ def children = Nil
+}
+
+/**
+ * A [[TreeNode]] with a single [[child]].
+ */
+trait UnaryNode[BaseType <: TreeNode[BaseType]] {
+ def child: BaseType
+ def children = child :: Nil
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
new file mode 100644
index 0000000000..e2da1d2439
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+/**
+ * A library for easily manipulating trees of operators. Operators that extend TreeNode are
+ * granted the following interface:
+ * <ul>
+ * <li>Scala collection like methods (foreach, map, flatMap, collect, etc)</li>
+ * <li>
+ * transform - accepts a partial function that is used to generate a new tree. When the
+ * partial function can be applied to a given tree segment, that segment is replaced with the
+ * result. After attempting to apply the partial function to a given node, the transform
+ * function recursively attempts to apply the function to that node's children.
+ * </li>
+ * <li>debugging support - pretty printing, easy splicing of trees, etc.</li>
+ * </ul>
+ */
+package object trees {
+ // Since we want tree nodes to be lightweight, we create one logger for all treenode instances.
+ protected val logger = Logger("catalyst.trees")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
new file mode 100644
index 0000000000..6eb2b62ecc
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package types
+
+import expressions.Expression
+
+abstract class DataType {
+ /** Matches any expression that evaluates to this DataType */
+ def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType == this => true
+ case _ => false
+ }
+}
+
+case object NullType extends DataType
+
+abstract class NativeType extends DataType {
+ type JvmType
+ val ordering: Ordering[JvmType]
+}
+
+case object StringType extends NativeType {
+ type JvmType = String
+ val ordering = implicitly[Ordering[JvmType]]
+}
+case object BinaryType extends DataType {
+ type JvmType = Array[Byte]
+}
+case object BooleanType extends NativeType {
+ type JvmType = Boolean
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+abstract class NumericType extends NativeType {
+ // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
+ // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
+ // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
+ // desugared by the compiler into an argument to the objects constructor. This means there is no
+ // longer an no argument constructor and thus the JVM cannot serialize the object anymore.
+ val numeric: Numeric[JvmType]
+}
+
+/** Matcher for any expressions that evaluate to [[IntegralType]]s */
+object IntegralType {
+ def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
+ case _ => false
+ }
+}
+
+abstract class IntegralType extends NumericType {
+ val integral: Integral[JvmType]
+}
+
+case object LongType extends IntegralType {
+ type JvmType = Long
+ val numeric = implicitly[Numeric[Long]]
+ val integral = implicitly[Integral[Long]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case object IntegerType extends IntegralType {
+ type JvmType = Int
+ val numeric = implicitly[Numeric[Int]]
+ val integral = implicitly[Integral[Int]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case object ShortType extends IntegralType {
+ type JvmType = Short
+ val numeric = implicitly[Numeric[Short]]
+ val integral = implicitly[Integral[Short]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case object ByteType extends IntegralType {
+ type JvmType = Byte
+ val numeric = implicitly[Numeric[Byte]]
+ val integral = implicitly[Integral[Byte]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+/** Matcher for any expressions that evaluate to [[FractionalType]]s */
+object FractionalType {
+ def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
+ case _ => false
+ }
+}
+abstract class FractionalType extends NumericType {
+ val fractional: Fractional[JvmType]
+}
+
+case object DecimalType extends FractionalType {
+ type JvmType = BigDecimal
+ val numeric = implicitly[Numeric[BigDecimal]]
+ val fractional = implicitly[Fractional[BigDecimal]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case object DoubleType extends FractionalType {
+ type JvmType = Double
+ val numeric = implicitly[Numeric[Double]]
+ val fractional = implicitly[Fractional[Double]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case object FloatType extends FractionalType {
+ type JvmType = Float
+ val numeric = implicitly[Numeric[Float]]
+ val fractional = implicitly[Fractional[Float]]
+ val ordering = implicitly[Ordering[JvmType]]
+}
+
+case class ArrayType(elementType: DataType) extends DataType
+
+case class StructField(name: String, dataType: DataType, nullable: Boolean)
+case class StructType(fields: Seq[StructField]) extends DataType
+
+case class MapType(keyType: DataType, valueType: DataType) extends DataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala
new file mode 100644
index 0000000000..b65a5617d9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+/**
+ * Contains a type system for attributes produced by relations, including complex types like
+ * structs, arrays and maps.
+ */
+package object types
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
new file mode 100644
index 0000000000..52adea2661
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}
+
+package object util {
+ /**
+ * Returns a path to a temporary file that probably does not exist.
+ * Note, there is always the race condition that someone created this
+ * file since the last time we checked. Thus, this shouldn't be used
+ * for anything security conscious.
+ */
+ def getTempFilePath(prefix: String, suffix: String = ""): File = {
+ val tempFile = File.createTempFile(prefix, suffix)
+ tempFile.delete()
+ tempFile
+ }
+
+ def fileToString(file: File, encoding: String = "UTF-8") = {
+ val inStream = new FileInputStream(file)
+ val outStream = new ByteArrayOutputStream
+ try {
+ var reading = true
+ while ( reading ) {
+ inStream.read() match {
+ case -1 => reading = false
+ case c => outStream.write(c)
+ }
+ }
+ outStream.flush()
+ }
+ finally {
+ inStream.close()
+ }
+ new String(outStream.toByteArray, encoding)
+ }
+
+ def resourceToString(
+ resource:String,
+ encoding: String = "UTF-8",
+ classLoader: ClassLoader = this.getClass.getClassLoader) = {
+ val inStream = classLoader.getResourceAsStream(resource)
+ val outStream = new ByteArrayOutputStream
+ try {
+ var reading = true
+ while ( reading ) {
+ inStream.read() match {
+ case -1 => reading = false
+ case c => outStream.write(c)
+ }
+ }
+ outStream.flush()
+ }
+ finally {
+ inStream.close()
+ }
+ new String(outStream.toByteArray, encoding)
+ }
+
+ def stringToFile(file: File, str: String): File = {
+ val out = new PrintWriter(file)
+ out.write(str)
+ out.close()
+ file
+ }
+
+ def sideBySide(left: String, right: String): Seq[String] = {
+ sideBySide(left.split("\n"), right.split("\n"))
+ }
+
+ def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
+ val maxLeftSize = left.map(_.size).max
+ val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
+ val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")
+
+ leftPadded.zip(rightPadded).map {
+ case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r
+ }
+ }
+
+ def stackTraceToString(t: Throwable): String = {
+ val out = new java.io.ByteArrayOutputStream
+ val writer = new PrintWriter(out)
+ t.printStackTrace(writer)
+ writer.flush()
+ new String(out.toByteArray)
+ }
+
+ def stringOrNull(a: AnyRef) = if (a == null) null else a.toString
+
+ def benchmark[A](f: => A): A = {
+ val startTime = System.nanoTime()
+ val ret = f
+ val endTime = System.nanoTime()
+ println(s"${(endTime - startTime).toDouble / 1000000}ms")
+ ret
+ }
+
+ /* FIX ME
+ implicit class debugLogging(a: AnyRef) {
+ def debugLogging() {
+ org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG)
+ }
+ } */
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
new file mode 100644
index 0000000000..9ec31689b5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Allows the execution of relational queries, including those expressed in SQL using Spark.
+ *
+ * Note that this package is located in catalyst instead of in core so that all subprojects can
+ * inherit the settings from this package object.
+ */
+package object sql {
+
+ protected[sql] def Logger(name: String) =
+ com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name))
+
+ protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging
+
+ type Row = catalyst.expressions.Row
+
+ object Row {
+ /**
+ * This method can be used to extract fields from a [[Row]] object in a pattern match. Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val pairs = sql("SELECT key, value FROM src").rdd.map {
+ * case Row(key: Int, value: String) =>
+ * key -> value
+ * }
+ * }}}
+ */
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala
new file mode 100644
index 0000000000..1fd0d26b6f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import org.scalatest.FunSuite
+
+import analysis._
+import expressions._
+import plans.logical._
+import types._
+
+import dsl._
+import dsl.expressions._
+
+class AnalysisSuite extends FunSuite {
+ val analyze = SimpleAnalyzer
+
+ val testRelation = LocalRelation('a.int)
+
+ test("analyze project") {
+ assert(analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation))
+
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
new file mode 100644
index 0000000000..fb25e1c246
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.plans.physical._
+
+/* Implicit conversions */
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class DistributionSuite extends FunSuite {
+
+ protected def checkSatisfied(
+ inputPartitioning: Partitioning,
+ requiredDistribution: Distribution,
+ satisfied: Boolean) {
+ if (inputPartitioning.satisfies(requiredDistribution) != satisfied)
+ fail(
+ s"""
+ |== Input Partitioning ==
+ |$inputPartitioning
+ |== Required Distribution ==
+ |$requiredDistribution
+ |== Does input partitioning satisfy required distribution? ==
+ |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)}
+ """.stripMargin)
+ }
+
+ test("HashPartitioning is the output partitioning") {
+ // Cases which do not need an exchange between two data properties.
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ UnspecifiedDistribution,
+ true)
+
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ ClusteredDistribution(Seq('a, 'b, 'c)),
+ true)
+
+ checkSatisfied(
+ HashPartitioning(Seq('b, 'c), 10),
+ ClusteredDistribution(Seq('a, 'b, 'c)),
+ true)
+
+ checkSatisfied(
+ SinglePartition,
+ ClusteredDistribution(Seq('a, 'b, 'c)),
+ true)
+
+ checkSatisfied(
+ SinglePartition,
+ OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
+ true)
+
+ // Cases which need an exchange between two data properties.
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ ClusteredDistribution(Seq('b, 'c)),
+ false)
+
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ ClusteredDistribution(Seq('d, 'e)),
+ false)
+
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ AllTuples,
+ false)
+
+ checkSatisfied(
+ HashPartitioning(Seq('a, 'b, 'c), 10),
+ OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
+ false)
+
+ checkSatisfied(
+ HashPartitioning(Seq('b, 'c), 10),
+ OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
+ false)
+
+ // TODO: We should check functional dependencies
+ /*
+ checkSatisfied(
+ ClusteredDistribution(Seq('b)),
+ ClusteredDistribution(Seq('b + 1)),
+ true)
+ */
+ }
+
+ test("RangePartitioning is the output partitioning") {
+ // Cases which do not need an exchange between two data properties.
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ UnspecifiedDistribution,
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ OrderedDistribution(Seq('a.asc, 'b.asc)),
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)),
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ ClusteredDistribution(Seq('a, 'b, 'c)),
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ ClusteredDistribution(Seq('c, 'b, 'a)),
+ true)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
+ true)
+
+ // Cases which need an exchange between two data properties.
+ // TODO: We can have an optimization to first sort the dataset
+ // by a.asc and then sort b, and c in a partition. This optimization
+ // should tradeoff the benefit of a less number of Exchange operators
+ // and the parallelism.
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)),
+ false)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ OrderedDistribution(Seq('b.asc, 'a.asc)),
+ false)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ ClusteredDistribution(Seq('a, 'b)),
+ false)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ ClusteredDistribution(Seq('c, 'd)),
+ false)
+
+ checkSatisfied(
+ RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
+ AllTuples,
+ false)
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala
new file mode 100644
index 0000000000..f06618ad11
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package expressions
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+/* Implict conversions */
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class ExpressionEvaluationSuite extends FunSuite {
+
+ test("literals") {
+ assert((Literal(1) + Literal(1)).apply(null) === 2)
+ }
+
+ /**
+ * Checks for three-valued-logic. Based on:
+ * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29
+ *
+ * p q p OR q p AND q p = q
+ * True True True True True
+ * True False True False False
+ * True Unknown True Unknown Unknown
+ * False True True False False
+ * False False False False True
+ * False Unknown Unknown False Unknown
+ * Unknown True True Unknown Unknown
+ * Unknown False Unknown False Unknown
+ * Unknown Unknown Unknown Unknown Unknown
+ *
+ * p NOT p
+ * True False
+ * False True
+ * Unknown Unknown
+ */
+
+ val notTrueTable =
+ (true, false) ::
+ (false, true) ::
+ (null, null) :: Nil
+
+ test("3VL Not") {
+ notTrueTable.foreach {
+ case (v, answer) =>
+ val expr = Not(Literal(v, BooleanType))
+ val result = expr.apply(null)
+ if (result != answer)
+ fail(s"$expr should not evaluate to $result, expected: $answer") }
+ }
+
+ booleanLogicTest("AND", _ && _,
+ (true, true, true) ::
+ (true, false, false) ::
+ (true, null, null) ::
+ (false, true, false) ::
+ (false, false, false) ::
+ (false, null, false) ::
+ (null, true, null) ::
+ (null, false, false) ::
+ (null, null, null) :: Nil)
+
+ booleanLogicTest("OR", _ || _,
+ (true, true, true) ::
+ (true, false, true) ::
+ (true, null, true) ::
+ (false, true, true) ::
+ (false, false, false) ::
+ (false, null, null) ::
+ (null, true, true) ::
+ (null, false, null) ::
+ (null, null, null) :: Nil)
+
+ booleanLogicTest("=", _ === _,
+ (true, true, true) ::
+ (true, false, false) ::
+ (true, null, null) ::
+ (false, true, false) ::
+ (false, false, true) ::
+ (false, null, null) ::
+ (null, true, null) ::
+ (null, false, null) ::
+ (null, null, null) :: Nil)
+
+ def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) {
+ test(s"3VL $name") {
+ truthTable.foreach {
+ case (l,r,answer) =>
+ val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
+ val result = expr.apply(null)
+ if (result != answer)
+ fail(s"$expr should not evaluate to $result, expected: $answer")
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala
new file mode 100644
index 0000000000..f595bf7e44
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package analysis
+
+import org.scalatest.FunSuite
+
+import catalyst.types._
+
+
+class HiveTypeCoercionSuite extends FunSuite {
+
+ val rules = new HiveTypeCoercion { }
+ import rules._
+
+ test("tightest common bound for numeric and boolean types") {
+ def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
+ var found = WidenTypes.findTightestCommonType(t1, t2)
+ assert(found == tightestCommon,
+ s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
+ // Test both directions to make sure the widening is symmetric.
+ found = WidenTypes.findTightestCommonType(t2, t1)
+ assert(found == tightestCommon,
+ s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
+ }
+
+ // Boolean
+ widenTest(NullType, BooleanType, Some(BooleanType))
+ widenTest(BooleanType, BooleanType, Some(BooleanType))
+ widenTest(IntegerType, BooleanType, None)
+ widenTest(LongType, BooleanType, None)
+
+ // Integral
+ widenTest(NullType, ByteType, Some(ByteType))
+ widenTest(NullType, IntegerType, Some(IntegerType))
+ widenTest(NullType, LongType, Some(LongType))
+ widenTest(ShortType, IntegerType, Some(IntegerType))
+ widenTest(ShortType, LongType, Some(LongType))
+ widenTest(IntegerType, LongType, Some(LongType))
+ widenTest(LongType, LongType, Some(LongType))
+
+ // Floating point
+ widenTest(NullType, FloatType, Some(FloatType))
+ widenTest(NullType, DoubleType, Some(DoubleType))
+ widenTest(FloatType, DoubleType, Some(DoubleType))
+ widenTest(FloatType, FloatType, Some(FloatType))
+ widenTest(DoubleType, DoubleType, Some(DoubleType))
+
+ // Integral mixed with floating point.
+ widenTest(NullType, FloatType, Some(FloatType))
+ widenTest(NullType, DoubleType, Some(DoubleType))
+ widenTest(IntegerType, FloatType, Some(FloatType))
+ widenTest(IntegerType, DoubleType, Some(DoubleType))
+ widenTest(IntegerType, DoubleType, Some(DoubleType))
+ widenTest(LongType, FloatType, Some(FloatType))
+ widenTest(LongType, DoubleType, Some(DoubleType))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala
new file mode 100644
index 0000000000..ff7c15b718
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package trees
+
+import org.scalatest.FunSuite
+
+import expressions._
+import rules._
+
+class RuleExecutorSuite extends FunSuite {
+ object DecrementLiterals extends Rule[Expression] {
+ def apply(e: Expression): Expression = e transform {
+ case IntegerLiteral(i) if i > 0 => Literal(i - 1)
+ }
+ }
+
+ test("only once") {
+ object ApplyOnce extends RuleExecutor[Expression] {
+ val batches = Batch("once", Once, DecrementLiterals) :: Nil
+ }
+
+ assert(ApplyOnce(Literal(10)) === Literal(9))
+ }
+
+ test("to fixed point") {
+ object ToFixedPoint extends RuleExecutor[Expression] {
+ val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil
+ }
+
+ assert(ToFixedPoint(Literal(10)) === Literal(0))
+ }
+
+ test("to maxIterations") {
+ object ToFixedPoint extends RuleExecutor[Expression] {
+ val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
+ }
+
+ assert(ToFixedPoint(Literal(100)) === Literal(90))
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala
new file mode 100644
index 0000000000..98bb090c29
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package trees
+
+import scala.collection.mutable.ArrayBuffer
+
+import expressions._
+
+import org.scalatest.{FunSuite}
+
+class TreeNodeSuite extends FunSuite {
+
+ test("top node changed") {
+ val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
+ assert(after === Literal(2))
+ }
+
+ test("one child changed") {
+ val before = Add(Literal(1), Literal(2))
+ val after = before transform { case Literal(2, _) => Literal(1) }
+
+ assert(after === Add(Literal(1), Literal(1)))
+ }
+
+ test("no change") {
+ val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4))))
+ val after = before transform { case Literal(5, _) => Literal(1)}
+
+ assert(before === after)
+ assert(before.map(_.id) === after.map(_.id))
+ }
+
+ test("collect") {
+ val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4))))
+ val literals = tree collect {case l: Literal => l}
+
+ assert(literals.size === 4)
+ (1 to 4).foreach(i => assert(literals contains Literal(i)))
+ }
+
+ test("pre-order transform") {
+ val actual = new ArrayBuffer[String]()
+ val expected = Seq("+", "1", "*", "2", "-", "3", "4")
+ val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+ expression transformDown {
+ case b: BinaryExpression => {actual.append(b.symbol); b}
+ case l: Literal => {actual.append(l.toString); l}
+ }
+
+ assert(expected === actual)
+ }
+
+ test("post-order transform") {
+ val actual = new ArrayBuffer[String]()
+ val expected = Seq("1", "2", "3", "4", "-", "*", "+")
+ val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+ expression transformUp {
+ case b: BinaryExpression => {actual.append(b.symbol); b}
+ case l: Literal => {actual.append(l.toString); l}
+ }
+
+ assert(expected === actual)
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
new file mode 100644
index 0000000000..7ce42b2b0a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package catalyst
+package optimizer
+
+import types.IntegerType
+import util._
+import plans.logical.{LogicalPlan, LocalRelation}
+import rules._
+import expressions._
+import dsl.plans._
+import dsl.expressions._
+
+class ConstantFoldingSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubqueries) ::
+ Batch("ConstantFolding", Once,
+ ConstantFolding,
+ BooleanSimplification) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("eliminate subqueries") {
+ val originalQuery =
+ testRelation
+ .subquery('y)
+ .select('a)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a.attr)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ /**
+ * Unit tests for constant folding in expressions.
+ */
+ test("Constant folding test: expressions only have literals") {
+ val originalQuery =
+ testRelation
+ .select(
+ Literal(2) + Literal(3) + Literal(4) as Symbol("2+3+4"),
+ Literal(2) * Literal(3) + Literal(4) as Symbol("2*3+4"),
+ Literal(2) * (Literal(3) + Literal(4)) as Symbol("2*(3+4)"))
+ .where(
+ Literal(1) === Literal(1) &&
+ Literal(2) > Literal(3) ||
+ Literal(3) > Literal(2) )
+ .groupBy(
+ Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2))
+ )(Literal(9) / Literal(3) as Symbol("9/3"))
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Literal(9) as Symbol("2+3+4"),
+ Literal(10) as Symbol("2*3+4"),
+ Literal(14) as Symbol("2*(3+4)"))
+ .where(Literal(true))
+ .groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Constant folding test: expressions have attribute references and literals in " +
+ "arithmetic operations") {
+ val originalQuery =
+ testRelation
+ .select(
+ Literal(2) + Literal(3) + 'a as Symbol("c1"),
+ 'a + Literal(2) + Literal(3) as Symbol("c2"),
+ Literal(2) * 'a + Literal(4) as Symbol("c3"),
+ 'a * (Literal(3) + Literal(4)) as Symbol("c4"))
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Literal(5) + 'a as Symbol("c1"),
+ 'a + Literal(2) + Literal(3) as Symbol("c2"),
+ Literal(2) * 'a + Literal(4) as Symbol("c3"),
+ 'a * (Literal(7)) as Symbol("c4"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Constant folding test: expressions have attribute references and literals in " +
+ "predicates") {
+ val originalQuery =
+ testRelation
+ .where(
+ (('a > 1 && Literal(1) === Literal(1)) ||
+ ('a < 10 && Literal(1) === Literal(2)) ||
+ (Literal(1) === Literal(1) && 'b > 1) ||
+ (Literal(1) === Literal(2) && 'b < 10)) &&
+ (('a > 1 || Literal(1) === Literal(1)) &&
+ ('a < 10 || Literal(1) === Literal(2)) &&
+ (Literal(1) === Literal(1) || 'b > 1) &&
+ (Literal(1) === Literal(2) || 'b < 10)))
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .where(('a > 1 || 'b > 1) && ('a < 10 && 'b < 10))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Constant folding test: expressions have foldable functions") {
+ val originalQuery =
+ testRelation
+ .select(
+ Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"),
+ Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2"))
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Literal(5) + 'a as Symbol("c1"),
+ Literal(3) as Symbol("c2"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Constant folding test: expressions have nonfoldable functions") {
+ val originalQuery =
+ testRelation
+ .select(
+ Rand + Literal(1) as Symbol("c1"),
+ Sum('a) as Symbol("c2"))
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Rand + Literal(1.0) as Symbol("c1"),
+ Sum('a) as Symbol("c2"))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
new file mode 100644
index 0000000000..cd611b3fb3
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -0,0 +1,222 @@
+package org.apache.spark.sql
+package catalyst
+package optimizer
+
+import expressions._
+import plans.logical._
+import rules._
+import util._
+
+import dsl.plans._
+import dsl.expressions._
+
+class FilterPushdownSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubqueries) ::
+ Batch("Filter Pushdown", Once,
+ EliminateSubqueries,
+ CombineFilters,
+ PushPredicateThroughProject,
+ PushPredicateThroughInnerJoin) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ // This test already passes.
+ test("eliminate subqueries") {
+ val originalQuery =
+ testRelation
+ .subquery('y)
+ .select('a)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a.attr)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ // After this line is unimplemented.
+ test("simple push down") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .where('a === 1)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where('a === 1)
+ .select('a)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("can't push without rewrite") {
+ val originalQuery =
+ testRelation
+ .select('a + 'b as 'e)
+ .where('e === 1)
+ .analyze
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where('a + 'b === 1)
+ .select('a + 'b as 'e)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("filters: combines filters") {
+ val originalQuery = testRelation
+ .select('a)
+ .where('a === 1)
+ .where('a === 2)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where('a === 1 && 'a === 2)
+ .select('a).analyze
+
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+
+ test("joins: push to either side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y)
+ .where("x.b".attr === 1)
+ .where("y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 1)
+ val right = testRelation.where('b === 2)
+ val correctAnswer =
+ left.join(right).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push to one side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y)
+ .where("x.b".attr === 1)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 1)
+ val right = testRelation
+ val correctAnswer =
+ left.join(right).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: rewrite filter to push to either side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y)
+ .where("x.b".attr === 1 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 1)
+ val right = testRelation.where('b === 2)
+ val correctAnswer =
+ left.join(right).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: can't push down") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, condition = Some("x.b".attr === "y.b".attr))
+ }
+ val optimized = Optimize(originalQuery.analyze)
+
+ comparePlans(optimizer.EliminateSubqueries(originalQuery.analyze), optimized)
+ }
+
+ test("joins: conjunctive predicates") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y)
+ .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1))
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('a === 1).subquery('x)
+ val right = testRelation.where('a === 1).subquery('y)
+ val correctAnswer =
+ left.join(right, condition = Some("x.b".attr === "y.b".attr))
+ .analyze
+
+ comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer))
+ }
+
+ test("joins: conjunctive predicates #2") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y)
+ .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1))
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('a === 1).subquery('x)
+ val right = testRelation.subquery('y)
+ val correctAnswer =
+ left.join(right, condition = Some("x.b".attr === "y.b".attr))
+ .analyze
+
+ comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer))
+ }
+
+ test("joins: conjunctive predicates #3") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val z = testRelation.subquery('z)
+
+ val originalQuery = {
+ z.join(x.join(y))
+ .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr))
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val lleft = testRelation.where('a >= 3).subquery('z)
+ val left = testRelation.where('a === 1).subquery('x)
+ val right = testRelation.subquery('y)
+ val correctAnswer =
+ lleft.join(
+ left.join(right, condition = Some("x.b".attr === "y.b".attr)),
+ condition = Some("z.a".attr === "x.b".attr))
+ .analyze
+
+ comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer))
+ }
+} \ No newline at end of file
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala
new file mode 100644
index 0000000000..7b3653d0f9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala
@@ -0,0 +1,44 @@
+package org.apache.spark.sql
+package catalyst
+package optimizer
+
+import org.scalatest.FunSuite
+
+import types.IntegerType
+import util._
+import plans.logical.{LogicalPlan, LocalRelation}
+import expressions._
+import dsl._
+
+/* Implicit conversions for creating query plans */
+
+/**
+ * Provides helper methods for comparing plans produced by optimization rules with the expected
+ * result
+ */
+class OptimizerTest extends FunSuite {
+
+ /**
+ * Since attribute references are given globally unique ids during analysis,
+ * we must normalize them to check if two different queries are identical.
+ */
+ protected def normalizeExprIds(plan: LogicalPlan) = {
+ val minId = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)).min
+ plan transformAllExpressions {
+ case a: AttributeReference =>
+ AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
+ }
+ }
+
+ /** Fails the test if the two plans do not match */
+ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
+ val normalized1 = normalizeExprIds(plan1)
+ val normalized2 = normalizeExprIds(plan2)
+ if (normalized1 != normalized2)
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+} \ No newline at end of file