diff options
Diffstat (limited to 'sql/catalyst')
62 files changed, 6538 insertions, 0 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml new file mode 100644 index 0000000000..740f1fdc83 --- /dev/null +++ b/sql/catalyst/pom.xml @@ -0,0 +1,66 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent</artifactId> + <version>1.0.0-SNAPSHOT</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <groupId>org.apache.spark</groupId> + <artifactId>spark-catalyst_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Catalyst</name> + <url>http://spark.apache.org/</url> + + <dependencies> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.typesafe</groupId> + <artifactId>scalalogging-slf4j_${scala.binary.version}</artifactId> + <version>1.0.1</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalacheck</groupId> + <artifactId>scalacheck_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + <plugins> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + </plugin> + </plugins> + </build> +</project> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala new file mode 100644 index 0000000000..d3b1070a58 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.util.matching.Regex +import scala.util.parsing.combinator._ +import scala.util.parsing.input.CharArrayReader.EofCh +import lexical._ +import syntactical._ +import token._ + +import analysis._ +import expressions._ +import plans._ +import plans.logical._ +import types._ + +/** + * A very simple SQL parser. Based loosly on: + * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala + * + * Limitations: + * - Only supports a very limited subset of SQL. + * - Keywords must be capital. + * + * This is currently included mostly for illustrative purposes. Users wanting more complete support + * for a SQL like language should checkout the HiveQL support in the sql/hive subproject. + */ +class SqlParser extends StandardTokenParsers { + + def apply(input: String): LogicalPlan = { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } + } + + protected case class Keyword(str: String) + protected implicit def asParser(k: Keyword): Parser[String] = k.str + + protected class SqlLexical extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('.') | elem('_') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + } + + override val lexical = new SqlLexical + + protected val ALL = Keyword("ALL") + protected val AND = Keyword("AND") + protected val AS = Keyword("AS") + protected val ASC = Keyword("ASC") + protected val AVG = Keyword("AVG") + protected val BY = Keyword("BY") + protected val CAST = Keyword("CAST") + protected val COUNT = Keyword("COUNT") + protected val DESC = Keyword("DESC") + protected val DISTINCT = Keyword("DISTINCT") + protected val FALSE = Keyword("FALSE") + protected val FIRST = Keyword("FIRST") + protected val FROM = Keyword("FROM") + protected val FULL = Keyword("FULL") + protected val GROUP = Keyword("GROUP") + protected val HAVING = Keyword("HAVING") + protected val IF = Keyword("IF") + protected val IN = Keyword("IN") + protected val INNER = Keyword("INNER") + protected val IS = Keyword("IS") + protected val JOIN = Keyword("JOIN") + protected val LEFT = Keyword("LEFT") + protected val LIMIT = Keyword("LIMIT") + protected val NOT = Keyword("NOT") + protected val NULL = Keyword("NULL") + protected val ON = Keyword("ON") + protected val OR = Keyword("OR") + protected val ORDER = Keyword("ORDER") + protected val OUTER = Keyword("OUTER") + protected val RIGHT = Keyword("RIGHT") + protected val SELECT = Keyword("SELECT") + protected val STRING = Keyword("STRING") + protected val SUM = Keyword("SUM") + protected val TRUE = Keyword("TRUE") + protected val UNION = Keyword("UNION") + protected val WHERE = Keyword("WHERE") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword]) + + lexical.reserved ++= reservedWords.map(_.str) + + lexical.delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":" + ) + + protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"c$i")() + } + } + + protected lazy val query: Parser[LogicalPlan] = + select * ( + UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | + UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + ) + + protected lazy val select: Parser[LogicalPlan] = + SELECT ~> opt(DISTINCT) ~ projections ~ + opt(from) ~ opt(filter) ~ + opt(grouping) ~ + opt(having) ~ + opt(orderBy) ~ + opt(limit) <~ opt(";") ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = + g.map {g => + Aggregate(assignAliases(g), assignAliases(p), withFilter) + }.getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) + val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder) + withLimit + } + + protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") + + protected lazy val projection: Parser[Expression] = + expression ~ (opt(AS) ~> opt(ident)) ^^ { + case e ~ None => e + case e ~ Some(a) => Alias(e, a)() + } + + protected lazy val from: Parser[LogicalPlan] = FROM ~> relations + + // Based very loosly on the MySQL Grammar. + // http://dev.mysql.com/doc/refman/5.0/en/join.html + protected lazy val relations: Parser[LogicalPlan] = + relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | + relation + + protected lazy val relation: Parser[LogicalPlan] = + joinedRelation | + relationFactor + + protected lazy val relationFactor: Parser[LogicalPlan] = + ident ~ (opt(AS) ~> opt(ident)) ^^ { + case ident ~ alias => UnresolvedRelation(alias, ident) + } | + "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + + protected lazy val joinedRelation: Parser[LogicalPlan] = + relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { + case r1 ~ jt ~ _ ~ r2 ~ cond => + Join(r1, r2, joinType = jt.getOrElse(Inner), cond) + } + + protected lazy val joinConditions: Parser[Expression] = + ON ~> expression + + protected lazy val joinType: Parser[JoinType] = + INNER ^^^ Inner | + LEFT ~ opt(OUTER) ^^^ LeftOuter | + RIGHT ~ opt(OUTER) ^^^ RightOuter | + FULL ~ opt(OUTER) ^^^ FullOuter + + protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } + + protected lazy val orderBy: Parser[Seq[SortOrder]] = + ORDER ~> BY ~> ordering + + protected lazy val ordering: Parser[Seq[SortOrder]] = + rep1sep(singleOrder, ",") | + rep1sep(expression, ",") ~ opt(direction) ^^ { + case exps ~ None => exps.map(SortOrder(_, Ascending)) + case exps ~ Some(d) => exps.map(SortOrder(_, d)) + } + + protected lazy val singleOrder: Parser[SortOrder] = + expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + + protected lazy val direction: Parser[SortDirection] = + ASC ^^^ Ascending | + DESC ^^^ Descending + + protected lazy val grouping: Parser[Seq[Expression]] = + GROUP ~> BY ~> rep1sep(expression, ",") + + protected lazy val having: Parser[Expression] = + HAVING ~> expression + + protected lazy val limit: Parser[Expression] = + LIMIT ~> expression + + protected lazy val expression: Parser[Expression] = orExpression + + protected lazy val orExpression: Parser[Expression] = + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + + protected lazy val andExpression: Parser[Expression] = + comparisionExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + + protected lazy val comparisionExpression: Parser[Expression] = + termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } | + termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | + termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | + termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | + termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | + termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { + case e1 ~ _ ~ _ ~ e2 => In(e1, e2) + } | + termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { + case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) + } | + termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | + termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | + NOT ~> termExpression ^^ {e => Not(e)} | + termExpression + + protected lazy val termExpression: Parser[Expression] = + productExpression * ( + "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | + "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + + protected lazy val productExpression: Parser[Expression] = + baseExpression * ( + "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | + "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | + "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } + ) + + protected lazy val function: Parser[Expression] = + SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | + SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | + COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | + COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | + COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | + FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | + AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | + IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { + case c ~ "," ~ t ~ "," ~ f => If(c,t,f) + } | + ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { + case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) + } + + protected lazy val cast: Parser[Expression] = + CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + + protected lazy val literal: Parser[Literal] = + numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } | + NULL ^^^ Literal(null, NullType) | + floatLit ^^ {case f => Literal(f.toDouble) } | + stringLit ^^ {case s => Literal(s, StringType) } + + protected lazy val floatLit: Parser[String] = + elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + + protected lazy val baseExpression: Parser[Expression] = + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal + + protected lazy val dataType: Parser[DataType] = + STRING ^^^ StringType +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala new file mode 100644 index 0000000000..9eb992ee58 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical._ +import rules._ + +/** + * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing + * when all relations are already filled in and the analyser needs only to resolve attribute + * references. + */ +object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) + +/** + * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and + * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and + * a [[FunctionRegistry]]. + */ +class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { + + // TODO: pass this in as a parameter. + val fixedPoint = FixedPoint(100) + + val batches: Seq[Batch] = Seq( + Batch("MultiInstanceRelations", Once, + NewRelationInstances), + Batch("CaseInsensitiveAttributeReferences", Once, + (if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*), + Batch("Resolution", fixedPoint, + ResolveReferences :: + ResolveRelations :: + NewRelationInstances :: + ImplicitGenerate :: + StarExpansion :: + ResolveFunctions :: + GlobalAggregates :: + typeCoercionRules :_*) + ) + + /** + * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. + */ + object ResolveRelations extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case UnresolvedRelation(databaseName, name, alias) => + catalog.lookupRelation(databaseName, name, alias) + } + } + + /** + * Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase. + */ + object LowercaseAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case UnresolvedRelation(databaseName, name, alias) => + UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase)) + case Subquery(alias, child) => Subquery(alias.toLowerCase, child) + case q: LogicalPlan => q transformExpressions { + case s: Star => s.copy(table = s.table.map(_.toLowerCase)) + case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase) + case Alias(c, name) => Alias(c, name.toLowerCase)() + } + } + } + + /** + * Replaces [[UnresolvedAttribute]]s with concrete + * [[expressions.AttributeReference AttributeReferences]] from a logical plan node's children. + */ + object ResolveReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case q: LogicalPlan if q.childrenResolved => + logger.trace(s"Attempting to resolve ${q.simpleString}") + q transformExpressions { + case u @ UnresolvedAttribute(name) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = q.resolve(name).getOrElse(u) + logger.debug(s"Resolving $u to $result") + result + } + } + } + + /** + * Replaces [[UnresolvedFunction]]s with concrete [[expressions.Expression Expressions]]. + */ + object ResolveFunctions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => + q transformExpressions { + case u @ UnresolvedFunction(name, children) if u.childrenResolved => + registry.lookupFunction(name, children) + } + } + } + + /** + * Turns projections that contain aggregate expressions into aggregations. + */ + object GlobalAggregates extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, child) if containsAggregates(projectList) => + Aggregate(Nil, projectList, child) + } + + def containsAggregates(exprs: Seq[Expression]): Boolean = { + exprs.foreach(_.foreach { + case agg: AggregateExpression => return true + case _ => + }) + false + } + } + + /** + * When a SELECT clause has only a single expression and that expression is a + * [[catalyst.expressions.Generator Generator]] we convert the + * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + */ + object ImplicitGenerate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(Seq(Alias(g: Generator, _)), child) => + Generate(g, join = false, outer = false, None, child) + } + } + + /** + * Expands any references to [[Star]] (*) in project operators. + */ + object StarExpansion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Wait until children are resolved + case p: LogicalPlan if !p.childrenResolved => p + // If the projection list contains Stars, expand it. + case p @ Project(projectList, child) if containsStar(projectList) => + Project( + projectList.flatMap { + case s: Star => s.expand(child.output) + case o => o :: Nil + }, + child) + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child.output) + case o => o :: Nil + } + ) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + a.copy( + aggregateExpressions = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child.output) + case o => o :: Nil + } + ) + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + protected def containsStar(exprs: Seq[Expression]): Boolean = + exprs.collect { case _: Star => true }.nonEmpty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala new file mode 100644 index 0000000000..71e4dcdb15 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import plans.logical.{LogicalPlan, Subquery} +import scala.collection.mutable + +/** + * An interface for looking up relations by name. Used by an [[Analyzer]]. + */ +trait Catalog { + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan + + def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit +} + +class SimpleCatalog extends Catalog { + val tables = new mutable.HashMap[String, LogicalPlan]() + + def registerTable(databaseName: Option[String],tableName: String, plan: LogicalPlan): Unit = { + tables += ((tableName, plan)) + } + + def dropTable(tableName: String) = tables -= tableName + + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan = { + val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName")) + + // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are + // properly qualified with this alias. + alias.map(a => Subquery(a.toLowerCase, table)).getOrElse(table) + } +} + +/** + * A trait that can be mixed in with other Catalogs allowing specific tables to be overridden with + * new logical plans. This can be used to bind query result to virtual tables, or replace tables + * with in-memory cached versions. Note that the set of overrides is stored in memory and thus + * lost when the JVM exits. + */ +trait OverrideCatalog extends Catalog { + + // TODO: This doesn't work when the database changes... + val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + + abstract override def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan = { + + val overriddenTable = overrides.get((databaseName, tableName)) + + // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are + // properly qualified with this alias. + val withAlias = + overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r)) + + withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias)) + } + + override def registerTable( + databaseName: Option[String], + tableName: String, + plan: LogicalPlan): Unit = { + overrides.put((databaseName, tableName), plan) + } +} + +/** + * A trivial catalog that returns an error when a relation is requested. Used for testing when all + * relations are already filled in and the analyser needs only to resolve attribute references. + */ +object EmptyCatalog extends Catalog { + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None) = { + throw new UnsupportedOperationException + } + + def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala new file mode 100644 index 0000000000..a359eb5411 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ + +/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ +trait FunctionRegistry { + def lookupFunction(name: String, children: Seq[Expression]): Expression +} + +/** + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyser needs only to resolve attribute references. + */ +object EmptyFunctionRegistry extends FunctionRegistry { + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala new file mode 100644 index 0000000000..a0105cd7cf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical._ +import rules._ +import types._ + +/** + * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ +trait HiveTypeCoercion { + + val typeCoercionRules = + List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, + StringToIntegralCasts, FunctionArgumentConversion) + + /** + * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] dataTypes + * that are made by other rules to instances higher in the query tree. + */ + object PropagateTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // No propagation required for leaf nodes. + case q: LogicalPlan if q.children.isEmpty => q + + // Don't propagate types from unresolved children. + case q: LogicalPlan if !q.childrenResolved => q + + case q: LogicalPlan => q transformExpressions { + case a: AttributeReference => + q.inputSet.find(_.exprId == a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } + } + } + + /** + * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to + * the appropriate numeric equivalent. + */ + object ConvertNaNs extends Rule[LogicalPlan] { + val stringNaN = Literal("NaN", StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + /* Double Conversions */ + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => + b.makeCopy(Array(b.right, Literal(Double.NaN))) + case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) + + /* Float Conversions */ + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => + b.makeCopy(Array(b.right, Literal(Float.NaN))) + case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) + } + } + } + + /** + * Widens numeric types and converts strings to numbers when appropriate. + * + * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White + * + * The implicit conversion rules can be summarized as follows: + * - Any integral numeric type can be implicitly converted to a wider type. + * - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be implicitly + * converted to DOUBLE. + * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. + * - BOOLEAN types cannot be converted to any other type. + * + * Additionally, all types when UNION-ed with strings will be promoted to strings. + * Other string conversions are handled by PromoteStrings. + */ + object WidenTypes extends Rule[LogicalPlan] { + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence = + Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + // Boolean is only wider than Void + val booleanPrecedence = Seq(NullType, BooleanType) + val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil + + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + // When a string is found on one side, make the other side a string too. + case (l, r) if l.dataType == StringType && r.dataType != StringType => + (l, Alias(Cast(r, StringType), r.name)()) + case (l, r) if l.dataType != StringType && r.dataType == StringType => + (Alias(Cast(l, StringType), l.name)(), r) + + case (l, r) if l.dataType != r.dataType => + logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + findTightestCommonType(l.dataType, r.dataType).map { widestType => + val newLeft = + if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() + val newRight = + if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() + + (newLeft, newRight) + }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. + case other => other + } + + val (castedLeft, castedRight) = castedInput.unzip + + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // Also widen types for BinaryExpressions. + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case b: BinaryExpression if b.left.dataType != b.right.dataType => + findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + val newLeft = + if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) + val newRight = + if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + b.makeCopy(Array(newLeft, newRight)) + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + } + } + } + + /** + * Promotes strings that appear in arithmetic expressions. + */ + object PromoteStrings extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case a: BinaryArithmetic if a.left.dataType == StringType => + a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) + case a: BinaryArithmetic if a.right.dataType == StringType => + a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + + case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => + p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) + case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => + p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) + + case Sum(e) if e.dataType == StringType => + Sum(Cast(e, DoubleType)) + case Average(e) if e.dataType == StringType => + Average(Cast(e, DoubleType)) + } + } + + /** + * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + */ + object BooleanComparisons extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + // No need to change Equals operators as that actually makes sense for boolean types. + case e: Equals => e + // Otherwise turn them to Byte types so that there exists and ordering. + case p: BinaryComparison + if p.left.dataType == BooleanType && p.right.dataType == BooleanType => + p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + } + } + + /** + * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since + * the JVM does not consider Booleans to be numeric types. + */ + object BooleanCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) + case Cast(e, dataType) if e.dataType == BooleanType => + Cast(If(e, Literal(1), Literal(0)), dataType) + } + } + + /** + * When encountering a cast from a string representing a valid fractional number to an integral + * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the + * truncated version of this number. + */ + object StringToIntegralCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Cast(e @ StringType(), t: IntegralType) => + Cast(Cast(e, DecimalType), t) + } + } + + /** + * This ensure that the types for various functions are as expected. + */ + object FunctionArgumentConversion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + // Promote SUM to largest types to prevent overflows. + case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. + case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) + case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala new file mode 100644 index 0000000000..fe18cc466f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst +package analysis + +import plans.logical.LogicalPlan +import rules._ + +/** + * A trait that should be mixed into query operators where an single instance might appear multiple + * times in a logical query plan. It is invalid to have multiple copies of the same attribute + * produced by distinct operators in a query tree as this breaks the gurantee that expression + * ids, which are used to differentate attributes, are unique. + * + * Before analysis, all operators that include this trait will be asked to produce a new version + * of itself with globally unique expression ids. + */ +trait MultiInstanceRelation { + def newInstance: this.type +} + +/** + * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so + * that each instance has unique expression ids for the attributes produced. + */ +object NewRelationInstances extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val localRelations = plan collect { case l: MultiInstanceRelation => l} + val multiAppearance = localRelations + .groupBy(identity[MultiInstanceRelation]) + .filter { case (_, ls) => ls.size > 1 } + .map(_._1) + .toSet + + plan transform { + case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala new file mode 100644 index 0000000000..375c99f48e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package catalyst + +/** + * Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis. + * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s + * into fully typed objects using information in a schema [[Catalog]]. + */ +package object analysis diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala new file mode 100644 index 0000000000..2ed2af1352 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical.BaseRelation +import trees.TreeNode + +/** + * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully + * resolved. + */ +class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) extends + errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) + +/** + * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. + */ +case class UnresolvedRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None) extends BaseRelation { + def output = Nil + override lazy val resolved = false +} + +/** + * Holds the name of an attribute that has yet to be resolved. + */ +case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + override lazy val resolved = false + + def newInstance = this + def withQualifiers(newQualifiers: Seq[String]) = this + + override def toString: String = s"'$name" +} + +case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + override def foldable = throw new UnresolvedException(this, "foldable") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + def references = children.flatMap(_.references).toSet + override lazy val resolved = false + override def toString = s"'$name(${children.mkString(",")})" +} + +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT * FROM ...". + * + * @param table an optional table that should be the target of the expansion. If omitted all + * tables' columns are produced. + */ +case class Star( + table: Option[String], + mapFunction: Attribute => Expression = identity[Attribute]) + extends Attribute with trees.LeafNode[Expression] { + + def name = throw new UnresolvedException(this, "exprId") + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + override lazy val resolved = false + + def newInstance = this + def withQualifiers(newQualifiers: Seq[String]) = this + + def expand(input: Seq[Attribute]): Seq[NamedExpression] = { + val expandedAttributes: Seq[Attribute] = table match { + // If there is no table specified, use all input attributes. + case None => input + // If there is a table, pick out attributes that are part of this table. + case Some(table) => input.filter(_.qualifiers contains table) + } + val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { + case (n: NamedExpression, _) => n + case (e, originalAttribute) => + Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + } + mappedAttributes + } + + override def toString = table.map(_ + ".").getOrElse("") + "*" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala new file mode 100644 index 0000000000..cd8de9d52f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import analysis.UnresolvedAttribute +import expressions._ +import plans._ +import plans.logical._ +import types._ + +/** + * Provides experimental support for generating catalyst schemas for scala objects. + */ +object ScalaReflection { + import scala.reflect.runtime.universe._ + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case s: StructType => + s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + } + + /** Returns a catalyst DataType for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) + + /** Returns a catalyst DataType for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): DataType = tpe match { + case t if t <:< typeOf[Product] => + val params = t.member("<init>": TermName).asMethod.paramss + StructType( + params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true))) + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + ArrayType(schemaFor(elementType)) + case t if t <:< typeOf[String] => StringType + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + } + + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { + + /** + * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation + * for the the data in the sequence. + */ + def asRelation: LocalRelation = { + val output = attributesFor[A] + LocalRelation(output, data) + } + } +} + +/** + * A collection of implicit conversions that create a DSL for constructing catalyst data structures. + * + * {{{ + * scala> import catalyst.dsl._ + * + * // Standard operators are added to expressions. + * scala> Literal(1) + Literal(1) + * res1: catalyst.expressions.Add = (1 + 1) + * + * // There is a conversion from 'symbols to unresolved attributes. + * scala> 'a.attr + * res2: catalyst.analysis.UnresolvedAttribute = 'a + * + * // These unresolved attributes can be used to create more complicated expressions. + * scala> 'a === 'b + * res3: catalyst.expressions.Equals = ('a = 'b) + * + * // SQL verbs can be used to construct logical query plans. + * scala> TestRelation('key.int, 'value.string).where('key === 1).select('value).analyze + * res4: catalyst.plans.logical.LogicalPlan = + * Project {value#1} + * Filter (key#0 = 1) + * TestRelation {key#0,value#1} + * }}} + */ +package object dsl { + trait ImplicitOperators { + def expr: Expression + + def + (other: Expression) = Add(expr, other) + def - (other: Expression) = Subtract(expr, other) + def * (other: Expression) = Multiply(expr, other) + def / (other: Expression) = Divide(expr, other) + + def && (other: Expression) = And(expr, other) + def || (other: Expression) = Or(expr, other) + + def < (other: Expression) = LessThan(expr, other) + def <= (other: Expression) = LessThanOrEqual(expr, other) + def > (other: Expression) = GreaterThan(expr, other) + def >= (other: Expression) = GreaterThanOrEqual(expr, other) + def === (other: Expression) = Equals(expr, other) + def != (other: Expression) = Not(Equals(expr, other)) + + def asc = SortOrder(expr, Ascending) + def desc = SortOrder(expr, Descending) + + def as(s: Symbol) = Alias(expr, s.name)() + } + + trait ExpressionConversions { + implicit class DslExpression(e: Expression) extends ImplicitOperators { + def expr = e + } + + implicit def intToLiteral(i: Int) = Literal(i) + implicit def longToLiteral(l: Long) = Literal(l) + implicit def floatToLiteral(f: Float) = Literal(f) + implicit def doubleToLiteral(d: Double) = Literal(d) + implicit def stringToLiteral(s: String) = Literal(s) + + implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + implicit class DslString(val s: String) extends ImplicitAttribute + + abstract class ImplicitAttribute extends ImplicitOperators { + def s: String + def expr = attr + def attr = analysis.UnresolvedAttribute(s) + + /** Creates a new typed attributes of type int */ + def int = AttributeReference(s, IntegerType, nullable = false)() + + /** Creates a new typed attributes of type string */ + def string = AttributeReference(s, StringType, nullable = false)() + } + + implicit class DslAttribute(a: AttributeReference) { + def notNull = a.withNullability(false) + def nullable = a.withNullability(true) + + // Protobuf terminology + def required = a.withNullability(false) + } + } + + + object expressions extends ExpressionConversions // scalastyle:ignore + + abstract class LogicalPlanFunctions { + def logicalPlan: LogicalPlan + + def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + + def where(condition: Expression) = Filter(condition, logicalPlan) + + def join( + otherPlan: LogicalPlan, + joinType: JoinType = Inner, + condition: Option[Expression] = None) = + Join(logicalPlan, otherPlan, joinType, condition) + + def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan) + + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) + } + + def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + + def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + + def sfilter(dynamicUdf: (DynamicRow) => Boolean) = + Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan) + + def sample( + fraction: Double, + withReplacement: Boolean = true, + seed: Int = (math.random * 1000).toInt) = + Sample(fraction, withReplacement, seed, logicalPlan) + + def generate( + generator: Generator, + join: Boolean = false, + outer: Boolean = false, + alias: Option[String] = None) = + Generate(generator, join, outer, None, logicalPlan) + + def insertInto(tableName: String, overwrite: Boolean = false) = + InsertIntoTable( + analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite) + + def analyze = analysis.SimpleAnalyzer(logicalPlan) + } + + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { + def writeToFile(path: String) = WriteToFile(path, logicalPlan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala new file mode 100644 index 0000000000..c253587f67 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import trees._ + +/** + * Functions for attaching and retrieving trees that are associated with errors. + */ +package object errors { + + class TreeNodeException[TreeType <: TreeNode[_]] + (tree: TreeType, msg: String, cause: Throwable) extends Exception(msg, cause) { + + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT + // external project dependencies for some reason. + def this(tree: TreeType, msg: String) = this(tree, msg, null) + + override def getMessage: String = { + val treeString = tree.toString + s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" + } + } + + /** + * Wraps any exceptions that are thrown while executing `f` in a + * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. + */ + def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = { + try f catch { + case e: Exception => throw new TreeNodeException(tree, msg, e) + } + } + + /** + * Executes `f` which is expected to throw a + * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in + * the stack of exceptions of type `TreeType` is returned. + */ + def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala new file mode 100644 index 0000000000..3b6bac16ff --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import rules._ +import errors._ + +import catalyst.plans.QueryPlan + +/** + * A bound reference points to a specific slot in the input tuple, allowing the actual value + * to be retrieved more efficiently. However, since operations like column pruning can change + * the layout of intermediate tuples, BindReferences should be run after all such transformations. + */ +case class BoundReference(ordinal: Int, baseReference: Attribute) + extends Attribute with trees.LeafNode[Expression] { + + type EvaluatedType = Any + + def nullable = baseReference.nullable + def dataType = baseReference.dataType + def exprId = baseReference.exprId + def qualifiers = baseReference.qualifiers + def name = baseReference.name + + def newInstance = BoundReference(ordinal, baseReference.newInstance) + def withQualifiers(newQualifiers: Seq[String]) = + BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) + + override def toString = s"$baseReference:$ordinal" + + override def apply(input: Row): Any = input(ordinal) +} + +class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { + import BindReferences._ + + def apply(plan: TreeNode): TreeNode = { + plan.transform { + case leafNode if leafNode.children.isEmpty => leafNode + case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => + bindReference(e, unaryNode.children.head.output) + } + } + } +} + +object BindReferences extends Logging { + def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { + expression.transform { case a: AttributeReference => + attachTree(a, "Binding attribute") { + val ordinal = input.indexWhere(_.exprId == a.exprId) + if (ordinal == -1) { + // TODO: This fallback is required because some operators (such as ScriptTransform) + // produce new attributes that can't be bound. Likely the right thing to do is remove + // this rule and require all operators to explicitly bind to the input schema that + // they specify. + logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + a + } else { + BoundReference(ordinal, a) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala new file mode 100644 index 0000000000..608656d3a9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"CAST($child, $dataType)" + + type EvaluatedType = Any + + lazy val castingFunction: Any => Any = (child.dataType, dataType) match { + case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]]) + case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes + case (_, StringType) => a: Any => a.toString + case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt) + case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble) + case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat) + case (StringType, LongType) => a: Any => castOrNull(a, _.toLong) + case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort) + case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte) + case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_)) + case (BooleanType, ByteType) => a: Any => a match { + case null => null + case true => 1.toByte + case false => 0.toByte + } + case (dt, IntegerType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a) + case (dt, DoubleType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a) + case (dt, FloatType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a) + case (dt, LongType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a) + case (dt, ShortType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort + case (dt, ByteType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte + case (dt, DecimalType) => + a: Any => + BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)) + } + + @inline + protected def castOrNull[A](a: Any, f: String => A) = + try f(a.asInstanceOf[String]) catch { + case _: java.lang.NumberFormatException => null + } + + override def apply(input: Row): Any = { + val evaluated = child.apply(input) + if (evaluated == null) { + null + } else { + castingFunction(evaluated) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala new file mode 100644 index 0000000000..78aaaeebbd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import errors._ +import trees._ +import types._ + +abstract class Expression extends TreeNode[Expression] { + self: Product => + + /** The narrowest possible type that is produced when this expression is evaluated. */ + type EvaluatedType <: Any + + def dataType: DataType + + /** + * Returns true when an expression is a candidate for static evaluation before the query is + * executed. + * + * The following conditions are used to determine suitability for constant folding: + * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable + * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right + * child are foldable + * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or + * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable. + * - A [[expressions.Literal]] is foldable. + * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its + * child is foldable. + */ + // TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs. + def foldable: Boolean = false + def nullable: Boolean + def references: Set[Attribute] + + /** Returns the result of evaluating this expression on a given input Row */ + def apply(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + /** + * Returns `true` if this expression and all its children have been resolved to a specific schema + * and `false` if it is still contains any unresolved placeholders. Implementations of expressions + * should override this if the resolution of this type of expression involves more than just + * the resolution of its children. + */ + lazy val resolved: Boolean = childrenResolved + + /** + * Returns true if all the children of this expression have been resolved to a specific schema + * and false if any still contains any unresolved placeholders. + */ + def childrenResolved = !children.exists(!_.resolved) + + /** + * A set of helper functions that return the correct descendant of [[scala.math.Numeric]] type + * and do any casting necessary of child evaluation. + */ + @inline + def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { + val evalE = e.apply(i) + if (evalE == null) { + null + } else { + e.dataType match { + case n: NumericType => + val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType] + castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + + @inline + protected final def n2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Numeric[Any], Any, Any) => Any)): Any = { + + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case n: NumericType => + f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => Int]( + n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + } + + @inline + protected final def f2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Fractional[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i: Row) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i: Row) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case ft: FractionalType => + f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType]( + ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType]) + case other => sys.error(s"Type $other does not support fractional operations") + } + } + } + } + + @inline + protected final def i2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Integral[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: IntegralType => + f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( + i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + } +} + +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + def symbol: String + + override def foldable = left.foldable && right.foldable + + def references = left.references ++ right.references + + override def toString = s"($left $symbol $right)" +} + +abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { + self: Product => +} + +abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { + self: Product => + + def references = child.references +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala new file mode 100644 index 0000000000..8c407d2fdd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst +package expressions + +/** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. + */ +class Projection(expressions: Seq[Expression]) extends (Row => Row) { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) + + protected val exprArray = expressions.toArray + def apply(input: Row): Row = { + val outputArray = new Array[Any](exprArray.size) + var i = 0 + while (i < exprArray.size) { + outputArray(i) = exprArray(i).apply(input) + i += 1 + } + new GenericRow(outputArray) + } +} + +/** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of th + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significatly reduces the cost of calcuating the + * projection, but means that it is not safe + */ +case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) + + private[this] val exprArray = expressions.toArray + private[this] val mutableRow = new GenericMutableRow(exprArray.size) + def currentValue: Row = mutableRow + + def apply(input: Row): Row = { + var i = 0 + while (i < exprArray.size) { + mutableRow(i) = exprArray(i).apply(input) + i += 1 + } + mutableRow + } +} + +/** + * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to + * be instantiated once per thread and reused. + */ +class JoinedRow extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = apply(i) == null + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala new file mode 100644 index 0000000000..a5d0ecf964 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types.DoubleType + +case object Rand extends LeafExpression { + def dataType = DoubleType + def nullable = false + def references = Set.empty + override def toString = "RAND()" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala new file mode 100644 index 0000000000..3529675468 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** + * Represents one row of output from a relational operator. Allows both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * + * It is invalid to use the native primitive interface to retrieve a value that is null, instead a + * user must check [[isNullAt]] before attempting to retrieve a value that might be null. + */ +trait Row extends Seq[Any] with Serializable { + def apply(i: Int): Any + + def isNullAt(i: Int): Boolean + + def getInt(i: Int): Int + def getLong(i: Int): Long + def getDouble(i: Int): Double + def getFloat(i: Int): Float + def getBoolean(i: Int): Boolean + def getShort(i: Int): Short + def getByte(i: Int): Byte + def getString(i: Int): String + + override def toString() = + s"[${this.mkString(",")}]" + + def copy(): Row +} + +/** + * An extended interface to [[Row]] that allows the values for each column to be updated. Setting + * a value through a primitive function implicitly marks that column as not null. + */ +trait MutableRow extends Row { + def setNullAt(i: Int): Unit + + def update(ordinal: Int, value: Any) + + def setInt(ordinal: Int, value: Int) + def setLong(ordinal: Int, value: Long) + def setDouble(ordinal: Int, value: Double) + def setBoolean(ordinal: Int, value: Boolean) + def setShort(ordinal: Int, value: Short) + def setByte(ordinal: Int, value: Byte) + def setFloat(ordinal: Int, value: Float) + def setString(ordinal: Int, value: String) + + /** + * EXPERIMENTAL + * + * Returns a mutable string builder for the specified column. A given row should return the + * result of any mutations made to the returned buffer next time getString is called for the same + * column. + */ + def getStringBuilder(ordinal: Int): StringBuilder +} + +/** + * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. + */ +object EmptyRow extends Row { + def apply(i: Int): Any = throw new UnsupportedOperationException + + def iterator = Iterator.empty + def length = 0 + def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException + + def getInt(i: Int): Int = throw new UnsupportedOperationException + def getLong(i: Int): Long = throw new UnsupportedOperationException + def getDouble(i: Int): Double = throw new UnsupportedOperationException + def getFloat(i: Int): Float = throw new UnsupportedOperationException + def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException + def getShort(i: Int): Short = throw new UnsupportedOperationException + def getByte(i: Int): Byte = throw new UnsupportedOperationException + def getString(i: Int): String = throw new UnsupportedOperationException + + def copy() = this +} + +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { + /** No-arg constructor for serialization. */ + def this() = this(null) + + def this(size: Int) = this(new Array[Any](size)) + + def iterator = values.iterator + + def length = values.length + + def apply(i: Int) = values(i) + + def isNullAt(i: Int) = values(i) == null + + def getInt(i: Int): Int = { + if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") + values(i).asInstanceOf[Int] + } + + def getLong(i: Int): Long = { + if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") + values(i).asInstanceOf[Long] + } + + def getDouble(i: Int): Double = { + if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") + values(i).asInstanceOf[Double] + } + + def getFloat(i: Int): Float = { + if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") + values(i).asInstanceOf[Float] + } + + def getBoolean(i: Int): Boolean = { + if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") + values(i).asInstanceOf[Boolean] + } + + def getShort(i: Int): Short = { + if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") + values(i).asInstanceOf[Short] + } + + def getByte(i: Int): Byte = { + if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") + values(i).asInstanceOf[Byte] + } + + def getString(i: Int): String = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[String] + } + + def copy() = this +} + +class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { + /** No-arg constructor for serialization. */ + def this() = this(0) + + def getStringBuilder(ordinal: Int): StringBuilder = ??? + + override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + + override def setNullAt(i: Int): Unit = { values(i) = null } + + override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + + override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + + override def copy() = new GenericRow(values.clone()) +} + + +class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + var i = 0 + while (i < ordering.size) { + val order = ordering(i) + val left = order.child.apply(a) + val right = order.child.apply(b) + + if (left == null && right == null) { + // Both null, continue looking. + } else if (left == null) { + return if (order.direction == Ascending) -1 else 1 + } else if (right == null) { + return if (order.direction == Ascending) 1 else -1 + } else { + val comparison = order.dataType match { + case n: NativeType if order.direction == Ascending => + n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case n: NativeType if order.direction == Descending => + n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + } + if (comparison != 0) return comparison + } + i += 1 + } + return 0 + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala new file mode 100644 index 0000000000..a3c7ca1acd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) + extends Expression { + + type EvaluatedType = Any + + def references = children.flatMap(_.references).toSet + def nullable = true + + override def apply(input: Row): Any = { + children.size match { + case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input)) + case 2 => + function.asInstanceOf[(Any, Any) => Any]( + children(0).apply(input), + children(1).apply(input)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala new file mode 100644 index 0000000000..171997b90e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +abstract sealed class SortDirection +case object Ascending extends SortDirection +case object Descending extends SortDirection + +/** + * An expression that can be used to sort a tuple. This class extends expression primarily so that + * transformations over expression will descend into its child. + */ +case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { + def dataType = child.dataType + def nullable = child.nullable + override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala new file mode 100644 index 0000000000..2ad8d6f31d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import scala.language.dynamics + +import types._ + +case object DynamicType extends DataType + +case class WrapDynamic(children: Seq[Attribute]) extends Expression { + type EvaluatedType = DynamicRow + + def nullable = false + def references = children.toSet + def dataType = DynamicType + + override def apply(input: Row): DynamicRow = input match { + // Avoid copy for generic rows. + case g: GenericRow => new DynamicRow(children, g.values) + case otherRowType => new DynamicRow(children, otherRowType.toArray) + } +} + +class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) + extends GenericRow(values) with Dynamic { + + def selectDynamic(attributeName: String): String = { + val ordinal = schema.indexWhere(_.name == attributeName) + values(ordinal).toString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala new file mode 100644 index 0000000000..2287a849e6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types._ + +abstract class AggregateExpression extends Expression { + self: Product => + + /** + * Creates a new instance that can be used to compute this aggregate expression for a group + * of input rows/ + */ + def newInstance: AggregateFunction +} + +/** + * Represents an aggregation that has been rewritten to be performed in two steps. + * + * @param finalEvaluation an aggregate expression that evaluates to same final result as the + * original aggregation. + * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial + * data sets and are required to compute the `finalEvaluation`. + */ +case class SplitEvaluation( + finalEvaluation: Expression, + partialEvaluations: Seq[NamedExpression]) + +/** + * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples. + * These partial evaluations can then be combined to compute the actual answer. + */ +abstract class PartialAggregate extends AggregateExpression { + self: Product => + + /** + * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. + */ + def asPartial: SplitEvaluation +} + +/** + * A specific implementation of an aggregate function. Used to wrap a generic + * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + */ +abstract class AggregateFunction + extends AggregateExpression with Serializable with trees.LeafNode[Expression] { + self: Product => + + type EvaluatedType = Any + + /** Base should return the generic aggregate expression that this function is computing */ + val base: AggregateExpression + def references = base.references + def nullable = base.nullable + def dataType = base.dataType + + def update(input: Row): Unit + override def apply(input: Row): Any + + // Do we really need this? + def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) +} + +case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = IntegerType + override def toString = s"COUNT($child)" + + def asPartial: SplitEvaluation = { + val partialCount = Alias(Count(child), "PartialCount")() + SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) + } + + override def newInstance = new CountFunction(child, this) +} + +case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { + def children = expressions + def references = expressions.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" + override def newInstance = new CountDistinctFunction(expressions, this) +} + +case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = DoubleType + override def toString = s"AVG($child)" + + override def asPartial: SplitEvaluation = { + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } + + override def newInstance = new AverageFunction(child, this) +} + +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = child.dataType + override def toString = s"SUM($child)" + + override def asPartial: SplitEvaluation = { + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + Sum(partialSum.toAttribute), + partialSum :: Nil) + } + + override def newInstance = new SumFunction(child, this) +} + +case class SumDistinct(child: Expression) + extends AggregateExpression with trees.UnaryNode[Expression] { + + def references = child.references + def nullable = false + def dataType = child.dataType + override def toString = s"SUM(DISTINCT $child)" + + override def newInstance = new SumDistinctFunction(child, this) +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = child.nullable + def dataType = child.dataType + override def toString = s"FIRST($child)" + + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance = new FirstFunction(child, this) +} + +case class AverageFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + private var count: Long = _ + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow)) + private val sumAsDouble = Cast(sum, DoubleType) + + + + private val addFunction = Add(sum, expr) + + override def apply(input: Row): Any = + sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble + + def update(input: Row): Unit = { + count += 1 + sum.update(addFunction, input) + } +} + +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Int = _ + + def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.apply(input)) + if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { + count += 1 + } + } + + override def apply(input: Row): Any = count +} + +case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null)) + + private val addFunction = Add(sum, expr) + + def update(input: Row): Unit = { + sum.update(addFunction, input) + } + + override def apply(input: Row): Any = sum.apply(null) +} + +case class SumDistinctFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new scala.collection.mutable.HashSet[Any]() + + def update(input: Row): Unit = { + val evaluatedExpr = expr.apply(input) + if (evaluatedExpr != null) { + seen += evaluatedExpr + } + } + + override def apply(input: Row): Any = + seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) +} + +case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new scala.collection.mutable.HashSet[Any]() + + def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.apply(input)) + if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { + seen += evaluatedExpr + } + } + + override def apply(input: Row): Any = seen.size +} + +case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var result: Any = null + + def update(input: Row): Unit = { + if (result == null) { + result = expr.apply(input) + } + } + + override def apply(input: Row): Any = result +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala new file mode 100644 index 0000000000..db235645cd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedException +import catalyst.types._ + +case class UnaryMinus(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"-$child" + + override def apply(input: Row): Any = { + n1(child, input, _.negate(_)) + } +} + +abstract class BinaryArithmetic extends BinaryExpression { + self: Product => + + type EvaluatedType = Any + + def nullable = left.nullable || right.nullable + + override lazy val resolved = + left.resolved && right.resolved && left.dataType == right.dataType + + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } +} + +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "+" + + override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _)) +} + +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "-" + + override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _)) +} + +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "*" + + override def apply(input: Row): Any = n2(input, left, right, _.times(_, _)) +} + +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "/" + + override def apply(input: Row): Any = dataType match { + case _: FractionalType => f2(input, left, right, _.div(_, _)) + case _: IntegralType => i2(input, left , right, _.quot(_, _)) + } + +} + +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "%" + + override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala new file mode 100644 index 0000000000..d3feb6c461 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** + * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`. + */ +case class GetItem(child: Expression, ordinal: Expression) extends Expression { + type EvaluatedType = Any + + val children = child :: ordinal :: Nil + /** `Null` is returned for invalid ordinals. */ + override def nullable = true + override def references = children.flatMap(_.references).toSet + def dataType = child.dataType match { + case ArrayType(dt) => dt + case MapType(_, vt) => vt + } + override lazy val resolved = + childrenResolved && + (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + + override def toString = s"$child[$ordinal]" + + override def apply(input: Row): Any = { + if (child.dataType.isInstanceOf[ArrayType]) { + val baseValue = child.apply(input).asInstanceOf[Seq[_]] + val o = ordinal.apply(input).asInstanceOf[Int] + if (baseValue == null) { + null + } else if (o >= baseValue.size || o < 0) { + null + } else { + baseValue(o) + } + } else { + val baseValue = child.apply(input).asInstanceOf[Map[Any, _]] + val key = ordinal.apply(input) + if (baseValue == null) { + null + } else { + baseValue.get(key).orNull + } + } + } +} + +/** + * Returns the value of fields in the Struct `child`. + */ +case class GetField(child: Expression, fieldName: String) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = field.dataType + def nullable = field.nullable + + protected def structType = child.dataType match { + case s: StructType => s + case otherType => sys.error(s"GetField is not valid on fields of type $otherType") + } + + lazy val field = + structType.fields + .find(_.name == fieldName) + .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) + + lazy val ordinal = structType.fields.indexOf(field) + + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] + + override def apply(input: Row): Any = { + val baseValue = child.apply(input).asInstanceOf[Row] + if (baseValue == null) null else baseValue(ordinal) + } + + override def toString = s"$child.$fieldName" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala new file mode 100644 index 0000000000..c367de2a3e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types._ + +/** + * An expression that produces zero or more rows given a single input row. + * + * Generators produce multiple output rows instead of a single value like other expressions, + * and thus they must have a schema to associate with the rows that are output. + * + * However, unlike row producing relational operators, which are either leaves or determine their + * output schema functionally from their input, generators can contain other expressions that + * might result in their modification by rules. This structure means that they might be copied + * multiple times after first determining their output schema. If a new output schema is created for + * each copy references up the tree might be rendered invalid. As a result generators must + * instead define a function `makeOutput` which is called only once when the schema is first + * requested. The attributes produced by this function will be automatically copied anytime rules + * result in changes to the Generator or its children. + */ +abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { + self: Product => + + type EvaluatedType = TraversableOnce[Row] + + lazy val dataType = + ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) + + def nullable = false + + def references = children.flatMap(_.references).toSet + + /** + * Should be overridden by specific generators. Called only once for each instance to ensure + * that rule application does not change the output schema of a generator. + */ + protected def makeOutput(): Seq[Attribute] + + private var _output: Seq[Attribute] = null + + def output: Seq[Attribute] = { + if (_output == null) { + _output = makeOutput() + } + _output + } + + /** Should be implemented by child classes to perform specific Generators. */ + def apply(input: Row): TraversableOnce[Row] + + /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + val copy = super.makeCopy(newArgs) + copy._output = _output + copy + } +} + +/** + * Given an input array produces a sequence of rows for each value in the array. + */ +case class Explode(attributeNames: Seq[String], child: Expression) + extends Generator with trees.UnaryNode[Expression] { + + override lazy val resolved = + child.resolved && + (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + + lazy val elementTypes = child.dataType match { + case ArrayType(et) => et :: Nil + case MapType(kt,vt) => kt :: vt :: Nil + } + + // TODO: Move this pattern into Generator. + protected def makeOutput() = + if (attributeNames.size == elementTypes.size) { + attributeNames.zip(elementTypes).map { + case (n, t) => AttributeReference(n, t, nullable = true)() + } + } else { + elementTypes.zipWithIndex.map { + case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + } + } + + override def apply(input: Row): TraversableOnce[Row] = { + child.dataType match { + case ArrayType(_) => + val inputArray = child.apply(input).asInstanceOf[Seq[Any]] + if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) + case MapType(_, _) => + val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]] + if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + } + } + + override def toString() = s"explode($child)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala new file mode 100644 index 0000000000..229d8f7f7b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +object Literal { + def apply(v: Any): Literal = v match { + case i: Int => Literal(i, IntegerType) + case l: Long => Literal(l, LongType) + case d: Double => Literal(d, DoubleType) + case f: Float => Literal(f, FloatType) + case b: Byte => Literal(b, ByteType) + case s: Short => Literal(s, ShortType) + case s: String => Literal(s, StringType) + case b: Boolean => Literal(b, BooleanType) + case null => Literal(null, NullType) + } +} + +/** + * Extractor for retrieving Int literals. + */ +object IntegerLiteral { + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + case _ => None + } +} + +case class Literal(value: Any, dataType: DataType) extends LeafExpression { + + override def foldable = true + def nullable = value == null + def references = Set.empty + + override def toString = if (value != null) value.toString else "null" + + type EvaluatedType = Any + override def apply(input: Row):Any = value +} + +// TODO: Specialize +case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression { + type EvaluatedType = Any + + val dataType = Literal(value).dataType + + def references = Set.empty + + def update(expression: Expression, input: Row) = { + value = expression.apply(input) + } + + override def apply(input: Row) = value +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala new file mode 100644 index 0000000000..0a06e85325 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedAttribute +import types._ + +object NamedExpression { + private val curId = new java.util.concurrent.atomic.AtomicLong() + def newExprId = ExprId(curId.getAndIncrement()) +} + +/** + * A globally (within this JVM) id for a given named expression. + * Used to identify with attribute output by a relation is being + * referenced in a subsuqent computation. + */ +case class ExprId(id: Long) + +abstract class NamedExpression extends Expression { + self: Product => + + def name: String + def exprId: ExprId + def qualifiers: Seq[String] + + def toAttribute: Attribute + + protected def typeSuffix = + if (resolved) { + dataType match { + case LongType => "L" + case _ => "" + } + } else { + "" + } +} + +abstract class Attribute extends NamedExpression { + self: Product => + + def withQualifiers(newQualifiers: Seq[String]): Attribute + + def references = Set(this) + def toAttribute = this + def newInstance: Attribute +} + +/** + * Used to assign a new name to a computation. + * For example the SQL expression "1 + 1 AS a" could be represented as follows: + * Alias(Add(Literal(1), Literal(1), "a")() + * + * @param child the computation being performed + * @param name the name to be associated with the result of computing [[child]]. + * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this + * alias. Auto-assigned if left blank. + */ +case class Alias(child: Expression, name: String) + (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) + extends NamedExpression with trees.UnaryNode[Expression] { + + type EvaluatedType = Any + + override def apply(input: Row) = child.apply(input) + + def dataType = child.dataType + def nullable = child.nullable + def references = child.references + + def toAttribute = { + if (resolved) { + AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) + } else { + UnresolvedAttribute(name) + } + } + + override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" + + override protected final def otherCopyArgs = exprId :: qualifiers :: Nil +} + +/** + * A reference to an attribute produced by another operator in the tree. + * + * @param name The name of this attribute, should only be used during analysis or for debugging. + * @param dataType The [[types.DataType DataType]] of this attribute. + * @param nullable True if null is a valid value for this attribute. + * @param exprId A globally unique id used to check if different AttributeReferences refer to the + * same attribute. + * @param qualifiers a list of strings that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. + */ +case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true) + (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) + extends Attribute with trees.LeafNode[Expression] { + + override def equals(other: Any) = other match { + case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case _ => false + } + + override def hashCode: Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var h = 17 + h = h * 37 + exprId.hashCode() + h = h * 37 + dataType.hashCode() + h + } + + def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + + /** + * Returns a copy of this [[AttributeReference]] with changed nullability. + */ + def withNullability(newNullability: Boolean) = { + if (nullable == newNullability) { + this + } else { + AttributeReference(name, dataType, newNullability)(exprId, qualifiers) + } + } + + /** + * Returns a copy of this [[AttributeReference]] with new qualifiers. + */ + def withQualifiers(newQualifiers: Seq[String]) = { + if (newQualifiers == qualifiers) { + this + } else { + AttributeReference(name, dataType, nullable)(exprId, newQualifiers) + } + } + + override def toString: String = s"$name#${exprId.id}$typeSuffix" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala new file mode 100644 index 0000000000..e869a4d9b0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedException + +case class Coalesce(children: Seq[Expression]) extends Expression { + type EvaluatedType = Any + + /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ + def nullable = !children.exists(!_.nullable) + + def references = children.flatMap(_.references).toSet + // Coalesce is foldable if all children are foldable. + override def foldable = !children.exists(!_.foldable) + + // Only resolved if all the children are of the same type. + override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + + override def toString = s"Coalesce(${children.mkString(",")})" + + def dataType = if (resolved) { + children.head.dataType + } else { + throw new UnresolvedException(this, "Coalesce cannot have children of different types.") + } + + override def apply(input: Row): Any = { + var i = 0 + var result: Any = null + while(i < children.size && result == null) { + result = children(i).apply(input) + i += 1 + } + result + } +} + +case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = false + + override def apply(input: Row): Any = { + child.apply(input) == null + } +} + +case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = false + override def toString = s"IS NOT NULL $child" + + override def apply(input: Row): Any = { + child.apply(input) != null + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala new file mode 100644 index 0000000000..76554e160b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A set of classes that can be used to represent trees of relational expressions. A key goal of + * the expression library is to hide the details of naming and scoping from developers who want to + * manipulate trees of relational operators. As such, the library defines a special type of + * expression, a [[NamedExpression]] in addition to the standard collection of expressions. + * + * ==Standard Expressions== + * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT), + * and other computations (e.g. UDFs). Each expression type is capable of determining its output + * schema as a function of its children's output schema. + * + * ==Named Expressions== + * Some expression are named and thus can be referenced by later operators in the dataflow graph. + * The two types of named expressions are [[AttributeReference]]s and [[Alias]]es. + * [[AttributeReference]]s refer to attributes of the input tuple for a given operator and form + * the leaves of some expression trees. Aliases assign a name to intermediate computations. + * For example, in the SQL statement `SELECT a+b AS c FROM ...`, the expressions `a` and `b` would + * be represented by `AttributeReferences` and `c` would be represented by an `Alias`. + * + * During [[analysis]], all named expressions are assigned a globally unique expression id, which + * can be used for equality comparisons. While the original names are kept around for debugging + * purposes, they should never be used to check if two attributes refer to the same value, as + * plan transformations can result in the introduction of naming ambiguity. For example, consider + * a plan that contains subqueries, both of which are reading from the same table. If an + * optimization removes the subqueries, scoping information would be destroyed, eliminating the + * ability to reason about which subquery produced a given attribute. + * + * ==Evaluation== + * The result of expressions can be evaluated using the [[Evaluate]] object. + */ +package object expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala new file mode 100644 index 0000000000..561396eb43 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ +import catalyst.analysis.UnresolvedException + +trait Predicate extends Expression { + self: Product => + + def dataType = BooleanType + + type EvaluatedType = Any +} + +trait PredicateHelper { + def splitConjunctivePredicates(condition: Expression): Seq[Expression] = condition match { + case And(cond1, cond2) => splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } +} + +abstract class BinaryPredicate extends BinaryExpression with Predicate { + self: Product => + def nullable = left.nullable || right.nullable +} + +case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"NOT $child" + + override def apply(input: Row): Any = { + child.apply(input) match { + case null => null + case b: Boolean => !b + } + } +} + +/** + * Evaluates to `true` if `list` contains `value`. + */ +case class In(value: Expression, list: Seq[Expression]) extends Predicate { + def children = value +: list + def references = children.flatMap(_.references).toSet + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + + override def apply(input: Row): Any = { + val evaluatedValue = value.apply(input) + list.exists(e => e.apply(input) == evaluatedValue) + } +} + +case class And(left: Expression, right: Expression) extends BinaryPredicate { + def symbol = "&&" + + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == false || r == false) { + false + } else if (l == null || r == null ) { + null + } else { + true + } + } +} + +case class Or(left: Expression, right: Expression) extends BinaryPredicate { + def symbol = "||" + + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == true || r == true) { + true + } else if (l == null || r == null) { + null + } else { + false + } + } +} + +abstract class BinaryComparison extends BinaryPredicate { + self: Product => +} + +case class Equals(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "=" + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == null || r == null) null else l == r + } +} + +case class LessThan(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] < r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.lt(_, _)) + } + } +} + +case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<=" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] <= r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.lteq(_, _)) + } + } +} + +case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { + def symbol = ">" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] > r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.gt(_, _)) + } + } +} + +case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + def symbol = ">=" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] >= r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.gteq(_, _)) + } + } +} + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + def children = predicate :: trueValue :: falseValue :: Nil + def nullable = trueValue.nullable || falseValue.nullable + def references = children.flatMap(_.references).toSet + override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType + def dataType = { + if (!resolved) { + throw new UnresolvedException( + this, + s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + } + trueValue.dataType + } + + type EvaluatedType = Any + override def apply(input: Row): Any = { + if (predicate(input).asInstanceOf[Boolean]) { + trueValue.apply(input) + } else { + falseValue.apply(input) + } + } + + override def toString = s"if ($predicate) $trueValue else $falseValue" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala new file mode 100644 index 0000000000..6e585236b1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types.BooleanType + +case class Like(left: Expression, right: Expression) extends BinaryExpression { + def dataType = BooleanType + def nullable = left.nullable // Right cannot be null. + def symbol = "LIKE" +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala new file mode 100644 index 0000000000..4db2803173 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package optimizer + +import catalyst.expressions._ +import catalyst.plans.logical._ +import catalyst.rules._ +import catalyst.types.BooleanType +import catalyst.plans.Inner + +object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + SimplifyCasts) :: + Batch("Filter Pushdown", Once, + EliminateSubqueries, + CombineFilters, + PushPredicateThroughProject, + PushPredicateThroughInnerJoin) :: Nil +} + +/** + * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are + * only required to provide scoping information for attributes and can be removed once analysis is + * complete. + */ +object EliminateSubqueries extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Subquery(_, child) => child + } +} + +/** + * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with + * equivalent [[catalyst.expressions.Literal Literal]] values. + */ +object ConstantFolding extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + // Skip redundant folding of literals. + case l: Literal => l + case e if e.foldable => Literal(e.apply(null), e.dataType) + } + } +} + +/** + * Simplifies boolean expressions where the answer can be determined without evaluating both sides. + * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus + * is only safe when evaluations of expressions does not result in side effects. + */ +object BooleanSimplification extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case and @ And(left, right) => { + (left, right) match { + case (Literal(true, BooleanType), r) => r + case (l, Literal(true, BooleanType)) => l + case (Literal(false, BooleanType), _) => Literal(false) + case (_, Literal(false, BooleanType)) => Literal(false) + case (_, _) => and + } + } + case or @ Or(left, right) => { + (left, right) match { + case (Literal(true, BooleanType), _) => Literal(true) + case (_, Literal(true, BooleanType)) => Literal(true) + case (Literal(false, BooleanType), r) => r + case (l, Literal(false, BooleanType)) => l + case (_, _) => or + } + } + } + } +} + +/** + * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the + * conditions into one conjunctive predicate. + */ +object CombineFilters extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ff@Filter(fc, nf@Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) + } +} + +/** + * Pushes [[catalyst.plans.logical.Filter Filter]] operators through + * [[catalyst.plans.logical.Project Project]] operators, in-lining any + * [[catalyst.expressions.Alias Aliases]] that were defined in the projection. + * + * This heuristic is valid assuming the expression evaluation cost is minimal. + */ +object PushPredicateThroughProject extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter@Filter(condition, project@Project(fields, grandChild)) => + val sourceAliases = fields.collect { case a@Alias(c, _) => a.toAttribute -> c }.toMap + project.copy(child = filter.copy( + replaceAlias(condition, sourceAliases), + grandChild)) + } + + // + def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + condition transform { + case a: AttributeReference => sourceAliases.getOrElse(a, a) + } + } +} + +/** + * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be + * evaluated using only the attributes of the left or right side of an inner join. Other + * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the + * [[catalyst.plans.logical.Join Join]]. + */ +object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) => + val allConditions = + splitConjunctivePredicates(filterCondition) ++ + joinCondition.map(splitConjunctivePredicates).getOrElse(Nil) + + // Split the predicates into those that can be evaluated on the left, right, and those that + // must be evaluated after the join. + val (rightConditions, leftOrJoinConditions) = + allConditions.partition(_.references subsetOf right.outputSet) + val (leftConditions, joinConditions) = + leftOrJoinConditions.partition(_.references subsetOf left.outputSet) + + // Build the new left and right side, optionally with the pushed down filters. + val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And)) + } +} + +/** + * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already + * the correct type. + */ +object SimplifyCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Cast(e, dataType) if e.dataType == dataType => e + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala new file mode 100644 index 0000000000..22f8ea005b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package planning + + +import plans.logical.LogicalPlan +import trees._ + +/** + * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. + * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which + * can return a list of possible physical plan options. If a given strategy is unable to plan all + * of the remaining operators in the tree, it can call [[planLater]], which returns a placeholder + * object that will be filled in using other available strategies. + * + * TODO: RIGHT NOW ONLY ONE PLAN IS RETURNED EVER... + * PLAN SPACE EXPLORATION WILL BE IMPLEMENTED LATER. + * + * @tparam PhysicalPlan The type of physical plan produced by this [[QueryPlanner]] + */ +abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { + /** A list of execution strategies that can be used by the planner */ + def strategies: Seq[Strategy] + + /** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ + abstract protected class Strategy extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] + } + + /** + * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be + * filled in automatically by the QueryPlanner using the other execution strategies that are + * available. + */ + protected def planLater(plan: LogicalPlan) = apply(plan).next() + + def apply(plan: LogicalPlan): Iterator[PhysicalPlan] = { + // Obviously a lot to do here still... + val iter = strategies.view.flatMap(_(plan)).toIterator + assert(iter.hasNext, s"No plan for $plan") + iter + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala new file mode 100644 index 0000000000..64370ec7c0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * Contains classes for enumerating possible physical plans for a given logical query plan. + */ +package object planning diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala new file mode 100644 index 0000000000..613b028ca8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package planning + +import scala.annotation.tailrec + +import expressions._ +import plans.logical._ + +/** + * A pattern that matches any number of filter operations on top of another relational operator. + * Adjacent filter operators are collected and their conditions are broken up and returned as a + * sequence of conjunctive predicates. + * + * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the + * output and a relational operator. + */ +object FilteredOperation extends PredicateHelper { + type ReturnType = (Seq[Expression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) + + @tailrec + private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { + case Filter(condition, child) => + collectFilters(filters ++ splitConjunctivePredicates(condition), child) + case other => (filters, other) + } +} + +/** + * A pattern that matches any number of project or filter operations on top of another relational + * operator. All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if necessary. + */ +object PhysicalOperation extends PredicateHelper { + type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = { + val (fields, filters, child, _) = collectProjectsAndFilters(plan) + Some((fields.getOrElse(child.output), filters, child)) + } + + /** + * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two + * examples for alias in-lining/substitution. Before: + * {{{ + * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 + * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 + * }}} + * After: + * {{{ + * SELECT key AS c1 FROM t1 WHERE key > 10 + * SELECT key AS c2 FROM t1 WHERE key > 10 + * }}} + */ + def collectProjectsAndFilters(plan: LogicalPlan): + (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = + plan match { + case Project(fields, child) => + val (_, filters, other, aliases) = collectProjectsAndFilters(child) + val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] + (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + + case Filter(condition, child) => + val (fields, filters, other, aliases) = collectProjectsAndFilters(child) + val substitutedCondition = substitute(aliases)(condition) + (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + + case other => + (None, Nil, other, Map.empty) + } + + def collectAliases(fields: Seq[Expression]) = fields.collect { + case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + }.toMap + + def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } +} + +/** + * A pattern that collects all adjacent unions and returns their children as a Seq. + */ +object Unions { + def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { + case u: Union => Some(collectUnionChildren(u)) + case _ => None + } + + private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { + case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) + case other => other :: Nil + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala new file mode 100644 index 0000000000..20f230c5c4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans + +import catalyst.expressions.{SortOrder, Attribute, Expression} +import catalyst.trees._ + +abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { + self: PlanType with Product => + + def output: Seq[Attribute] + + /** + * Returns the set of attributes that are output by this node. + */ + def outputSet: Set[Attribute] = output.toSet + + /** + * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Users should not expect a specific directionality. If a specific directionality is needed, + * transformExpressionsDown or transformExpressionsUp should be used. + * @param rule the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDown(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * @param rule the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + var changed = false + + @inline def transformExpressionDown(e: Expression) = { + val newE = e.transformDown(rule) + if (newE.id != e.id && newE != e) { + changed = true + newE + } else { + e + } + } + + val newArgs = productIterator.map { + case e: Expression => transformExpressionDown(e) + case Some(e: Expression) => Some(transformExpressionDown(e)) + case m: Map[_,_] => m + case seq: Traversable[_] => seq.map { + case e: Expression => transformExpressionDown(e) + case other => other + } + case other: AnyRef => other + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * @param rule the rule to be applied to every expression in this operator. + * @return + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + var changed = false + + @inline def transformExpressionUp(e: Expression) = { + val newE = e.transformUp(rule) + if (newE.id != e.id && newE != e) { + changed = true + newE + } else { + e + } + } + + val newArgs = productIterator.map { + case e: Expression => transformExpressionUp(e) + case Some(e: Expression) => Some(transformExpressionUp(e)) + case m: Map[_,_] => m + case seq: Traversable[_] => seq.map { + case e: Expression => transformExpressionUp(e) + case other => other + } + case other: AnyRef => other + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** Returns the result of running [[transformExpressions]] on this node + * and all its children. */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transform { + case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + def expressions: Seq[Expression] = { + productIterator.flatMap { + case e: Expression => e :: Nil + case Some(e: Expression) => e :: Nil + case seq: Traversable[_] => seq.flatMap { + case e: Expression => e :: Nil + case other => Nil + } + case other => Nil + }.toSeq + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala new file mode 100644 index 0000000000..9f2283ad43 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans + +sealed abstract class JoinType +case object Inner extends JoinType +case object LeftOuter extends JoinType +case object RightOuter extends JoinType +case object FullOuter extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala new file mode 100644 index 0000000000..48ff45c3d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +abstract class BaseRelation extends LeafNode { + self: Product => + + def tableName: String + def isPartitioned: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala new file mode 100644 index 0000000000..bc7b6871df --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import catalyst.expressions._ +import catalyst.errors._ +import catalyst.types.StructType + +abstract class LogicalPlan extends QueryPlan[LogicalPlan] { + self: Product => + + /** + * Returns the set of attributes that are referenced by this node + * during evaluation. + */ + def references: Set[Attribute] + + /** + * Returns the set of attributes that this node takes as + * input from its children. + */ + lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet + + /** + * Returns true if this expression and all its children have been resolved to a specific schema + * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan + * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should + * return `false`). + */ + lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + + /** + * Returns true if all its children of this query plan have been resolved. + */ + def childrenResolved = !children.exists(!_.resolved) + + /** + * Optionally resolves the given string to a + * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as + * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = { + val parts = name.split("\\.") + // Collect all attributes that are output by this nodes children where either the first part + // matches the name or where the first part matches the scope and the second part matches the + // name. Return these matches along with any remaining parts, which represent dotted access to + // struct fields. + val options = children.flatMap(_.output).flatMap { option => + // If the first part of the desired name matches a qualifier for this possible match, drop it. + val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts + if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + } + + options.distinct match { + case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + // One match, but we also need to extract the requested nested field. + case (a, nestedFields) :: Nil => + a.dataType match { + case StructType(fields) => + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) + case _ => None // Don't know how to resolve these field references + } + case Nil => None // No matches. + case ambiguousReferences => + throw new TreeNodeException( + this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + } + } +} + +/** + * A logical plan node with no children. + */ +abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { + self: Product => + + // Leaf nodes by definition cannot reference any input attributes. + def references = Set.empty +} + +/** + * A logical node that represents a non-query command to be executed by the system. For example, + * commands can be used by parsers to represent DDL operations. + */ +abstract class Command extends LeafNode { + self: Product => + def output = Seq.empty +} + +/** + * Returned for commands supported by a given parser, but not catalyst. In general these are DDL + * commands that are passed directly to another system. + */ +case class NativeCommand(cmd: String) extends Command + +/** + * Returned by a parser when the users only wants to see what query plan would be executed, without + * actually performing the execution. + */ +case class ExplainCommand(plan: LogicalPlan) extends Command + +/** + * A logical plan node with single child. + */ +abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { + self: Product => +} + +/** + * A logical plan node with a left and right child. + */ +abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { + self: Product => +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala new file mode 100644 index 0000000000..1a1a2b9b88 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class ScriptTransformation( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + def references = input.flatMap(_.references).toSet +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala new file mode 100644 index 0000000000..b5905a4456 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ +import rules._ + +object LocalRelation { + def apply(output: Attribute*) = + new LocalRelation(output) +} + +case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) + extends LeafNode with analysis.MultiInstanceRelation { + + // TODO: Validate schema compliance. + def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData) + + /** + * Returns an identical copy of this relation with new exprIds for all attributes. Different + * attributes are required when a relation is going to be included multiple times in the same + * query. + */ + override final def newInstance: this.type = { + LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] + } + + override protected def stringArgs = Iterator(output) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala new file mode 100644 index 0000000000..8e98aab736 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { + def output = projectList.map(_.toAttribute) + def references = projectList.flatMap(_.references).toSet +} + +/** + * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional + * programming with one important additional feature, which allows the input rows to be joined with + * their output. + * @param join when true, each output row is implicitly joined with the input tuple that produced + * it. + * @param outer when true, each input row will be output at least once, even if the output of the + * given `generator` is empty. `outer` has no effect when `join` is false. + * @param alias when set, this string is applied to the schema of the output of the transformation + * as a qualifier. + */ +case class Generate( + generator: Generator, + join: Boolean, + outer: Boolean, + alias: Option[String], + child: LogicalPlan) + extends UnaryNode { + + protected def generatorOutput = + alias + .map(a => generator.output.map(_.withQualifiers(a :: Nil))) + .getOrElse(generator.output) + + def output = + if (join) child.output ++ generatorOutput else generatorOutput + + def references = + if (join) child.outputSet else generator.references +} + +case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = condition.references +} + +case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + // TODO: These aren't really the same attributes as nullability etc might change. + def output = left.output + + override lazy val resolved = + childrenResolved && + !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + + def references = Set.empty +} + +case class Join( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + + def references = condition.map(_.references).getOrElse(Set.empty) + def output = left.output ++ right.output +} + +case class InsertIntoTable( + table: BaseRelation, + partition: Map[String, Option[String]], + child: LogicalPlan, + overwrite: Boolean) + extends LogicalPlan { + // The table being inserted into is a child for the purposes of transformations. + def children = table :: child :: Nil + def references = Set.empty + def output = child.output + + override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType + } +} + +case class InsertIntoCreatedTable( + databaseName: Option[String], + tableName: String, + child: LogicalPlan) extends UnaryNode { + def references = Set.empty + def output = child.output +} + +case class WriteToFile( + path: String, + child: LogicalPlan) extends UnaryNode { + def references = Set.empty + def output = child.output +} + +case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = order.flatMap(_.references).toSet +} + +case class Aggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan) + extends UnaryNode { + + def output = aggregateExpressions.map(_.toAttribute) + def references = child.references +} + +case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = limit.references +} + +case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + def output = child.output.map(_.withQualifiers(alias :: Nil)) + def references = Set.empty +} + +case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan) + extends UnaryNode { + + def output = child.output + def references = Set.empty +} + +case class Distinct(child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = child.outputSet +} + +case object NoRelation extends LeafNode { + def output = Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala new file mode 100644 index 0000000000..f7fcdc5fdb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +/** + * Performs a physical redistribution of the data. Used when the consumer of the query + * result have expectations about the distribution and ordering of partitioned input data. + */ +abstract class RedistributeData extends UnaryNode { + self: Product => + + def output = child.output +} + +case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) + extends RedistributeData { + + def references = sortExpressions.flatMap(_.references).toSet +} + +case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) + extends RedistributeData { + + def references = partitionExpressions.flatMap(_.references).toSet +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala new file mode 100644 index 0000000000..a40ab4bbb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A a collection of common abstractions for query plans as well as + * a base logical plan representation. + */ +package object plans diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala new file mode 100644 index 0000000000..2d8f3ad335 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package physical + +import expressions._ +import types._ + +/** + * Specifies how tuples that share common expressions will be distributed when a query is executed + * in parallel on many machines. Distribution can be used to refer to two distinct physical + * properties: + * - Inter-node partitioning of data: In this case the distribution describes how tuples are + * partitioned across physical machines in a cluster. Knowing this property allows some + * operators (e.g., Aggregate) to perform partition local operations instead of global ones. + * - Intra-partition ordering of data: In this case the distribution describes guarantees made + * about how tuples are distributed within a single partition. + */ +sealed trait Distribution + +/** + * Represents a distribution where no promises are made about co-location of data. + */ +case object UnspecifiedDistribution extends Distribution + +/** + * Represents a distribution that only has a single partition and all tuples of the dataset + * are co-located. + */ +case object AllTuples extends Distribution + +/** + * Represents data where tuples that share the same values for the `clustering` + * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + +/** + * Represents data where tuples have been ordered according to the `ordering` + * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for + * the ordering expressions are contiguous and will never be split across partitions. + */ +case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { + require( + ordering != Nil, + "The ordering expressions of a OrderedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + def clustering = ordering.map(_.child).toSet +} + +sealed trait Partitioning { + /** Returns the number of partitions that the data is split across */ + val numPartitions: Int + + /** + * Returns true iff the guarantees made by this + * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy + * the partitioning scheme mandated by the `required` + * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not + * need to be re-partitioned for the `required` Distribution (it is possible that tuples within + * a partition need to be reorganized). + */ + def satisfies(required: Distribution): Boolean + + /** + * Returns true iff all distribution guarantees made by this partitioning can also be made + * for the `other` specified partitioning. + * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are + * only compatible if the `numPartitions` of them is the same. + */ + def compatibleWith(other: Partitioning): Boolean +} + +case class UnknownPartitioning(numPartitions: Int) extends Partitioning { + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case UnknownPartitioning(_) => true + case _ => false + } +} + +case object SinglePartition extends Partitioning { + val numPartitions = 1 + + override def satisfies(required: Distribution): Boolean = true + + override def compatibleWith(other: Partitioning) = other match { + case SinglePartition => true + case _ => false + } +} + +case object BroadcastPartitioning extends Partitioning { + val numPartitions = 1 + + override def satisfies(required: Distribution): Boolean = true + + override def compatibleWith(other: Partitioning) = other match { + case SinglePartition => true + case _ => false + } +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + def children = expressions + def references = expressions.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + + lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashPartitioning if h == this => true + case _ => false + } +} + +/** + * Represents a partitioning where rows are split across partitions based on some total ordering of + * the expressions specified in `ordering`. When data is partitioned in this manner the following + * two conditions are guaranteed to hold: + * - All row where the expressions in `ordering` evaluate to the same values will be in the same + * partition. + * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows + * that are in between `min` and `max` in this `ordering` will reside in this partition. + */ +case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) + extends Expression + with Partitioning { + + def children = ordering + def references = ordering.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + + lazy val clusteringSet = ordering.map(_.child).toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case r: RangePartitioning if r == this => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala new file mode 100644 index 0000000000..6ff4891a3f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package rules + +import trees._ + +abstract class Rule[TreeType <: TreeNode[_]] extends Logging { + + /** Name for this rule, automatically inferred based on class name. */ + val ruleName: String = { + val className = getClass.getName + if (className endsWith "$") className.dropRight(1) else className + } + + def apply(plan: TreeType): TreeType +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala new file mode 100644 index 0000000000..68ae30cde1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package rules + +import trees._ +import util._ + +abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { + + /** + * An execution strategy for rules that indicates the maximum number of executions. If the + * execution reaches fix point (i.e. converge) before maxIterations, it will stop. + */ + abstract class Strategy { def maxIterations: Int } + + /** A strategy that only runs once. */ + case object Once extends Strategy { val maxIterations = 1 } + + /** A strategy that runs until fix point or maxIterations times, whichever comes first. */ + case class FixedPoint(maxIterations: Int) extends Strategy + + /** A batch of rules. */ + protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + + /** Defines a sequence of rule batches, to be overridden by the implementation. */ + protected val batches: Seq[Batch] + + /** + * Executes the batches of rules defined by the subclass. The batches are executed serially + * using the defined execution strategy. Within each batch, rules are also executed serially. + */ + def apply(plan: TreeType): TreeType = { + var curPlan = plan + + batches.foreach { batch => + var iteration = 1 + var lastPlan = curPlan + curPlan = batch.rules.foldLeft(curPlan) { case (curPlan, rule) => rule(curPlan) } + + // Run until fix point (or the max number of iterations as specified in the strategy. + while (iteration < batch.strategy.maxIterations && !curPlan.fastEquals(lastPlan)) { + lastPlan = curPlan + curPlan = batch.rules.foldLeft(curPlan) { + case (curPlan, rule) => + val result = rule(curPlan) + if (!result.fastEquals(curPlan)) { + logger.debug( + s""" + |=== Applying Rule ${rule.ruleName} === + |${sideBySide(curPlan.treeString, result.treeString).mkString("\n")} + """.stripMargin) + } + + result + } + iteration += 1 + } + } + + curPlan + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala new file mode 100644 index 0000000000..26ab543082 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A framework for applying batches rewrite rules to trees, possibly to fixed point. + */ +package object rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala new file mode 100644 index 0000000000..76ede87e4e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import errors._ + +object TreeNode { + private val currentId = new java.util.concurrent.atomic.AtomicLong + protected def nextId() = currentId.getAndIncrement() +} + +/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ +private class MutableInt(var i: Int) + +abstract class TreeNode[BaseType <: TreeNode[BaseType]] { + self: BaseType with Product => + + /** Returns a Seq of the children of this node */ + def children: Seq[BaseType] + + /** + * A globally unique id for this specific instance. Not preserved across copies. + * Unlike `equals`, `id` can be used to differentiate distinct but structurally + * identical branches of a tree. + */ + val id = TreeNode.nextId() + + /** + * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike + * `equals` this function will return false for different instances of structurally identical + * trees. + */ + def sameInstance(other: TreeNode[_]): Boolean = { + this.id == other.id + } + + /** + * Faster version of equality which short-circuits when two treeNodes are the same instance. + * We don't just override Object.Equals, as doing so prevents the scala compiler from from + * generating case class `equals` methods + */ + def fastEquals(other: TreeNode[_]): Boolean = { + sameInstance(other) || this == other + } + + /** + * Runs the given function on this node and then recursively on [[children]]. + * @param f the function to be applied to each node in the tree. + */ + def foreach(f: BaseType => Unit): Unit = { + f(this) + children.foreach(_.foreach(f)) + } + + /** + * Returns a Seq containing the result of applying the given function to each + * node in this tree in a preorder traversal. + * @param f the function to be applied. + */ + def map[A](f: BaseType => A): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret += f(_)) + ret + } + + /** + * Returns a Seq by applying a function to all nodes in this tree and using the elements of the + * resulting collections. + */ + def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret ++= f(_)) + ret + } + + /** + * Returns a Seq containing the result of applying a partial function to all elements in this + * tree on which the function is defined. + */ + def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = { + val ret = new collection.mutable.ArrayBuffer[B]() + val lifted = pf.lift + foreach(node => lifted(node).foreach(ret.+=)) + ret + } + + /** + * Returns a copy of this node where `f` has been applied to all the nodes children. + */ + def mapChildren(f: BaseType => BaseType): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = f(arg.asInstanceOf[BaseType]) + if (newChild fastEquals arg) { + arg + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node with the children replaced. + * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. + */ + def withNewChildren(newChildren: Seq[BaseType]): this.type = { + assert(newChildren.size == children.size, "Incorrect number of children") + var changed = false + val remainingNewChildren = newChildren.toBuffer + val remainingOldChildren = children.toBuffer + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. + * When `rule` does not apply to a given node it is left unchanged. + * Users should not expect a specific directionality. If a specific directionality is needed, + * transformDown or transformUp should be used. + * @param rule the function use to transform this nodes children + */ + def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = { + transformDown(rule) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its + * children (pre-order). When `rule` does not apply to a given node it is left unchanged. + * @param rule the function used to transform this nodes children + */ + def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRule = rule.applyOrElse(this, identity[BaseType]) + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + transformChildrenDown(rule) + } else { + afterRule.transformChildrenDown(rule) + } + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to all the children of + * this node. When `rule` does not apply to a given node it is left unchanged. + * @param rule the function used to transform this nodes children + */ + def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case m: Map[_,_] => m + case args: Traversable[_] => args.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. + * @param rule the function use to transform this nodes children + */ + def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRuleOnChildren = transformChildrenUp(rule); + if (this fastEquals afterRuleOnChildren) { + rule.applyOrElse(this, identity[BaseType]) + } else { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } + } + + def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case m: Map[_,_] => m + case args: Traversable[_] => args.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Args to the constructor that should be copied, but not transformed. + * These are appended to the transformed args automatically by makeCopy + * @return + */ + protected def otherCopyArgs: Seq[AnyRef] = Nil + + /** + * Creates a copy of this type of tree node after a transformation. + * Must be overridden by child classes that have constructor arguments + * that are not present in the productIterator. + * @param newArgs the new product arguments. + */ + def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + try { + val defaultCtor = getClass.getConstructors.head + if (otherCopyArgs.isEmpty) { + defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + } else { + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + } + } catch { + case e: java.lang.IllegalArgumentException => + throw new TreeNodeException( + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + } + } + + /** Returns the name of this type of TreeNode. Defaults to the class name. */ + def nodeName = getClass.getSimpleName + + /** + * The arguments that should be included in the arg string. Defaults to the `productIterator`. + */ + protected def stringArgs = productIterator + + /** Returns a string representing the arguments to this node, minus any children */ + def argString: String = productIterator.flatMap { + case tn: TreeNode[_] if children contains tn => Nil + case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil + case set: Set[_] => set.mkString("{", ",", "}") :: Nil + case other => other :: Nil + }.mkString(", ") + + /** String representation of this node without any children */ + def simpleString = s"$nodeName $argString" + + override def toString: String = treeString + + /** Returns a string representation of the nodes in this tree */ + def treeString = generateTreeString(0, new StringBuilder).toString + + /** + * Returns a string representation of the nodes in this tree, where each operator is numbered. + * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + */ + def numberedTreeString = + treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") + + /** + * Returns the tree node at the specified number. + * Numbers for each node can be found in the [[numberedTreeString]]. + */ + def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number)) + + protected def getNodeNumbered(number: MutableInt): BaseType = { + if (number.i < 0) { + null.asInstanceOf[BaseType] + } else if (number.i == 0) { + this + } else { + number.i -= 1 + children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType]) + } + } + + /** Appends the string represent of this node and its children to the given StringBuilder. */ + protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { + builder.append(" " * depth) + builder.append(simpleString) + builder.append("\n") + children.foreach(_.generateTreeString(depth + 1, builder)) + builder + } +} + +/** + * A [[TreeNode]] that has two children, [[left]] and [[right]]. + */ +trait BinaryNode[BaseType <: TreeNode[BaseType]] { + def left: BaseType + def right: BaseType + + def children = Seq(left, right) +} + +/** + * A [[TreeNode]] with no children. + */ +trait LeafNode[BaseType <: TreeNode[BaseType]] { + def children = Nil +} + +/** + * A [[TreeNode]] with a single [[child]]. + */ +trait UnaryNode[BaseType <: TreeNode[BaseType]] { + def child: BaseType + def children = child :: Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala new file mode 100644 index 0000000000..e2da1d2439 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A library for easily manipulating trees of operators. Operators that extend TreeNode are + * granted the following interface: + * <ul> + * <li>Scala collection like methods (foreach, map, flatMap, collect, etc)</li> + * <li> + * transform - accepts a partial function that is used to generate a new tree. When the + * partial function can be applied to a given tree segment, that segment is replaced with the + * result. After attempting to apply the partial function to a given node, the transform + * function recursively attempts to apply the function to that node's children. + * </li> + * <li>debugging support - pretty printing, easy splicing of trees, etc.</li> + * </ul> + */ +package object trees { + // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. + protected val logger = Logger("catalyst.trees") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala new file mode 100644 index 0000000000..6eb2b62ecc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package types + +import expressions.Expression + +abstract class DataType { + /** Matches any expression that evaluates to this DataType */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } +} + +case object NullType extends DataType + +abstract class NativeType extends DataType { + type JvmType + val ordering: Ordering[JvmType] +} + +case object StringType extends NativeType { + type JvmType = String + val ordering = implicitly[Ordering[JvmType]] +} +case object BinaryType extends DataType { + type JvmType = Array[Byte] +} +case object BooleanType extends NativeType { + type JvmType = Boolean + val ordering = implicitly[Ordering[JvmType]] +} + +abstract class NumericType extends NativeType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + val numeric: Numeric[JvmType] +} + +/** Matcher for any expressions that evaluate to [[IntegralType]]s */ +object IntegralType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + +abstract class IntegralType extends NumericType { + val integral: Integral[JvmType] +} + +case object LongType extends IntegralType { + type JvmType = Long + val numeric = implicitly[Numeric[Long]] + val integral = implicitly[Integral[Long]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object IntegerType extends IntegralType { + type JvmType = Int + val numeric = implicitly[Numeric[Int]] + val integral = implicitly[Integral[Int]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object ShortType extends IntegralType { + type JvmType = Short + val numeric = implicitly[Numeric[Short]] + val integral = implicitly[Integral[Short]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object ByteType extends IntegralType { + type JvmType = Byte + val numeric = implicitly[Numeric[Byte]] + val integral = implicitly[Integral[Byte]] + val ordering = implicitly[Ordering[JvmType]] +} + +/** Matcher for any expressions that evaluate to [[FractionalType]]s */ +object FractionalType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} +abstract class FractionalType extends NumericType { + val fractional: Fractional[JvmType] +} + +case object DecimalType extends FractionalType { + type JvmType = BigDecimal + val numeric = implicitly[Numeric[BigDecimal]] + val fractional = implicitly[Fractional[BigDecimal]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object DoubleType extends FractionalType { + type JvmType = Double + val numeric = implicitly[Numeric[Double]] + val fractional = implicitly[Fractional[Double]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object FloatType extends FractionalType { + type JvmType = Float + val numeric = implicitly[Numeric[Float]] + val fractional = implicitly[Fractional[Float]] + val ordering = implicitly[Ordering[JvmType]] +} + +case class ArrayType(elementType: DataType) extends DataType + +case class StructField(name: String, dataType: DataType, nullable: Boolean) +case class StructType(fields: Seq[StructField]) extends DataType + +case class MapType(keyType: DataType, valueType: DataType) extends DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala new file mode 100644 index 0000000000..b65a5617d9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +/** + * Contains a type system for attributes produced by relations, including complex types like + * structs, arrays and maps. + */ +package object types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala new file mode 100644 index 0000000000..52adea2661 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} + +package object util { + /** + * Returns a path to a temporary file that probably does not exist. + * Note, there is always the race condition that someone created this + * file since the last time we checked. Thus, this shouldn't be used + * for anything security conscious. + */ + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + def fileToString(file: File, encoding: String = "UTF-8") = { + val inStream = new FileInputStream(file) + val outStream = new ByteArrayOutputStream + try { + var reading = true + while ( reading ) { + inStream.read() match { + case -1 => reading = false + case c => outStream.write(c) + } + } + outStream.flush() + } + finally { + inStream.close() + } + new String(outStream.toByteArray, encoding) + } + + def resourceToString( + resource:String, + encoding: String = "UTF-8", + classLoader: ClassLoader = this.getClass.getClassLoader) = { + val inStream = classLoader.getResourceAsStream(resource) + val outStream = new ByteArrayOutputStream + try { + var reading = true + while ( reading ) { + inStream.read() match { + case -1 => reading = false + case c => outStream.write(c) + } + } + outStream.flush() + } + finally { + inStream.close() + } + new String(outStream.toByteArray, encoding) + } + + def stringToFile(file: File, str: String): File = { + val out = new PrintWriter(file) + out.write(str) + out.close() + file + } + + def sideBySide(left: String, right: String): Seq[String] = { + sideBySide(left.split("\n"), right.split("\n")) + } + + def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { + val maxLeftSize = left.map(_.size).max + val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") + val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") + + leftPadded.zip(rightPadded).map { + case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r + } + } + + def stackTraceToString(t: Throwable): String = { + val out = new java.io.ByteArrayOutputStream + val writer = new PrintWriter(out) + t.printStackTrace(writer) + writer.flush() + new String(out.toByteArray) + } + + def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + + def benchmark[A](f: => A): A = { + val startTime = System.nanoTime() + val ret = f + val endTime = System.nanoTime() + println(s"${(endTime - startTime).toDouble / 1000000}ms") + ret + } + + /* FIX ME + implicit class debugLogging(a: AnyRef) { + def debugLogging() { + org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG) + } + } */ +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000..9ec31689b5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * Note that this package is located in catalyst instead of in core so that all subprojects can + * inherit the settings from this package object. + */ +package object sql { + + protected[sql] def Logger(name: String) = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name)) + + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + + type Row = catalyst.expressions.Row + + object Row { + /** + * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + */ + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala new file mode 100644 index 0000000000..1fd0d26b6f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import org.scalatest.FunSuite + +import analysis._ +import expressions._ +import plans.logical._ +import types._ + +import dsl._ +import dsl.expressions._ + +class AnalysisSuite extends FunSuite { + val analyze = SimpleAnalyzer + + val testRelation = LocalRelation('a.int) + + test("analyze project") { + assert(analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) + + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala new file mode 100644 index 0000000000..fb25e1c246 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.plans.physical._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class DistributionSuite extends FunSuite { + + protected def checkSatisfied( + inputPartitioning: Partitioning, + requiredDistribution: Distribution, + satisfied: Boolean) { + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + fail( + s""" + |== Input Partitioning == + |$inputPartitioning + |== Required Distribution == + |$requiredDistribution + |== Does input partitioning satisfy required distribution? == + |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} + """.stripMargin) + } + + test("HashPartitioning is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + // Cases which need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + AllTuples, + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + // TODO: We should check functional dependencies + /* + checkSatisfied( + ClusteredDistribution(Seq('b)), + ClusteredDistribution(Seq('b + 1)), + true) + */ + } + + test("RangePartitioning is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'b, 'a)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd)), + true) + + // Cases which need an exchange between two data properties. + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'd)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + AllTuples, + false) + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala new file mode 100644 index 0000000000..f06618ad11 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/* Implict conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class ExpressionEvaluationSuite extends FunSuite { + + test("literals") { + assert((Literal(1) + Literal(1)).apply(null) === 2) + } + + /** + * Checks for three-valued-logic. Based on: + * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 + * + * p q p OR q p AND q p = q + * True True True True True + * True False True False False + * True Unknown True Unknown Unknown + * False True True False False + * False False False False True + * False Unknown Unknown False Unknown + * Unknown True True Unknown Unknown + * Unknown False Unknown False Unknown + * Unknown Unknown Unknown Unknown Unknown + * + * p NOT p + * True False + * False True + * Unknown Unknown + */ + + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil + + test("3VL Not") { + notTrueTable.foreach { + case (v, answer) => + val expr = Not(Literal(v, BooleanType)) + val result = expr.apply(null) + if (result != answer) + fail(s"$expr should not evaluate to $result, expected: $answer") } + } + + booleanLogicTest("AND", _ && _, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, false) :: + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) + + booleanLogicTest("OR", _ || _, + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: + (false, false, false) :: + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) + + booleanLogicTest("=", _ === _, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, true) :: + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) + + def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) { + test(s"3VL $name") { + truthTable.foreach { + case (l,r,answer) => + val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val result = expr.apply(null) + if (result != answer) + fail(s"$expr should not evaluate to $result, expected: $answer") + } + } + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala new file mode 100644 index 0000000000..f595bf7e44 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import org.scalatest.FunSuite + +import catalyst.types._ + + +class HiveTypeCoercionSuite extends FunSuite { + + val rules = new HiveTypeCoercion { } + import rules._ + + test("tightest common bound for numeric and boolean types") { + def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { + var found = WidenTypes.findTightestCommonType(t1, t2) + assert(found == tightestCommon, + s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + found = WidenTypes.findTightestCommonType(t2, t1) + assert(found == tightestCommon, + s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") + } + + // Boolean + widenTest(NullType, BooleanType, Some(BooleanType)) + widenTest(BooleanType, BooleanType, Some(BooleanType)) + widenTest(IntegerType, BooleanType, None) + widenTest(LongType, BooleanType, None) + + // Integral + widenTest(NullType, ByteType, Some(ByteType)) + widenTest(NullType, IntegerType, Some(IntegerType)) + widenTest(NullType, LongType, Some(LongType)) + widenTest(ShortType, IntegerType, Some(IntegerType)) + widenTest(ShortType, LongType, Some(LongType)) + widenTest(IntegerType, LongType, Some(LongType)) + widenTest(LongType, LongType, Some(LongType)) + + // Floating point + widenTest(NullType, FloatType, Some(FloatType)) + widenTest(NullType, DoubleType, Some(DoubleType)) + widenTest(FloatType, DoubleType, Some(DoubleType)) + widenTest(FloatType, FloatType, Some(FloatType)) + widenTest(DoubleType, DoubleType, Some(DoubleType)) + + // Integral mixed with floating point. + widenTest(NullType, FloatType, Some(FloatType)) + widenTest(NullType, DoubleType, Some(DoubleType)) + widenTest(IntegerType, FloatType, Some(FloatType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(LongType, FloatType, Some(FloatType)) + widenTest(LongType, DoubleType, Some(DoubleType)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala new file mode 100644 index 0000000000..ff7c15b718 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import org.scalatest.FunSuite + +import expressions._ +import rules._ + +class RuleExecutorSuite extends FunSuite { + object DecrementLiterals extends Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case IntegerLiteral(i) if i > 0 => Literal(i - 1) + } + } + + test("only once") { + object ApplyOnce extends RuleExecutor[Expression] { + val batches = Batch("once", Once, DecrementLiterals) :: Nil + } + + assert(ApplyOnce(Literal(10)) === Literal(9)) + } + + test("to fixed point") { + object ToFixedPoint extends RuleExecutor[Expression] { + val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil + } + + assert(ToFixedPoint(Literal(10)) === Literal(0)) + } + + test("to maxIterations") { + object ToFixedPoint extends RuleExecutor[Expression] { + val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil + } + + assert(ToFixedPoint(Literal(100)) === Literal(90)) + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala new file mode 100644 index 0000000000..98bb090c29 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import scala.collection.mutable.ArrayBuffer + +import expressions._ + +import org.scalatest.{FunSuite} + +class TreeNodeSuite extends FunSuite { + + test("top node changed") { + val after = Literal(1) transform { case Literal(1, _) => Literal(2) } + assert(after === Literal(2)) + } + + test("one child changed") { + val before = Add(Literal(1), Literal(2)) + val after = before transform { case Literal(2, _) => Literal(1) } + + assert(after === Add(Literal(1), Literal(1))) + } + + test("no change") { + val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) + val after = before transform { case Literal(5, _) => Literal(1)} + + assert(before === after) + assert(before.map(_.id) === after.map(_.id)) + } + + test("collect") { + val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) + val literals = tree collect {case l: Literal => l} + + assert(literals.size === 4) + (1 to 4).foreach(i => assert(literals contains Literal(i))) + } + + test("pre-order transform") { + val actual = new ArrayBuffer[String]() + val expected = Seq("+", "1", "*", "2", "-", "3", "4") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression transformDown { + case b: BinaryExpression => {actual.append(b.symbol); b} + case l: Literal => {actual.append(l.toString); l} + } + + assert(expected === actual) + } + + test("post-order transform") { + val actual = new ArrayBuffer[String]() + val expected = Seq("1", "2", "3", "4", "-", "*", "+") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression transformUp { + case b: BinaryExpression => {actual.append(b.symbol); b} + case l: Literal => {actual.append(l.toString); l} + } + + assert(expected === actual) + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala new file mode 100644 index 0000000000..7ce42b2b0a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package optimizer + +import types.IntegerType +import util._ +import plans.logical.{LogicalPlan, LocalRelation} +import rules._ +import expressions._ +import dsl.plans._ +import dsl.expressions._ + +class ConstantFoldingSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("eliminate subqueries") { + val originalQuery = + testRelation + .subquery('y) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a.attr) + .analyze + + comparePlans(optimized, correctAnswer) + } + + /** + * Unit tests for constant folding in expressions. + */ + test("Constant folding test: expressions only have literals") { + val originalQuery = + testRelation + .select( + Literal(2) + Literal(3) + Literal(4) as Symbol("2+3+4"), + Literal(2) * Literal(3) + Literal(4) as Symbol("2*3+4"), + Literal(2) * (Literal(3) + Literal(4)) as Symbol("2*(3+4)")) + .where( + Literal(1) === Literal(1) && + Literal(2) > Literal(3) || + Literal(3) > Literal(2) ) + .groupBy( + Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2)) + )(Literal(9) / Literal(3) as Symbol("9/3")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(9) as Symbol("2+3+4"), + Literal(10) as Symbol("2*3+4"), + Literal(14) as Symbol("2*(3+4)")) + .where(Literal(true)) + .groupBy(Literal(3))(Literal(3) as Symbol("9/3")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have attribute references and literals in " + + "arithmetic operations") { + val originalQuery = + testRelation + .select( + Literal(2) + Literal(3) + 'a as Symbol("c1"), + 'a + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * 'a + Literal(4) as Symbol("c3"), + 'a * (Literal(3) + Literal(4)) as Symbol("c4")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(5) + 'a as Symbol("c1"), + 'a + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * 'a + Literal(4) as Symbol("c3"), + 'a * (Literal(7)) as Symbol("c4")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have attribute references and literals in " + + "predicates") { + val originalQuery = + testRelation + .where( + (('a > 1 && Literal(1) === Literal(1)) || + ('a < 10 && Literal(1) === Literal(2)) || + (Literal(1) === Literal(1) && 'b > 1) || + (Literal(1) === Literal(2) && 'b < 10)) && + (('a > 1 || Literal(1) === Literal(1)) && + ('a < 10 || Literal(1) === Literal(2)) && + (Literal(1) === Literal(1) || 'b > 1) && + (Literal(1) === Literal(2) || 'b < 10))) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .where(('a > 1 || 'b > 1) && ('a < 10 && 'b < 10)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have foldable functions") { + val originalQuery = + testRelation + .select( + Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), + Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(5) + 'a as Symbol("c1"), + Literal(3) as Symbol("c2")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have nonfoldable functions") { + val originalQuery = + testRelation + .select( + Rand + Literal(1) as Symbol("c1"), + Sum('a) as Symbol("c2")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Rand + Literal(1.0) as Symbol("c1"), + Sum('a) as Symbol("c2")) + .analyze + + comparePlans(optimized, correctAnswer) + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala new file mode 100644 index 0000000000..cd611b3fb3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -0,0 +1,222 @@ +package org.apache.spark.sql +package catalyst +package optimizer + +import expressions._ +import plans.logical._ +import rules._ +import util._ + +import dsl.plans._ +import dsl.expressions._ + +class FilterPushdownSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("Filter Pushdown", Once, + EliminateSubqueries, + CombineFilters, + PushPredicateThroughProject, + PushPredicateThroughInnerJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + // This test already passes. + test("eliminate subqueries") { + val originalQuery = + testRelation + .subquery('y) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a.attr) + .analyze + + comparePlans(optimized, correctAnswer) + } + + // After this line is unimplemented. + test("simple push down") { + val originalQuery = + testRelation + .select('a) + .where('a === 1) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a === 1) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("can't push without rewrite") { + val originalQuery = + testRelation + .select('a + 'b as 'e) + .where('e === 1) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a + 'b === 1) + .select('a + 'b as 'e) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("filters: combines filters") { + val originalQuery = testRelation + .select('a) + .where('a === 1) + .where('a === 2) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a === 1 && 'a === 2) + .select('a).analyze + + + comparePlans(optimized, correctAnswer) + } + + + test("joins: push to either side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1) + .where("y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation.where('b === 2) + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push to one side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: rewrite filter to push to either side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation.where('b === 2) + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: can't push down") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some("x.b".attr === "y.b".attr)) + } + val optimized = Optimize(originalQuery.analyze) + + comparePlans(optimizer.EliminateSubqueries(originalQuery.analyze), optimized) + } + + test("joins: conjunctive predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1)) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.where('a === 1).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } + + test("joins: conjunctive predicates #2") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1)) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } + + test("joins: conjunctive predicates #3") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + z.join(x.join(y)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + } + + val optimized = Optimize(originalQuery.analyze) + val lleft = testRelation.where('a >= 3).subquery('z) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + lleft.join( + left.join(right, condition = Some("x.b".attr === "y.b".attr)), + condition = Some("z.a".attr === "x.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } +}
\ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala new file mode 100644 index 0000000000..7b3653d0f9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala @@ -0,0 +1,44 @@ +package org.apache.spark.sql +package catalyst +package optimizer + +import org.scalatest.FunSuite + +import types.IntegerType +import util._ +import plans.logical.{LogicalPlan, LocalRelation} +import expressions._ +import dsl._ + +/* Implicit conversions for creating query plans */ + +/** + * Provides helper methods for comparing plans produced by optimization rules with the expected + * result + */ +class OptimizerTest extends FunSuite { + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + val minId = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)).min + plan transformAllExpressions { + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + } + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizeExprIds(plan1) + val normalized2 = normalizeExprIds(plan2) + if (normalized1 != normalized2) + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } +}
\ No newline at end of file |