From 83f2a2f14e4145a04672e42216d43100a66b1fc2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 7 Apr 2014 10:45:31 -0700 Subject: [sql] Rename Expression.apply to eval for better readability. Also used this opportunity to add a bunch of override's and made some members private. Author: Reynold Xin Closes #340 from rxin/eval and squashes the following commits: a7c7ca7 [Reynold Xin] Fixed conflicts in merge. 9069de6 [Reynold Xin] Merge branch 'master' into eval 3ccc313 [Reynold Xin] Merge branch 'master' into eval 1a47e10 [Reynold Xin] Renamed apply to eval for generators and added a bunch of override's. ea061de [Reynold Xin] Rename Expression.apply to eval for better readability. --- .../sql/catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 26 +++--- .../sql/catalyst/expressions/Projection.scala | 5 +- .../spark/sql/catalyst/expressions/Row.scala | 4 +- .../spark/sql/catalyst/expressions/ScalaUdf.scala | 8 +- .../sql/catalyst/expressions/WrapDynamic.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 96 +++++++++++----------- .../sql/catalyst/expressions/arithmetic.scala | 12 +-- .../sql/catalyst/expressions/complexTypes.scala | 14 ++-- .../sql/catalyst/expressions/generators.scala | 20 ++--- .../spark/sql/catalyst/expressions/literals.scala | 6 +- .../catalyst/expressions/namedExpressions.scala | 2 +- .../sql/catalyst/expressions/nullFunctions.scala | 12 +-- .../sql/catalyst/expressions/predicates.scala | 46 +++++------ .../catalyst/expressions/stringOperations.scala | 10 +-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../expressions/ExpressionEvaluationSuite.scala | 8 +- 18 files changed, 138 insertions(+), 141 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 37b9035df9..4ebf6c4584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -45,7 +45,7 @@ case class BoundReference(ordinal: Int, baseReference: Attribute) override def toString = s"$baseReference:$ordinal" - override def apply(input: Row): Any = input(ordinal) + override def eval(input: Row): Any = input(ordinal) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 941b53fe70..89226999ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -185,8 +185,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case DoubleType => castToDouble } - override def apply(input: Row): Any = { - val evaluated = child.apply(input) + override def eval(input: Row): Any = { + val evaluated = child.eval(input) if (evaluated == null) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3d1952550..f190bd0cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} @@ -50,7 +50,7 @@ abstract class Expression extends TreeNode[Expression] { def references: Set[Attribute] /** Returns the result of evaluating this expression on a given input Row */ - def apply(input: Row = null): EvaluatedType = + def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") /** @@ -73,7 +73,7 @@ abstract class Expression extends TreeNode[Expression] { */ @inline def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { - val evalE = e.apply(i) + val evalE = e.eval(i) if (evalE == null) { null } else { @@ -102,11 +102,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i) + val evalE1 = e1.eval(i) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i) + val evalE2 = e2.eval(i) if (evalE2 == null) { null } else { @@ -135,11 +135,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i: Row) + val evalE1 = e1.eval(i: Row) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i: Row) + val evalE2 = e2.eval(i: Row) if (evalE2 == null) { null } else { @@ -168,11 +168,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i) + val evalE1 = e1.eval(i) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i) + val evalE2 = e2.eval(i) if (evalE2 == null) { null } else { @@ -205,11 +205,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i) + val evalE1 = e1.eval(i) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i) + val evalE2 = e2.eval(i) if (evalE2 == null) { null } else { @@ -231,7 +231,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable = left.foldable && right.foldable - def references = left.references ++ right.references + override def references = left.references ++ right.references override def toString = s"($left $symbol $right)" } @@ -243,5 +243,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - def references = child.references + override def references = child.references } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 5576ecbb65..c9b7cea6a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -27,11 +27,12 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { this(expressions.map(BindReferences.bindReference(_, inputSchema))) protected val exprArray = expressions.toArray + def apply(input: Row): Row = { val outputArray = new Array[Any](exprArray.length) var i = 0 while (i < exprArray.length) { - outputArray(i) = exprArray(i).apply(input) + outputArray(i) = exprArray(i).eval(input) i += 1 } new GenericRow(outputArray) @@ -58,7 +59,7 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).apply(input) + mutableRow(i) = exprArray(i).eval(input) i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 9f4d84466e..0f06ea088e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -212,8 +212,8 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { var i = 0 while (i < ordering.size) { val order = ordering(i) - val left = order.child.apply(a) - val right = order.child.apply(b) + val left = order.child.eval(a) + val right = order.child.eval(b) if (left == null && right == null) { // Both null, continue looking. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index f53d8504b0..5e089f7618 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,13 +27,13 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { children.size match { - case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input)) + case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).apply(input), - children(1).apply(input)) + children(0).eval(input), + children(1).eval(input)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 9828d0b9bd..e787c59e75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -30,7 +30,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { def references = children.toSet def dataType = DynamicType - override def apply(input: Row): DynamicRow = input match { + override def eval(input: Row): DynamicRow = input match { // Avoid copy for generic rows. case g: GenericRow => new DynamicRow(children, g.values) case otherRowType => new DynamicRow(children, otherRowType.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 53b884a41e..5edcea1427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -43,7 +43,7 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples. + * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ abstract class PartialAggregate extends AggregateExpression { @@ -63,28 +63,28 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - type EvaluatedType = Any + override type EvaluatedType = Any /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - def references = base.references - def nullable = base.nullable - def dataType = base.dataType + override def references = base.references + override def nullable = base.nullable + override def dataType = base.dataType def update(input: Row): Unit - override def apply(input: Row): Any + override def eval(input: Row): Any // Do we really need this? - def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = IntegerType + override def references = child.references + override def nullable = false + override def dataType = IntegerType override def toString = s"COUNT($child)" - def asPartial: SplitEvaluation = { + override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) } @@ -93,18 +93,18 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod } case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { - def children = expressions - def references = expressions.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = expressions + override def references = expressions.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" override def newInstance()= new CountDistinctFunction(expressions, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = DoubleType + override def references = child.references + override def nullable = false + override def dataType = DoubleType override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { @@ -122,9 +122,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = child.dataType + override def references = child.references + override def nullable = false + override def dataType = child.dataType override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -140,18 +140,18 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = child.dataType + override def references = child.references + override def nullable = false + override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" override def newInstance()= new SumDistinctFunction(child, this) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = child.nullable - def dataType = child.dataType + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType override def toString = s"FIRST($child)" override def asPartial: SplitEvaluation = { @@ -169,17 +169,15 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. private var count: Long = _ - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow)) + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) - - private val addFunction = Add(sum, expr) - override def apply(input: Row): Any = - sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble + override def eval(input: Row): Any = + sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble - def update(input: Row): Unit = { + override def update(input: Row): Unit = { count += 1 sum.update(addFunction, input) } @@ -190,28 +188,28 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag var count: Int = _ - def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.apply(input)) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.eval(input)) if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { count += 1 } } - override def apply(input: Row): Any = count + override def eval(input: Row): Any = count } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null)) + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null)) private val addFunction = Add(sum, expr) - def update(input: Row): Unit = { + override def update(input: Row): Unit = { sum.update(addFunction, input) } - override def apply(input: Row): Any = sum.apply(null) + override def eval(input: Row): Any = sum.eval(null) } case class SumDistinctFunction(expr: Expression, base: AggregateExpression) @@ -219,16 +217,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. - val seen = new scala.collection.mutable.HashSet[Any]() + private val seen = new scala.collection.mutable.HashSet[Any]() - def update(input: Row): Unit = { - val evaluatedExpr = expr.apply(input) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { seen += evaluatedExpr } } - override def apply(input: Row): Any = + override def eval(input: Row): Any = seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } @@ -239,14 +237,14 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio val seen = new scala.collection.mutable.HashSet[Any]() - def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.apply(input)) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.eval(input)) if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { seen += evaluatedExpr } } - override def apply(input: Row): Any = seen.size + override def eval(input: Row): Any = seen.size } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -254,11 +252,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag var result: Any = null - def update(input: Row): Unit = { + override def update(input: Row): Unit = { if (result == null) { - result = expr.apply(input) + result = expr.eval(input) } } - override def apply(input: Row): Any = result + override def eval(input: Row): Any = result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fba056e7c0..c79c1847ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -28,7 +28,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { def nullable = child.nullable override def toString = s"-$child" - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { n1(child, input, _.negate(_)) } } @@ -55,25 +55,25 @@ abstract class BinaryArithmetic extends BinaryExpression { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "+" - override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _)) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "-" - override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _)) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "*" - override def apply(input: Row): Any = n2(input, left, right, _.times(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.times(_, _)) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" - override def apply(input: Row): Any = dataType match { + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) } @@ -83,5 +83,5 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" - override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _)) + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index ab96618d73..c947155cb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -39,10 +39,10 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def toString = s"$child[$ordinal]" - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { if (child.dataType.isInstanceOf[ArrayType]) { - val baseValue = child.apply(input).asInstanceOf[Seq[_]] - val o = ordinal.apply(input).asInstanceOf[Int] + val baseValue = child.eval(input).asInstanceOf[Seq[_]] + val o = ordinal.eval(input).asInstanceOf[Int] if (baseValue == null) { null } else if (o >= baseValue.size || o < 0) { @@ -51,8 +51,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { baseValue(o) } } else { - val baseValue = child.apply(input).asInstanceOf[Map[Any, _]] - val key = ordinal.apply(input) + val baseValue = child.eval(input).asInstanceOf[Map[Any, _]] + val key = ordinal.eval(input) if (baseValue == null) { null } else { @@ -85,8 +85,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] - override def apply(input: Row): Any = { - val baseValue = child.apply(input).asInstanceOf[Row] + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e9b491b10a..dd78614754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -35,17 +35,17 @@ import org.apache.spark.sql.catalyst.types._ * requested. The attributes produced by this function will be automatically copied anytime rules * result in changes to the Generator or its children. */ -abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { +abstract class Generator extends Expression { self: Product => - type EvaluatedType = TraversableOnce[Row] + override type EvaluatedType = TraversableOnce[Row] - lazy val dataType = + override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) - def nullable = false + override def nullable = false - def references = children.flatMap(_.references).toSet + override def references = children.flatMap(_.references).toSet /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -63,7 +63,7 @@ abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { } /** Should be implemented by child classes to perform specific Generators. */ - def apply(input: Row): TraversableOnce[Row] + override def eval(input: Row): TraversableOnce[Row] /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ override def makeCopy(newArgs: Array[AnyRef]): this.type = { @@ -83,7 +83,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - lazy val elementTypes = child.dataType match { + private lazy val elementTypes = child.dataType match { case ArrayType(et) => et :: Nil case MapType(kt,vt) => kt :: vt :: Nil } @@ -100,13 +100,13 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def apply(input: Row): TraversableOnce[Row] = { + override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_) => - val inputArray = child.apply(input).asInstanceOf[Seq[Any]] + val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _) => - val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]] + val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d879b2b5e8..e15e16d633 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,7 +57,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def toString = if (value != null) value.toString else "null" type EvaluatedType = Any - override def apply(input: Row):Any = value + override def eval(input: Row):Any = value } // TODO: Specialize @@ -69,8 +69,8 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf def references = Set.empty def update(expression: Expression, input: Row) = { - value = expression.apply(input) + value = expression.eval(input) } - override def apply(input: Row) = value + override def eval(input: Row) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 69c8bed309..eb4bc8e755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -79,7 +79,7 @@ case class Alias(child: Expression, name: String) type EvaluatedType = Any - override def apply(input: Row) = child.apply(input) + override def eval(input: Row) = child.eval(input) def dataType = child.dataType def nullable = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5a47768dcb..ce6d99c911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -41,11 +41,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { throw new UnresolvedException(this, "Coalesce cannot have children of different types.") } - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { var i = 0 var result: Any = null while(i < children.size && result == null) { - result = children(i).apply(input) + result = children(i).eval(input) i += 1 } result @@ -57,8 +57,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def foldable = child.foldable def nullable = false - override def apply(input: Row): Any = { - child.apply(input) == null + override def eval(input: Row): Any = { + child.eval(input) == null } } @@ -68,7 +68,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E def nullable = false override def toString = s"IS NOT NULL $child" - override def apply(input: Row): Any = { - child.apply(input) != null + override def eval(input: Row): Any = { + child.eval(input) != null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b74809e5ca..da5b2cf5b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampTy object InterpretedPredicate { def apply(expression: Expression): (Row => Boolean) = { - (r: Row) => expression.apply(r).asInstanceOf[Boolean] + (r: Row) => expression.eval(r).asInstanceOf[Boolean] } } @@ -54,8 +54,8 @@ case class Not(child: Expression) extends Predicate with trees.UnaryNode[Express def nullable = child.nullable override def toString = s"NOT $child" - override def apply(input: Row): Any = { - child.apply(input) match { + override def eval(input: Row): Any = { + child.eval(input) match { case null => null case b: Boolean => !b } @@ -71,18 +71,18 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value IN ${list.mkString("(", ",", ")")}" - override def apply(input: Row): Any = { - val evaluatedValue = value.apply(input) - list.exists(e => e.apply(input) == evaluatedValue) + override def eval(input: Row): Any = { + val evaluatedValue = value.eval(input) + list.exists(e => e.eval(input) == evaluatedValue) } } case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == false || r == false) { false } else if (l == null || r == null ) { @@ -96,9 +96,9 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { case class Or(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "||" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == true || r == true) { true } else if (l == null || r == null) { @@ -115,31 +115,31 @@ abstract class BinaryComparison extends BinaryPredicate { case class Equals(left: Expression, right: Expression) extends BinaryComparison { def symbol = "=" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == null || r == null) null else l == r } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" - override def apply(input: Row): Any = c2(input, left, right, _.lt(_, _)) + override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<=" - override def apply(input: Row): Any = c2(input, left, right, _.lteq(_, _)) + override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _)) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">" - override def apply(input: Row): Any = c2(input, left, right, _.gt(_, _)) + override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _)) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">=" - override def apply(input: Row): Any = c2(input, left, right, _.gteq(_, _)) + override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _)) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -159,11 +159,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } type EvaluatedType = Any - override def apply(input: Row): Any = { - if (predicate(input).asInstanceOf[Boolean]) { - trueValue.apply(input) + override def eval(input: Row): Any = { + if (predicate.eval(input).asInstanceOf[Boolean]) { + trueValue.eval(input) } else { - falseValue.apply(input) + falseValue.eval(input) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 42b7a9b125..a27c71db1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,8 +22,6 @@ import java.util.regex.Pattern import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.catalyst.types.BooleanType -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.errors.`package`.TreeNodeException trait StringRegexExpression { @@ -52,12 +50,12 @@ trait StringRegexExpression { protected def pattern(str: String) = if(cache == null) compile(str) else cache - override def apply(input: Row): Any = { - val l = left.apply(input) - if(l == null) { + override def eval(input: Row): Any = { + val l = left.eval(input) + if (l == null) { null } else { - val r = right.apply(input) + val r = right.eval(input) if(r == null) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3dd6818029..37b23ba582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -45,7 +45,7 @@ object ConstantFolding extends Rule[LogicalPlan] { case q: LogicalPlan => q transformExpressionsDown { // Skip redundant folding of literals. case l: Literal => l - case e if e.foldable => Literal(e.apply(null), e.dataType) + case e if e.foldable => Literal(e.eval(null), e.dataType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 43876033d3..92987405aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).apply(null) === 2) + assert((Literal(1) + Literal(1)).eval(null) === 2) } /** @@ -62,7 +62,7 @@ class ExpressionEvaluationSuite extends FunSuite { notTrueTable.foreach { case (v, answer) => val expr = Not(Literal(v, BooleanType)) - val result = expr.apply(null) + val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } } @@ -105,7 +105,7 @@ class ExpressionEvaluationSuite extends FunSuite { truthTable.foreach { case (l,r,answer) => val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) - val result = expr.apply(null) + val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } @@ -113,7 +113,7 @@ class ExpressionEvaluationSuite extends FunSuite { } def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.apply(inputRow) + expression.eval(inputRow) } def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { -- cgit v1.2.3