aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
committerReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
commit119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8 (patch)
tree714df6362313e93bee0e9dba2f84b3ba1697e555 /sql
parentb1b35ca2e440df40b253bf967bb93705d355c1c0 (diff)
downloadspark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.gz
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.bz2
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.zip
[SPARK-5097][SQL] DataFrame
This pull request redesigns the existing Spark SQL dsl, which already provides data frame like functionalities. TODOs: With the exception of Python support, other tasks can be done in separate, follow-up PRs. - [ ] Audit of the API - [ ] Documentation - [ ] More test cases to cover the new API - [x] Python support - [ ] Type alias SchemaRDD Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4173 from rxin/df1 and squashes the following commits: 0a1a73b [Reynold Xin] Merge branch 'df1' of github.com:rxin/spark into df1 23b4427 [Reynold Xin] Mima. 828f70d [Reynold Xin] Merge pull request #7 from davies/df 257b9e6 [Davies Liu] add repartition 6bf2b73 [Davies Liu] fix collect with UDT and tests e971078 [Reynold Xin] Missing quotes. b9306b4 [Reynold Xin] Remove removeColumn/updateColumn for now. a728bf2 [Reynold Xin] Example rename. e8aa3d3 [Reynold Xin] groupby -> groupBy. 9662c9e [Davies Liu] improve DataFrame Python API 4ae51ea [Davies Liu] python API for dataframe 1e5e454 [Reynold Xin] Fixed a bug with symbol conversion. 2ca74db [Reynold Xin] Couple minor fixes. ea98ea1 [Reynold Xin] Documentation & literal expressions. 2b22684 [Reynold Xin] Got rid of IntelliJ problems. 02bbfbc [Reynold Xin] Tightening imports. ffbce66 [Reynold Xin] Fixed compilation error. 59b6d8b [Reynold Xin] Style violation. b85edfb [Reynold Xin] ALS. 8c37f0a [Reynold Xin] Made MLlib and examples compile 6d53134 [Reynold Xin] Hive module. d35efd5 [Reynold Xin] Fixed compilation error. ce4a5d2 [Reynold Xin] Fixed test cases in SQL except ParquetIOSuite. 66d5ef1 [Reynold Xin] SQLContext minor patch. c9bcdc0 [Reynold Xin] Checkpoint: SQL module compiles!
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala528
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala596
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala139
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Literal.scala98
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala85
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala511
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala139
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api.scala289
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala495
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala6
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java4
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala119
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala65
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala126
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala2
-rwxr-xr-xsql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala2
-rw-r--r--sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala6
-rw-r--r--sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala17
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala2
50 files changed, 2494 insertions, 1073 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
index 22941edef2..4c5fb3f45b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] {
.toSet
plan transform {
- case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance
+ case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 3035d934ff..f388cd5972 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression {
* For example the SQL expression "1 + 1 AS a" could be represented as follows:
* Alias(Add(Literal(1), Literal(1), "a")()
*
+ * Note that exprId and qualifiers are in a separate parameter list because
+ * we only pattern match on child and name.
+ *
* @param child the computation being performed
* @param name the name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 613f4bb09d..5dc0539cae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -17,9 +17,24 @@
package org.apache.spark.sql.catalyst.plans
+object JoinType {
+ def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
+ case "inner" => Inner
+ case "outer" | "full" | "fullouter" => FullOuter
+ case "leftouter" | "left" => LeftOuter
+ case "rightouter" | "right" => RightOuter
+ case "leftsemi" => LeftSemi
+ }
+}
+
sealed abstract class JoinType
+
case object Inner extends JoinType
+
case object LeftOuter extends JoinType
+
case object RightOuter extends JoinType
+
case object FullOuter extends JoinType
+
case object LeftSemi extends JoinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
index 19769986ef..d90af45b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.types.{StructType, StructField}
object LocalRelation {
- def apply(output: Attribute*) =
- new LocalRelation(output)
+ def apply(output: Attribute*): LocalRelation = new LocalRelation(output)
+
+ def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation(
+ StructType(output1 +: output).toAttributes
+ )
}
case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index e715d9434a..bc22f68833 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -80,7 +80,7 @@ private[sql] trait CacheManager {
* the in-memory columnar representation of the underlying table is expensive.
*/
private[sql] def cacheQuery(
- query: SchemaRDD,
+ query: DataFrame,
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
@@ -100,7 +100,7 @@ private[sql] trait CacheManager {
}
/** Removes the data for the given SchemaRDD from the cache */
- private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock {
+ private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
require(dataIndex >= 0, s"Table $query is not cached.")
@@ -110,7 +110,7 @@ private[sql] trait CacheManager {
/** Tries to remove the data for the given SchemaRDD from the cache if it's cached */
private[sql] def tryUncacheQuery(
- query: SchemaRDD,
+ query: DataFrame,
blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,7 +123,7 @@ private[sql] trait CacheManager {
}
/** Optionally returns cached data for the given SchemaRDD */
- private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
+ private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 0000000000..7fc8347428
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,528 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
+import org.apache.spark.sql.types._
+
+
+object Column {
+ def unapply(col: Column): Option[Expression] = Some(col.expr)
+
+ def apply(colName: String): Column = new Column(colName)
+}
+
+
+/**
+ * A column in a [[DataFrame]].
+ *
+ * `Column` instances can be created by:
+ * {{{
+ * // 1. Select a column out of a DataFrame
+ * df("colName")
+ *
+ * // 2. Create a literal expression
+ * Literal(1)
+ *
+ * // 3. Create new columns from
+ * }}}
+ *
+ */
+// TODO: Improve documentation.
+class Column(
+ sqlContext: Option[SQLContext],
+ plan: Option[LogicalPlan],
+ val expr: Expression)
+ extends DataFrame(sqlContext, plan) with ExpressionApi {
+
+ /** Turn a Catalyst expression into a `Column`. */
+ protected[sql] def this(expr: Expression) = this(None, None, expr)
+
+ /**
+ * Create a new `Column` expression based on a column or attribute name.
+ * The resolution of this is the same as SQL. For example:
+ *
+ * - "colName" becomes an expression selecting the column named "colName".
+ * - "*" becomes an expression selecting all columns.
+ * - "df.*" becomes an expression selecting all columns in data frame "df".
+ */
+ def this(name: String) = this(name match {
+ case "*" => Star(None)
+ case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2)))
+ case _ => UnresolvedAttribute(name)
+ })
+
+ override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined
+
+ /**
+ * An implicit conversion function internal to this class. This function creates a new Column
+ * based on an expression. If the expression itself is not named, it aliases the expression
+ * by calling it "col".
+ */
+ private[this] implicit def toColumn(expr: Expression): Column = {
+ val projectedPlan = plan.map { p =>
+ Project(Seq(expr match {
+ case named: NamedExpression => named
+ case unnamed: Expression => Alias(unnamed, "col")()
+ }), p)
+ }
+ new Column(sqlContext, projectedPlan, expr)
+ }
+
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Select the amount column and negates all values.
+ * df.select( -df("amount") )
+ * }}}
+ */
+ override def unary_- : Column = UnaryMinus(expr)
+
+ /**
+ * Bitwise NOT.
+ * {{{
+ * // Select the flags column and negate every bit.
+ * df.select( ~df("flags") )
+ * }}}
+ */
+ override def unary_~ : Column = BitwiseNot(expr)
+
+ /**
+ * Invert a boolean expression, i.e. NOT.
+ * {{
+ * // Select rows that are not active (isActive === false)
+ * df.select( !df("isActive") )
+ * }}
+ */
+ override def unary_! : Column = Not(expr)
+
+
+ /**
+ * Equality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA equals colB.
+ * df.select( df("colA") === df("colB") )
+ * df.select( df("colA".equalTo(df("colB")) )
+ * }}}
+ */
+ override def === (other: Column): Column = EqualTo(expr, other.expr)
+
+ /**
+ * Equality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA is "Zaharia".
+ * df.select( df("colA") === "Zaharia")
+ * df.select( df("colA".equalTo("Zaharia") )
+ * }}}
+ */
+ override def === (literal: Any): Column = this === Literal.anyToLiteral(literal)
+
+ /**
+ * Equality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA equals colB.
+ * df.select( df("colA") === df("colB") )
+ * df.select( df("colA".equalTo(df("colB")) )
+ * }}}
+ */
+ override def equalTo(other: Column): Column = this === other
+
+ /**
+ * Equality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA is "Zaharia".
+ * df.select( df("colA") === "Zaharia")
+ * df.select( df("colA".equalTo("Zaharia") )
+ * }}}
+ */
+ override def equalTo(literal: Any): Column = this === literal
+
+ /**
+ * Inequality test with an expression.
+ * {{{
+ * // The following two both select rows in which colA does not equal colB.
+ * df.select( df("colA") !== df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ * }}}
+ */
+ override def !== (other: Column): Column = Not(EqualTo(expr, other.expr))
+
+ /**
+ * Inequality test with a literal value.
+ * {{{
+ * // The following two both select rows in which colA does not equal equal 15.
+ * df.select( df("colA") !== 15 )
+ * df.select( !(df("colA") === 15) )
+ * }}}
+ */
+ override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal)
+
+ /**
+ * Greater than an expression.
+ * {{{
+ * // The following selects people older than 21.
+ * people.select( people("age") > Literal(21) )
+ * }}}
+ */
+ override def > (other: Column): Column = GreaterThan(expr, other.expr)
+
+ /**
+ * Greater than a literal value.
+ * {{{
+ * // The following selects people older than 21.
+ * people.select( people("age") > 21 )
+ * }}}
+ */
+ override def > (literal: Any): Column = this > Literal.anyToLiteral(literal)
+
+ /**
+ * Less than an expression.
+ * {{{
+ * // The following selects people younger than 21.
+ * people.select( people("age") < Literal(21) )
+ * }}}
+ */
+ override def < (other: Column): Column = LessThan(expr, other.expr)
+
+ /**
+ * Less than a literal value.
+ * {{{
+ * // The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ * }}}
+ */
+ override def < (literal: Any): Column = this < Literal.anyToLiteral(literal)
+
+ /**
+ * Less than or equal to an expression.
+ * {{{
+ * // The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= Literal(21) )
+ * }}}
+ */
+ override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr)
+
+ /**
+ * Less than or equal to a literal value.
+ * {{{
+ * // The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ * }}}
+ */
+ override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal)
+
+ /**
+ * Greater than or equal to an expression.
+ * {{{
+ * // The following selects people age 21 or older than 21.
+ * people.select( people("age") >= Literal(21) )
+ * }}}
+ */
+ override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr)
+
+ /**
+ * Greater than or equal to a literal value.
+ * {{{
+ * // The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ * }}}
+ */
+ override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal)
+
+ /**
+ * Equality test with an expression that is safe for null values.
+ */
+ override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr)
+
+ /**
+ * Equality test with a literal value that is safe for null values.
+ */
+ override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal)
+
+ /**
+ * True if the current expression is null.
+ */
+ override def isNull: Column = IsNull(expr)
+
+ /**
+ * True if the current expression is NOT null.
+ */
+ override def isNotNull: Column = IsNotNull(expr)
+
+ /**
+ * Boolean OR with an expression.
+ * {{{
+ * // The following selects people that are in school or employed.
+ * people.select( people("inSchool") || people("isEmployed") )
+ * }}}
+ */
+ override def || (other: Column): Column = Or(expr, other.expr)
+
+ /**
+ * Boolean OR with a literal value.
+ * {{{
+ * // The following selects everything.
+ * people.select( people("inSchool") || true )
+ * }}}
+ */
+ override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal)
+
+ /**
+ * Boolean AND with an expression.
+ * {{{
+ * // The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ * }}}
+ */
+ override def && (other: Column): Column = And(expr, other.expr)
+
+ /**
+ * Boolean AND with a literal value.
+ * {{{
+ * // The following selects people that are in school.
+ * people.select( people("inSchool") && true )
+ * }}}
+ */
+ override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise AND with an expression.
+ */
+ override def & (other: Column): Column = BitwiseAnd(expr, other.expr)
+
+ /**
+ * Bitwise AND with a literal value.
+ */
+ override def & (literal: Any): Column = this & Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise OR with an expression.
+ */
+ override def | (other: Column): Column = BitwiseOr(expr, other.expr)
+
+ /**
+ * Bitwise OR with a literal value.
+ */
+ override def | (literal: Any): Column = this | Literal.anyToLiteral(literal)
+
+ /**
+ * Bitwise XOR with an expression.
+ */
+ override def ^ (other: Column): Column = BitwiseXor(expr, other.expr)
+
+ /**
+ * Bitwise XOR with a literal value.
+ */
+ override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal)
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ * }}}
+ */
+ override def + (other: Column): Column = Add(expr, other.expr)
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // The following selects the sum of a person's height and 10.
+ * people.select( people("height") + 10 )
+ * }}}
+ */
+ override def + (literal: Any): Column = this + Literal.anyToLiteral(literal)
+
+ /**
+ * Subtraction. Substract the other expression from this expression.
+ * {{{
+ * // The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ * }}}
+ */
+ override def - (other: Column): Column = Subtract(expr, other.expr)
+
+ /**
+ * Subtraction. Substract a literal value from this expression.
+ * {{{
+ * // The following selects a person's height and substract it by 10.
+ * people.select( people("height") - 10 )
+ * }}}
+ */
+ override def - (literal: Any): Column = this - Literal.anyToLiteral(literal)
+
+ /**
+ * Multiply this expression and another expression.
+ * {{{
+ * // The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ * }}}
+ */
+ override def * (other: Column): Column = Multiply(expr, other.expr)
+
+ /**
+ * Multiply this expression and a literal value.
+ * {{{
+ * // The following multiplies a person's height by 10.
+ * people.select( people("height") * 10 )
+ * }}}
+ */
+ override def * (literal: Any): Column = this * Literal.anyToLiteral(literal)
+
+ /**
+ * Divide this expression by another expression.
+ * {{{
+ * // The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ * }}}
+ */
+ override def / (other: Column): Column = Divide(expr, other.expr)
+
+ /**
+ * Divide this expression by a literal value.
+ * {{{
+ * // The following divides a person's height by 10.
+ * people.select( people("height") / 10 )
+ * }}}
+ */
+ override def / (literal: Any): Column = this / Literal.anyToLiteral(literal)
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ */
+ override def % (other: Column): Column = Remainder(expr, other.expr)
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ */
+ override def % (literal: Any): Column = this % Literal.anyToLiteral(literal)
+
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the evaluated values of the arguments.
+ */
+ @scala.annotation.varargs
+ override def in(list: Column*): Column = In(expr, list.map(_.expr))
+
+ override def like(other: Column): Column = Like(expr, other.expr)
+
+ override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal))
+
+ override def rlike(other: Column): Column = RLike(expr, other.expr)
+
+ override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal))
+
+
+ override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
+
+ override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr)
+
+ override def getField(fieldName: String): Column = GetField(expr, fieldName)
+
+
+ override def substr(startPos: Column, len: Column): Column =
+ Substring(expr, startPos.expr, len.expr)
+
+ override def substr(startPos: Int, len: Int): Column =
+ this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len))
+
+ override def contains(other: Column): Column = Contains(expr, other.expr)
+
+ override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal))
+
+
+ override def startsWith(other: Column): Column = StartsWith(expr, other.expr)
+
+ override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal))
+
+ override def endsWith(other: Column): Column = EndsWith(expr, other.expr)
+
+ override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal))
+
+ override def as(alias: String): Column = Alias(expr, alias)()
+
+ override def cast(to: DataType): Column = Cast(expr, to)
+
+ override def desc: Column = SortOrder(expr, Descending)
+
+ override def asc: Column = SortOrder(expr, Ascending)
+}
+
+
+class ColumnName(name: String) extends Column(name) {
+
+ /** Creates a new AttributeReference of type boolean */
+ def boolean: StructField = StructField(name, BooleanType)
+
+ /** Creates a new AttributeReference of type byte */
+ def byte: StructField = StructField(name, ByteType)
+
+ /** Creates a new AttributeReference of type short */
+ def short: StructField = StructField(name, ShortType)
+
+ /** Creates a new AttributeReference of type int */
+ def int: StructField = StructField(name, IntegerType)
+
+ /** Creates a new AttributeReference of type long */
+ def long: StructField = StructField(name, LongType)
+
+ /** Creates a new AttributeReference of type float */
+ def float: StructField = StructField(name, FloatType)
+
+ /** Creates a new AttributeReference of type double */
+ def double: StructField = StructField(name, DoubleType)
+
+ /** Creates a new AttributeReference of type string */
+ def string: StructField = StructField(name, StringType)
+
+ /** Creates a new AttributeReference of type date */
+ def date: StructField = StructField(name, DateType)
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal: StructField = StructField(name, DecimalType.Unlimited)
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal(precision: Int, scale: Int): StructField =
+ StructField(name, DecimalType(precision, scale))
+
+ /** Creates a new AttributeReference of type timestamp */
+ def timestamp: StructField = StructField(name, TimestampType)
+
+ /** Creates a new AttributeReference of type binary */
+ def binary: StructField = StructField(name, BinaryType)
+
+ /** Creates a new AttributeReference of type array */
+ def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType))
+
+ /** Creates a new AttributeReference of type map */
+ def map(keyType: DataType, valueType: DataType): StructField =
+ map(MapType(keyType, valueType))
+
+ def map(mapType: MapType): StructField = StructField(name, mapType)
+
+ /** Creates a new AttributeReference of type struct */
+ def struct(fields: StructField*): StructField = struct(StructType(fields))
+
+ def struct(structType: StructType): StructField = StructField(name, structType)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
new file mode 100644
index 0000000000..d0bb3640f8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -0,0 +1,596 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
+import java.util.{ArrayList, List => JList}
+
+import com.fasterxml.jackson.core.JsonFactory
+import net.razorvine.pickle.Pickler
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A collection of rows that have the same columns.
+ *
+ * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using
+ * various functions in [[SQLContext]].
+ * {{{
+ * val people = sqlContext.parquetFile("...")
+ * }}}
+ *
+ * Once created, it can be manipulated using the various domain-specific-language (DSL) functions
+ * defined in: [[DataFrame]] (this class), [[Column]], and [[dsl]] for Scala DSL.
+ *
+ * To select a column from the data frame, use the apply method:
+ * {{{
+ * val ageCol = people("age") // in Scala
+ * Column ageCol = people.apply("age") // in Java
+ * }}}
+ *
+ * Note that the [[Column]] type can also be manipulated through its various functions.
+ * {{
+ * // The following creates a new column that increases everybody's age by 10.
+ * people("age") + 10 // in Scala
+ * }}
+ *
+ * A more concrete example:
+ * {{{
+ * // To create DataFrame using SQLContext
+ * val people = sqlContext.parquetFile("...")
+ * val department = sqlContext.parquetFile("...")
+ *
+ * people.filter("age" > 30)
+ * .join(department, people("deptId") === department("id"))
+ * .groupBy(department("name"), "gender")
+ * .agg(avg(people("salary")), max(people("age")))
+ * }}}
+ */
+// TODO: Improve documentation.
+class DataFrame protected[sql](
+ val sqlContext: SQLContext,
+ private val baseLogicalPlan: LogicalPlan,
+ operatorsEnabled: Boolean)
+ extends DataFrameSpecificApi with RDDApi[Row] {
+
+ protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) =
+ this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined)
+
+ protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true)
+
+ @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
+
+ @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ baseLogicalPlan
+ }
+
+ /**
+ * An implicit conversion function internal to this class for us to avoid doing
+ * "new DataFrame(...)" everywhere.
+ */
+ private[this] implicit def toDataFrame(logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrame(sqlContext, logicalPlan, true)
+ }
+
+ /** Return the list of numeric columns, useful for doing aggregation. */
+ protected[sql] def numericColumns: Seq[Expression] = {
+ schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
+ logicalPlan.resolve(n.name, sqlContext.analyzer.resolver).get
+ }
+ }
+
+ /** Resolve a column name into a Catalyst [[NamedExpression]]. */
+ protected[sql] def resolve(colName: String): NamedExpression = {
+ logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
+ throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
+ }
+
+ /** Left here for compatibility reasons. */
+ @deprecated("1.3.0", "use toDataFrame")
+ def toSchemaRDD: DataFrame = this
+
+ /**
+ * Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
+ */
+ def toDF: DataFrame = this
+
+ /** Return the schema of this [[DataFrame]]. */
+ override def schema: StructType = queryExecution.analyzed.schema
+
+ /** Return all column names and their data types as an array. */
+ override def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
+
+ /** Return all column names as an array. */
+ override def columns: Array[String] = schema.fields.map(_.name)
+
+ /** Print the schema to the console in a nice tree format. */
+ override def printSchema(): Unit = println(schema.treeString)
+
+ /**
+ * Cartesian join with another [[DataFrame]].
+ *
+ * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ *
+ * @param right Right side of the join operation.
+ */
+ override def join(right: DataFrame): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ }
+
+ /**
+ * Inner join with another [[DataFrame]], using the given join expression.
+ *
+ * {{{
+ * // The following two are equivalent:
+ * df1.join(df2, $"df1Key" === $"df2Key")
+ * df1.join(df2).where($"df1Key" === $"df2Key")
+ * }}}
+ */
+ override def join(right: DataFrame, joinExprs: Column): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
+ }
+
+ /**
+ * Join with another [[DataFrame]], usin g the given join expression. The following performs
+ * a full outer join between `df1` and `df2`.
+ *
+ * {{{
+ * df1.join(df2, "outer", $"df1Key" === $"df2Key")
+ * }}}
+ *
+ * @param right Right side of the join.
+ * @param joinExprs Join expression.
+ * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+ */
+ override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the specified column, in ascending column.
+ * {{{
+ * // The following 3 are equivalent
+ * df.sort("sortcol")
+ * df.sort($"sortcol")
+ * df.sort($"sortcol".asc)
+ * }}}
+ */
+ override def sort(colName: String): DataFrame = {
+ Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the given expressions. For example:
+ * {{{
+ * df.sort($"col1", $"col2".desc)
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = true, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] sorted by the given expressions.
+ * This is an alias of the `sort` function.
+ */
+ @scala.annotation.varargs
+ override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ sort(sortExpr, sortExprs :_*)
+ }
+
+ /**
+ * Selecting a single column and return it as a [[Column]].
+ */
+ override def apply(colName: String): Column = {
+ val expr = resolve(colName)
+ new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr)
+ }
+
+ /**
+ * Selecting a set of expressions, wrapped in a Product.
+ * {{{
+ * // The following two are equivalent:
+ * df.apply(($"colA", $"colB" + 1))
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ */
+ override def apply(projection: Product): DataFrame = {
+ require(projection.productArity >= 1)
+ select(projection.productIterator.map {
+ case c: Column => c
+ case o: Any => new Column(Some(sqlContext), None, LiteralExpr(o))
+ }.toSeq :_*)
+ }
+
+ /**
+ * Alias the current [[DataFrame]].
+ */
+ override def as(name: String): DataFrame = Subquery(name, logicalPlan)
+
+ /**
+ * Selecting a set of expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def select(cols: Column*): DataFrame = {
+ val exprs = cols.zipWithIndex.map {
+ case (Column(expr: NamedExpression), _) =>
+ expr
+ case (Column(expr: Expression), _) =>
+ Alias(expr, expr.toString)()
+ }
+ Project(exprs.toSeq, logicalPlan)
+ }
+
+ /**
+ * Selecting a set of columns. This is a variant of `select` that can only select
+ * existing columns using column names (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // The following two are equivalent:
+ * df.select("colA", "colB")
+ * df.select($"colA", $"colB")
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def select(col: String, cols: String*): DataFrame = {
+ select((col +: cols).map(new Column(_)) :_*)
+ }
+
+ /**
+ * Filtering rows using the given condition.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def filter(condition: Column): DataFrame = {
+ Filter(condition.expr, logicalPlan)
+ }
+
+ /**
+ * Filtering rows using the given condition. This is an alias for `filter`.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def where(condition: Column): DataFrame = filter(condition)
+
+ /**
+ * Filtering rows using the given condition. This is a shorthand meant for Scala.
+ * {{{
+ * // The following are equivalent:
+ * peopleDf.filter($"age" > 15)
+ * peopleDf.where($"age" > 15)
+ * peopleDf($"age" > 15)
+ * }}}
+ */
+ override def apply(condition: Column): DataFrame = filter(condition)
+
+ /**
+ * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ * See [[GroupedDataFrame]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * df.groupBy($"department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * df.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def groupBy(cols: Column*): GroupedDataFrame = {
+ new GroupedDataFrame(this, cols.map(_.expr))
+ }
+
+ /**
+ * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ * See [[GroupedDataFrame]] for all the available aggregate functions.
+ *
+ * This is a variant of groupBy that can only group by existing columns using column names
+ * (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * df.groupBy("department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * df.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
+ val colNames: Seq[String] = col1 +: cols
+ new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
+ }
+
+ /**
+ * Aggregate on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(Map("age" -> "max", "salary" -> "avg"))
+ * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}
+ */
+ override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
+
+ /**
+ * Aggregate on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(max($"age"), avg($"salary"))
+ * df.groupBy().agg(max($"age"), avg($"salary"))
+ * }}
+ */
+ @scala.annotation.varargs
+ override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
+
+ /**
+ * Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function
+ * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
+ */
+ override def limit(n: Int): DataFrame = Limit(LiteralExpr(n), logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing union of rows in this frame and another frame.
+ * This is equivalent to `UNION ALL` in SQL.
+ */
+ override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing rows only in both this frame and another frame.
+ * This is equivalent to `INTERSECT` in SQL.
+ */
+ override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] containing rows in this frame but not in another frame.
+ * This is equivalent to `EXCEPT` in SQL.
+ */
+ override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
+
+ /**
+ * Return a new [[DataFrame]] by sampling a fraction of rows.
+ *
+ * @param withReplacement Sample with replacement or not.
+ * @param fraction Fraction of rows to generate.
+ * @param seed Seed for sampling.
+ */
+ override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
+ Sample(fraction, withReplacement, seed, logicalPlan)
+ }
+
+ /**
+ * Return a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
+ *
+ * @param withReplacement Sample with replacement or not.
+ * @param fraction Fraction of rows to generate.
+ */
+ override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Return a new [[DataFrame]] by adding a column.
+ */
+ override def addColumn(colName: String, col: Column): DataFrame = {
+ select(Column("*"), col.as(colName))
+ }
+
+ /**
+ * Return the first `n` rows.
+ */
+ override def head(n: Int): Array[Row] = limit(n).collect()
+
+ /**
+ * Return the first row.
+ */
+ override def head(): Row = head(1).head
+
+ /**
+ * Return the first row. Alias for head().
+ */
+ override def first(): Row = head()
+
+ override def map[R: ClassTag](f: Row => R): RDD[R] = {
+ rdd.map(f)
+ }
+
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
+
+ /**
+ * Return the first `n` rows in the [[DataFrame]].
+ */
+ override def take(n: Int): Array[Row] = head(n)
+
+ /**
+ * Return an array that contains all of [[Row]]s in this [[DataFrame]].
+ */
+ override def collect(): Array[Row] = rdd.collect()
+
+ /**
+ * Return a Java list that contains all of [[Row]]s in this [[DataFrame]].
+ */
+ override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
+
+ /**
+ * Return the number of rows in the [[DataFrame]].
+ */
+ override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
+
+ /**
+ * Return a new [[DataFrame]] that has exactly `numPartitions` partitions.
+ */
+ override def repartition(numPartitions: Int): DataFrame = {
+ sqlContext.applySchema(rdd.repartition(numPartitions), schema)
+ }
+
+ override def persist(): this.type = {
+ sqlContext.cacheQuery(this)
+ this
+ }
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheQuery(this, None, newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.tryUncacheQuery(this, blocking)
+ this
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // I/O
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
+ */
+ override def rdd: RDD[Row] = {
+ val schema = this.schema
+ queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
+ }
+
+ /**
+ * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ * table is tied to the [[SQLContext]] that was used to create this DataFrame.
+ *
+ * @group schema
+ */
+ override def registerTempTable(tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(this, tableName)
+ }
+
+ /**
+ * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema.
+ * Files that are written out using this method can be read back in as a [[DataFrame]]
+ * using the `parquetFile` function in [[SQLContext]].
+ */
+ override def saveAsParquetFile(path: String): Unit = {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table from the the contents of this DataFrame. This will fail if the table already
+ * exists.
+ *
+ * Note that this currently only works with DataFrame that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ */
+ @Experimental
+ override def saveAsTable(tableName: String): Unit = {
+ sqlContext.executePlan(
+ CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
+ }
+
+ /**
+ * :: Experimental ::
+ * Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
+ */
+ @Experimental
+ override def insertInto(tableName: String, overwrite: Boolean): Unit = {
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
+ }
+
+ /**
+ * Return the content of the [[DataFrame]] as a RDD of JSON strings.
+ */
+ override def toJSON: RDD[String] = {
+ val rowSchema = this.schema
+ this.mapPartitions { iter =>
+ val jsonFactory = new JsonFactory()
+ iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
+ }
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+ // for Python API
+ ////////////////////////////////////////////////////////////////////////////
+ /**
+ * A helpful function for Py4j, convert a list of Column to an array
+ */
+ protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = {
+ cols.toList.toArray
+ }
+
+ /**
+ * Converts a JavaRDD to a PythonRDD.
+ */
+ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ SerDeUtil.javaToPython(jrdd)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
new file mode 100644
index 0000000000..1f1e9bd989
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+
+
+/**
+ * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
+ */
+class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
+ extends GroupedDataFrameApi {
+
+ private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
+ val namedGroupingExprs = groupingExprs.map {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.toString)()
+ }
+ new DataFrame(df.sqlContext,
+ Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
+ }
+
+ private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
+ df.numericColumns.map { c =>
+ val a = f(c)
+ Alias(a, a.toString)()
+ }
+ }
+
+ private[this] def strToExpr(expr: String): (Expression => Expression) = {
+ expr.toLowerCase match {
+ case "avg" | "average" | "mean" => Average
+ case "max" => Max
+ case "min" => Min
+ case "sum" => Sum
+ case "count" | "size" => Count
+ }
+ }
+
+ /**
+ * Compute aggregates by specifying a map from column name to aggregate methods.
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max"
+ * "sum" -> "expense"
+ * ))
+ * }}}
+ */
+ override def agg(exprs: Map[String, String]): DataFrame = {
+ exprs.map { case (colName, expr) =>
+ val a = strToExpr(expr)(df(colName).expr)
+ Alias(a, a.toString)()
+ }.toSeq
+ }
+
+ /**
+ * Compute aggregates by specifying a map from column name to aggregate methods.
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max"
+ * "sum" -> "expense"
+ * ))
+ * }}}
+ */
+ def agg(exprs: java.util.Map[String, String]): DataFrame = {
+ agg(exprs.toMap)
+ }
+
+ /**
+ * Compute aggregates by specifying a series of aggregate columns.
+ * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]].
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * import org.apache.spark.sql.dsl._
+ * df.groupBy("department").agg(max($"age"), sum($"expense"))
+ * }}}
+ */
+ @scala.annotation.varargs
+ override def agg(expr: Column, exprs: Column*): DataFrame = {
+ val aggExprs = (expr +: exprs).map(_.expr).map {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.toString)()
+ }
+
+ new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
+ }
+
+ /** Count the number of rows for each group. */
+ override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
+
+ /**
+ * Compute the average value for each numeric columns for each group. This is an alias for `avg`.
+ */
+ override def mean(): DataFrame = aggregateNumericColumns(Average)
+
+ /**
+ * Compute the max value for each numeric columns for each group.
+ */
+ override def max(): DataFrame = aggregateNumericColumns(Max)
+
+ /**
+ * Compute the mean value for each numeric columns for each group.
+ */
+ override def avg(): DataFrame = aggregateNumericColumns(Average)
+
+ /**
+ * Compute the min value for each numeric column for each group.
+ */
+ override def min(): DataFrame = aggregateNumericColumns(Min)
+
+ /**
+ * Compute the sum for each numeric columns for each group.
+ */
+ override def sum(): DataFrame = aggregateNumericColumns(Sum)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala
new file mode 100644
index 0000000000..08cd4d0f3f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Literal.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.types._
+
+object Literal {
+
+ /** Return a new boolean literal. */
+ def apply(literal: Boolean): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new byte literal. */
+ def apply(literal: Byte): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new short literal. */
+ def apply(literal: Short): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new int literal. */
+ def apply(literal: Int): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new long literal. */
+ def apply(literal: Long): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new float literal. */
+ def apply(literal: Float): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new double literal. */
+ def apply(literal: Double): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new string literal. */
+ def apply(literal: String): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new decimal literal. */
+ def apply(literal: BigDecimal): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new decimal literal. */
+ def apply(literal: java.math.BigDecimal): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new timestamp literal. */
+ def apply(literal: java.sql.Timestamp): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new date literal. */
+ def apply(literal: java.sql.Date): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new binary (byte array) literal. */
+ def apply(literal: Array[Byte]): Column = new Column(LiteralExpr(literal))
+
+ /** Return a new null literal. */
+ def apply(literal: Null): Column = new Column(LiteralExpr(null))
+
+ /**
+ * Return a Column expression representing the literal value. Throws an exception if the
+ * data type is not supported by SparkSQL.
+ */
+ protected[sql] def anyToLiteral(literal: Any): Column = {
+ // If the literal is a symbol, convert it into a Column.
+ if (literal.isInstanceOf[Symbol]) {
+ return dsl.symbolToColumn(literal.asInstanceOf[Symbol])
+ }
+
+ val literalExpr = literal match {
+ case v: Int => LiteralExpr(v, IntegerType)
+ case v: Long => LiteralExpr(v, LongType)
+ case v: Double => LiteralExpr(v, DoubleType)
+ case v: Float => LiteralExpr(v, FloatType)
+ case v: Byte => LiteralExpr(v, ByteType)
+ case v: Short => LiteralExpr(v, ShortType)
+ case v: String => LiteralExpr(v, StringType)
+ case v: Boolean => LiteralExpr(v, BooleanType)
+ case v: BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
+ case v: java.math.BigDecimal => LiteralExpr(Decimal(v), DecimalType.Unlimited)
+ case v: Decimal => LiteralExpr(v, DecimalType.Unlimited)
+ case v: java.sql.Timestamp => LiteralExpr(v, TimestampType)
+ case v: java.sql.Date => LiteralExpr(v, DateType)
+ case v: Array[Byte] => LiteralExpr(v, BinaryType)
+ case null => LiteralExpr(null, NullType)
+ case _ =>
+ throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
+ }
+ new Column(literalExpr)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 0a22968cc7..5030e689c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -43,7 +42,7 @@ import org.apache.spark.util.Utils
/**
* :: AlphaComponent ::
- * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]]
+ * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]]
* objects and the execution of SQL queries.
*
* @groupname userf Spark SQL Functions
@@ -53,7 +52,6 @@ import org.apache.spark.util.Utils
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with CacheManager
- with ExpressionConversions
with Serializable {
self =>
@@ -111,8 +109,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
- protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+
+ protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan)
sparkContext.getConf.getAll.foreach {
case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
@@ -124,24 +122,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = {
+ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val attributeSeq = ScalaReflection.attributesFor[A]
val schema = StructType.fromAttributes(attributeSeq)
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
- new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
+ new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
}
/**
- * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]].
+ * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
*/
- def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {
- new SchemaRDD(this, LogicalRelation(baseRelation))
+ def baseRelationToSchemaRDD(baseRelation: BaseRelation): DataFrame = {
+ new DataFrame(this, LogicalRelation(baseRelation))
}
/**
* :: DeveloperApi ::
- * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
@@ -170,11 +168,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
- new SchemaRDD(this, logicalPlan)
+ new DataFrame(this, logicalPlan)
}
/**
@@ -183,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = {
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -201,7 +199,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
) : Row
}
}
- new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this))
+ new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
}
/**
@@ -210,35 +208,35 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = {
+ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
applySchema(rdd.rdd, beanClass)
}
/**
- * Loads a Parquet file, returning the result as a [[SchemaRDD]].
+ * Loads a Parquet file, returning the result as a [[DataFrame]].
*
* @group userf
*/
- def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
+ def parquetFile(path: String): DataFrame =
+ new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
- * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]].
+ * Loads a JSON file (one object per line), returning the result as a [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
- def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0)
+ def jsonFile(path: String): DataFrame = jsonFile(path, 1.0)
/**
* :: Experimental ::
* Loads a JSON file (one object per line) and applies the given schema,
- * returning the result as a [[SchemaRDD]].
+ * returning the result as a [[DataFrame]].
*
* @group userf
*/
@Experimental
- def jsonFile(path: String, schema: StructType): SchemaRDD = {
+ def jsonFile(path: String, schema: StructType): DataFrame = {
val json = sparkContext.textFile(path)
jsonRDD(json, schema)
}
@@ -247,29 +245,29 @@ class SQLContext(@transient val sparkContext: SparkContext)
* :: Experimental ::
*/
@Experimental
- def jsonFile(path: String, samplingRatio: Double): SchemaRDD = {
+ def jsonFile(path: String, samplingRatio: Double): DataFrame = {
val json = sparkContext.textFile(path)
jsonRDD(json, samplingRatio)
}
/**
* Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
- * [[SchemaRDD]].
+ * [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
* @group userf
*/
- def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0)
+ def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0)
/**
* :: Experimental ::
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
- * returning the result as a [[SchemaRDD]].
+ * returning the result as a [[DataFrame]].
*
* @group userf
*/
@Experimental
- def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
+ def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
@@ -283,7 +281,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* :: Experimental ::
*/
@Experimental
- def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
+ def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
@@ -298,8 +296,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
- catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
+ def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = {
+ catalog.registerTable(Seq(tableName), rdd.logicalPlan)
}
/**
@@ -321,17 +319,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- def sql(sqlText: String): SchemaRDD = {
+ def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
- new SchemaRDD(this, parseSql(sqlText))
+ new DataFrame(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
}
/** Returns the specified table as a SchemaRDD */
- def table(tableName: String): SchemaRDD =
- new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))
+ def table(tableName: String): DataFrame =
+ new DataFrame(this, catalog.lookupRelation(Seq(tableName)))
/**
* A collection of methods that are considered experimental, but can be used to hook into
@@ -454,15 +452,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
* access to the intermediate phases of query execution for developers.
*/
@DeveloperApi
- protected abstract class QueryExecution {
- def logical: LogicalPlan
+ protected class QueryExecution(val logical: LogicalPlan) {
- lazy val analyzed = ExtractPythonUdfs(analyzer(logical))
- lazy val withCachedData = useCachedData(analyzed)
- lazy val optimizedPlan = optimizer(withCachedData)
+ lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical))
+ lazy val withCachedData: LogicalPlan = useCachedData(analyzed)
+ lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
// TODO: Don't just pick the first one...
- lazy val sparkPlan = {
+ lazy val sparkPlan: SparkPlan = {
SparkPlan.currentContext.set(self)
planner(optimizedPlan).next()
}
@@ -512,7 +509,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
- schemaString: String): SchemaRDD = {
+ schemaString: String): DataFrame = {
val schema = parseDataType(schemaString).asInstanceOf[StructType]
applySchemaToPythonRDD(rdd, schema)
}
@@ -522,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
- schema: StructType): SchemaRDD = {
+ schema: StructType): DataFrame = {
def needsConversion(dataType: DataType): Boolean = dataType match {
case ByteType => true
@@ -549,7 +546,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
iter.map { m => new GenericRow(m): Row}
}
- new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
+ new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
deleted file mode 100644
index d1e21dffeb..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ /dev/null
@@ -1,511 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import java.util.{List => JList}
-
-import scala.collection.JavaConversions._
-
-import com.fasterxml.jackson.core.JsonFactory
-
-import net.razorvine.pickle.Pickler
-
-import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
-import org.apache.spark.annotation.{AlphaComponent, Experimental}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{BooleanType, StructType}
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: AlphaComponent ::
- * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions,
- * SchemaRDDs can be used in relational queries, as shown in the examples below.
- *
- * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD
- * whose elements are scala case classes into a SchemaRDD. This conversion can also be done
- * explicitly using the `createSchemaRDD` function on a [[SQLContext]].
- *
- * A `SchemaRDD` can also be created by loading data in from external sources.
- * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]]
- * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]].
- *
- * == SQL Queries ==
- * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once
- * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements.
- *
- * {{{
- * // One method for defining the schema of an RDD is to make a case class with the desired column
- * // names and types.
- * case class Record(key: Int, value: String)
- *
- * val sc: SparkContext // An existing spark context.
- * val sqlContext = new SQLContext(sc)
- *
- * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
- * import sqlContext._
- *
- * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
- * // Any RDD containing case classes can be registered as a table. The schema of the table is
- * // automatically inferred using scala reflection.
- * rdd.registerTempTable("records")
- *
- * val results: SchemaRDD = sql("SELECT * FROM records")
- * }}}
- *
- * == Language Integrated Queries ==
- *
- * {{{
- *
- * case class Record(key: Int, value: String)
- *
- * val sc: SparkContext // An existing spark context.
- * val sqlContext = new SQLContext(sc)
- *
- * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
- * import sqlContext._
- *
- * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i)))
- *
- * // Example of language integrated queries.
- * rdd.where('key === 1).orderBy('value.asc).select('key).collect()
- * }}}
- *
- * @groupname Query Language Integrated Queries
- * @groupdesc Query Functions that create new queries from SchemaRDDs. The
- * result of all query functions is also a SchemaRDD, allowing multiple operations to be
- * chained using a builder pattern.
- * @groupprio Query -2
- * @groupname schema SchemaRDD Functions
- * @groupprio schema -1
- * @groupname Ungrouped Base RDD Functions
- */
-@AlphaComponent
-class SchemaRDD(
- @transient val sqlContext: SQLContext,
- @transient val baseLogicalPlan: LogicalPlan)
- extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike {
-
- def baseSchemaRDD = this
-
- // =========================================================================================
- // RDD functions: Copy the internal row representation so we present immutable data to users.
- // =========================================================================================
-
- override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))
-
- override def getPartitions: Array[Partition] = firstParent[Row].partitions
-
- override protected def getDependencies: Seq[Dependency[_]] = {
- schema // Force reification of the schema so it is available on executors.
-
- List(new OneToOneDependency(queryExecution.toRdd))
- }
-
- /**
- * Returns the schema of this SchemaRDD (represented by a [[StructType]]).
- *
- * @group schema
- */
- lazy val schema: StructType = queryExecution.analyzed.schema
-
- /**
- * Returns a new RDD with each row transformed to a JSON string.
- *
- * @group schema
- */
- def toJSON: RDD[String] = {
- val rowSchema = this.schema
- this.mapPartitions { iter =>
- val jsonFactory = new JsonFactory()
- iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
- }
- }
-
-
- // =======================================================================
- // Query DSL
- // =======================================================================
-
- /**
- * Changes the output of this relation to the given expressions, similar to the `SELECT` clause
- * in SQL.
- *
- * {{{
- * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName)
- * }}}
- *
- * @param exprs a set of logical expression that will be evaluated for each input row.
- *
- * @group Query
- */
- def select(exprs: Expression*): SchemaRDD = {
- val aliases = exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"c$i")()
- }
- new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
- }
-
- /**
- * Filters the output, only returning those rows where `condition` evaluates to true.
- *
- * {{{
- * schemaRDD.where('a === 'b)
- * schemaRDD.where('a === 1)
- * schemaRDD.where('a + 'b > 10)
- * }}}
- *
- * @group Query
- */
- def where(condition: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Filter(condition, logicalPlan))
-
- /**
- * Performs a relational join on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be joined with this one.
- * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.`
- * @param on An optional condition for the join operation. This is equivalent to the `ON`
- * clause in standard SQL. In the case of `Inner` joins, specifying a
- * `condition` is equivalent to adding `where` clauses after the `join`.
- *
- * @group Query
- */
- def join(
- otherPlan: SchemaRDD,
- joinType: JoinType = Inner,
- on: Option[Expression] = None): SchemaRDD =
- new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on))
-
- /**
- * Sorts the results by the given expressions.
- * {{{
- * schemaRDD.orderBy('a)
- * schemaRDD.orderBy('a, 'b)
- * schemaRDD.orderBy('a.asc, 'b.desc)
- * }}}
- *
- * @group Query
- */
- def orderBy(sortExprs: SortOrder*): SchemaRDD =
- new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan))
-
- /**
- * Sorts the results by the given expressions within partition.
- * {{{
- * schemaRDD.sortBy('a)
- * schemaRDD.sortBy('a, 'b)
- * schemaRDD.sortBy('a.asc, 'b.desc)
- * }}}
- *
- * @group Query
- */
- def sortBy(sortExprs: SortOrder*): SchemaRDD =
- new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan))
-
- @deprecated("use limit with integer argument", "1.1.0")
- def limit(limitExpr: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
-
- /**
- * Limits the results by the given integer.
- * {{{
- * schemaRDD.limit(10)
- * }}}
- * @group Query
- */
- def limit(limitNum: Int): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
-
- /**
- * Performs a grouping followed by an aggregation.
- *
- * {{{
- * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales)
- * }}}
- *
- * @group Query
- */
- def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = {
- val aliasedExprs = aggregateExprs.map {
- case ne: NamedExpression => ne
- case e => Alias(e, e.toString)()
- }
- new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
- }
-
- /**
- * Performs an aggregation over all Rows in this RDD.
- * This is equivalent to a groupBy with no grouping expressions.
- *
- * {{{
- * schemaRDD.aggregate(Sum('sales) as 'totalSales)
- * }}}
- *
- * @group Query
- */
- def aggregate(aggregateExprs: Expression*): SchemaRDD = {
- groupBy()(aggregateExprs: _*)
- }
-
- /**
- * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
- * with the same name, for example, when performing self-joins.
- *
- * {{{
- * val x = schemaRDD.where('a === 1).as('x)
- * val y = schemaRDD.where('a === 2).as('y)
- * x.join(y).where("x.a".attr === "y.a".attr),
- * }}}
- *
- * @group Query
- */
- def as(alias: Symbol) =
- new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan))
-
- /**
- * Combines the tuples of two RDDs with the same schema, keeping duplicates.
- *
- * @group Query
- */
- def unionAll(otherPlan: SchemaRDD) =
- new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Performs a relational except on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be excepted from this one.
- *
- * @group Query
- */
- def except(otherPlan: SchemaRDD): SchemaRDD =
- new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Performs a relational intersect on two SchemaRDDs
- *
- * @param otherPlan the [[SchemaRDD]] that should be intersected with this one.
- *
- * @group Query
- */
- def intersect(otherPlan: SchemaRDD): SchemaRDD =
- new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan))
-
- /**
- * Filters tuples using a function over the value of the specified column.
- *
- * {{{
- * schemaRDD.where('a)((a: Int) => ...)
- * }}}
- *
- * @group Query
- */
- def where[T1](arg1: Symbol)(udf: (T1) => Boolean) =
- new SchemaRDD(
- sqlContext,
- Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan))
-
- /**
- * :: Experimental ::
- * Returns a sampled version of the underlying dataset.
- *
- * @group Query
- */
- @Experimental
- override
- def sample(
- withReplacement: Boolean = true,
- fraction: Double,
- seed: Long) =
- new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))
-
- /**
- * :: Experimental ::
- * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this
- * implementation leverages the query optimizer to compute the count on the SchemaRDD, which
- * supports features such as filter pushdown.
- *
- * @group Query
- */
- @Experimental
- override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
-
- /**
- * :: Experimental ::
- * Applies the given Generator, or table generating function, to this relation.
- *
- * @param generator A table generating function. The API for such functions is likely to change
- * in future releases
- * @param join when set to true, each output row of the generator is joined with the input row
- * that produced it.
- * @param outer when set to true, at least one row will be produced for each input row, similar to
- * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a
- * given row, a single row will be output, with `NULL` values for each of the
- * generated columns.
- * @param alias an optional alias that can be used as qualifier for the attributes that are
- * produced by this generate operation.
- *
- * @group Query
- */
- @Experimental
- def generate(
- generator: Generator,
- join: Boolean = false,
- outer: Boolean = false,
- alias: Option[String] = None) =
- new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan))
-
- /**
- * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit
- * conversion from a standard RDD to a SchemaRDD.
- *
- * @group schema
- */
- def toSchemaRDD = this
-
- /**
- * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
- */
- private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
- }
-
- /**
- * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
- * format as javaToPython. It is used by pyspark.
- */
- private[sql] def collectToPython: JList[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val pickle = new Pickler
- new java.util.ArrayList(collect().map { row =>
- EvaluatePython.rowToArray(row, fieldTypes)
- }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
- }
-
- /**
- * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
- * format as javaToPython and collectToPython. It is used by pyspark.
- */
- private[sql] def takeSampleToPython(
- withReplacement: Boolean,
- num: Int,
- seed: Long): JList[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val pickle = new Pickler
- new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
- EvaluatePython.rowToArray(row, fieldTypes)
- }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
- }
-
- /**
- * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
- * of base RDD functions that do not change schema.
- *
- * @param rdd RDD derived from this one and has same schema
- *
- * @group schema
- */
- private def applySchema(rdd: RDD[Row]): SchemaRDD = {
- new SchemaRDD(sqlContext,
- LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext))
- }
-
- // =======================================================================
- // Overridden RDD actions
- // =======================================================================
-
- override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
-
- def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*)
-
- override def take(num: Int): Array[Row] = limit(num).collect()
-
- // =======================================================================
- // Base RDD functions that do NOT change schema
- // =======================================================================
-
- // Transformations (return a new RDD)
-
- override def coalesce(numPartitions: Int, shuffle: Boolean = false)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.coalesce(numPartitions, shuffle)(ord))
-
- override def distinct(): SchemaRDD = applySchema(super.distinct())
-
- override def distinct(numPartitions: Int)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.distinct(numPartitions)(ord))
-
- def distinct(numPartitions: Int): SchemaRDD =
- applySchema(super.distinct(numPartitions)(null))
-
- override def filter(f: Row => Boolean): SchemaRDD =
- applySchema(super.filter(f))
-
- override def intersection(other: RDD[Row]): SchemaRDD =
- applySchema(super.intersection(other))
-
- override def intersection(other: RDD[Row], partitioner: Partitioner)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.intersection(other, partitioner)(ord))
-
- override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD =
- applySchema(super.intersection(other, numPartitions))
-
- override def repartition(numPartitions: Int)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.repartition(numPartitions)(ord))
-
- override def subtract(other: RDD[Row]): SchemaRDD =
- applySchema(super.subtract(other))
-
- override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD =
- applySchema(super.subtract(other, numPartitions))
-
- override def subtract(other: RDD[Row], p: Partitioner)
- (implicit ord: Ordering[Row] = null): SchemaRDD =
- applySchema(super.subtract(other, p)(ord))
-
- /** Overridden cache function will always use the in-memory columnar caching. */
- override def cache(): this.type = {
- sqlContext.cacheQuery(this)
- this
- }
-
- override def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheQuery(this, None, newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean): this.type = {
- sqlContext.tryUncacheQuery(this, blocking)
- this
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
deleted file mode 100644
index 3cf9209465..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.LogicalRDD
-
-/**
- * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
- */
-private[sql] trait SchemaRDDLike {
- @transient def sqlContext: SQLContext
- @transient val baseLogicalPlan: LogicalPlan
-
- private[sql] def baseSchemaRDD: SchemaRDD
-
- /**
- * :: DeveloperApi ::
- * A lazily computed query execution workflow. All other RDD operations are passed
- * through to the RDD that is produced by this workflow. This workflow is produced lazily because
- * invoking the whole query optimization pipeline can be expensive.
- *
- * The query execution is considered a Developer API as phases may be added or removed in future
- * releases. This execution is only exposed to provide an interface for inspecting the various
- * phases for debugging purposes. Applications should not depend on particular phases existing
- * or producing any specific output, even for exactly the same query.
- *
- * Additionally, the RDD exposed by this execution is not designed for consumption by end users.
- * In particular, it does not contain any schema information, and it reuses Row objects
- * internally. This object reuse improves performance, but can make programming against the RDD
- * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly.
- */
- @transient
- @DeveloperApi
- lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
-
- @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- baseLogicalPlan
- }
-
- override def toString =
- s"""${super.toString}
- |== Query Plan ==
- |${queryExecution.simpleString}""".stripMargin.trim
-
- /**
- * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that
- * are written out using this method can be read back in as a SchemaRDD using the `parquetFile`
- * function.
- *
- * @group schema
- */
- def saveAsParquetFile(path: String): Unit = {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
-
- /**
- * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
- * table is tied to the [[SQLContext]] that was used to create this SchemaRDD.
- *
- * @group schema
- */
- def registerTempTable(tableName: String): Unit = {
- sqlContext.registerRDDAsTable(baseSchemaRDD, tableName)
- }
-
- @deprecated("Use registerTempTable instead of registerAsTable.", "1.1")
- def registerAsTable(tableName: String): Unit = registerTempTable(tableName)
-
- /**
- * :: Experimental ::
- * Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
- *
- * @group schema
- */
- @Experimental
- def insertInto(tableName: String, overwrite: Boolean): Unit =
- sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
- Map.empty, logicalPlan, overwrite)).toRdd
-
- /**
- * :: Experimental ::
- * Appends the rows from this RDD to the specified table.
- *
- * @group schema
- */
- @Experimental
- def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
-
- /**
- * :: Experimental ::
- * Creates a table from the the contents of this SchemaRDD. This will fail if the table already
- * exists.
- *
- * Note that this currently only works with SchemaRDDs that are created from a HiveContext as
- * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
- * an RDD out to a parquet file, and then register that file as a table. This "table" can then
- * be the target of an `insertInto`.
- *
- * @group schema
- */
- @Experimental
- def saveAsTable(tableName: String): Unit =
- sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd
-
- /** Returns the schema as a string in the tree format.
- *
- * @group schema
- */
- def schemaString: String = baseSchemaRDD.schema.treeString
-
- /** Prints out the schema.
- *
- * @group schema
- */
- def printSchema(): Unit = println(schemaString)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
new file mode 100644
index 0000000000..073d41e938
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -0,0 +1,289 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * An internal interface defining the RDD-like methods for [[DataFrame]].
+ * Please use [[DataFrame]] directly, and do NOT use this.
+ */
+trait RDDApi[T] {
+
+ def cache(): this.type = persist()
+
+ def persist(): this.type
+
+ def persist(newLevel: StorageLevel): this.type
+
+ def unpersist(): this.type = unpersist(blocking = false)
+
+ def unpersist(blocking: Boolean): this.type
+
+ def map[R: ClassTag](f: T => R): RDD[R]
+
+ def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
+
+ def take(n: Int): Array[T]
+
+ def collect(): Array[T]
+
+ def collectAsList(): java.util.List[T]
+
+ def count(): Long
+
+ def first(): T
+
+ def repartition(numPartitions: Int): DataFrame
+}
+
+
+/**
+ * An internal interface defining data frame related methods in [[DataFrame]].
+ * Please use [[DataFrame]] directly, and do NOT use this.
+ */
+trait DataFrameSpecificApi {
+
+ def schema: StructType
+
+ def printSchema(): Unit
+
+ def dtypes: Array[(String, String)]
+
+ def columns: Array[String]
+
+ def head(): Row
+
+ def head(n: Int): Array[Row]
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Relational operators
+ /////////////////////////////////////////////////////////////////////////////
+ def apply(colName: String): Column
+
+ def apply(projection: Product): DataFrame
+
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def select(col: String, cols: String*): DataFrame
+
+ def apply(condition: Column): DataFrame
+
+ def as(name: String): DataFrame
+
+ def filter(condition: Column): DataFrame
+
+ def where(condition: Column): DataFrame
+
+ @scala.annotation.varargs
+ def groupBy(cols: Column*): GroupedDataFrame
+
+ @scala.annotation.varargs
+ def groupBy(col1: String, cols: String*): GroupedDataFrame
+
+ def agg(exprs: Map[String, String]): DataFrame
+
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame
+
+ def sort(colName: String): DataFrame
+
+ @scala.annotation.varargs
+ def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ def join(right: DataFrame): DataFrame
+
+ def join(right: DataFrame, joinExprs: Column): DataFrame
+
+ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame
+
+ def limit(n: Int): DataFrame
+
+ def unionAll(other: DataFrame): DataFrame
+
+ def intersect(other: DataFrame): DataFrame
+
+ def except(other: DataFrame): DataFrame
+
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
+
+ def sample(withReplacement: Boolean, fraction: Double): DataFrame
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Column mutation
+ /////////////////////////////////////////////////////////////////////////////
+ def addColumn(colName: String, col: Column): DataFrame
+
+ /////////////////////////////////////////////////////////////////////////////
+ // I/O and interaction with other frameworks
+ /////////////////////////////////////////////////////////////////////////////
+
+ def rdd: RDD[Row]
+
+ def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
+
+ def toJSON: RDD[String]
+
+ def registerTempTable(tableName: String): Unit
+
+ def saveAsParquetFile(path: String): Unit
+
+ @Experimental
+ def saveAsTable(tableName: String): Unit
+
+ @Experimental
+ def insertInto(tableName: String, overwrite: Boolean): Unit
+
+ @Experimental
+ def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Stat functions
+ /////////////////////////////////////////////////////////////////////////////
+// def describe(): Unit
+//
+// def mean(): Unit
+//
+// def max(): Unit
+//
+// def min(): Unit
+}
+
+
+/**
+ * An internal interface defining expression APIs for [[DataFrame]].
+ * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this.
+ */
+trait ExpressionApi {
+
+ def isComputable: Boolean
+
+ def unary_- : Column
+ def unary_! : Column
+ def unary_~ : Column
+
+ def + (other: Column): Column
+ def + (other: Any): Column
+ def - (other: Column): Column
+ def - (other: Any): Column
+ def * (other: Column): Column
+ def * (other: Any): Column
+ def / (other: Column): Column
+ def / (other: Any): Column
+ def % (other: Column): Column
+ def % (other: Any): Column
+ def & (other: Column): Column
+ def & (other: Any): Column
+ def | (other: Column): Column
+ def | (other: Any): Column
+ def ^ (other: Column): Column
+ def ^ (other: Any): Column
+
+ def && (other: Column): Column
+ def && (other: Boolean): Column
+ def || (other: Column): Column
+ def || (other: Boolean): Column
+
+ def < (other: Column): Column
+ def < (other: Any): Column
+ def <= (other: Column): Column
+ def <= (other: Any): Column
+ def > (other: Column): Column
+ def > (other: Any): Column
+ def >= (other: Column): Column
+ def >= (other: Any): Column
+ def === (other: Column): Column
+ def === (other: Any): Column
+ def equalTo(other: Column): Column
+ def equalTo(other: Any): Column
+ def <=> (other: Column): Column
+ def <=> (other: Any): Column
+ def !== (other: Column): Column
+ def !== (other: Any): Column
+
+ @scala.annotation.varargs
+ def in(list: Column*): Column
+
+ def like(other: Column): Column
+ def like(other: String): Column
+ def rlike(other: Column): Column
+ def rlike(other: String): Column
+
+ def contains(other: Column): Column
+ def contains(other: Any): Column
+ def startsWith(other: Column): Column
+ def startsWith(other: String): Column
+ def endsWith(other: Column): Column
+ def endsWith(other: String): Column
+
+ def substr(startPos: Column, len: Column): Column
+ def substr(startPos: Int, len: Int): Column
+
+ def isNull: Column
+ def isNotNull: Column
+
+ def getItem(ordinal: Column): Column
+ def getItem(ordinal: Int): Column
+ def getField(fieldName: String): Column
+
+ def cast(to: DataType): Column
+
+ def asc: Column
+ def desc: Column
+
+ def as(alias: String): Column
+}
+
+
+/**
+ * An internal interface defining aggregation APIs for [[DataFrame]].
+ * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this.
+ */
+trait GroupedDataFrameApi {
+
+ def agg(exprs: Map[String, String]): DataFrame
+
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame
+
+ def avg(): DataFrame
+
+ def mean(): DataFrame
+
+ def min(): DataFrame
+
+ def max(): DataFrame
+
+ def sum(): DataFrame
+
+ def count(): DataFrame
+
+ // TODO: Add var, std
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
new file mode 100644
index 0000000000..29c3d26ae5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Timestamp, Date}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.DataType
+
+
+package object dsl {
+
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) extends AnyVal {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
+ }
+
+ private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
+
+ def sum(e: Column): Column = Sum(e.expr)
+ def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+ def count(e: Column): Column = Count(e.expr)
+
+ @scala.annotation.varargs
+ def countDistinct(expr: Column, exprs: Column*): Column =
+ CountDistinct((expr +: exprs).map(_.expr))
+
+ def avg(e: Column): Column = Average(e.expr)
+ def first(e: Column): Column = First(e.expr)
+ def last(e: Column): Column = Last(e.expr)
+ def min(e: Column): Column = Min(e.expr)
+ def max(e: Column): Column = Max(e.expr)
+ def upper(e: Column): Column = Upper(e.expr)
+ def lower(e: Column): Column = Lower(e.expr)
+ def sqrt(e: Column): Column = Sqrt(e.expr)
+ def abs(e: Column): Column = Abs(e.expr)
+
+ // scalastyle:off
+
+ object literals {
+
+ implicit def booleanToLiteral(b: Boolean): Column = Literal(b)
+
+ implicit def byteToLiteral(b: Byte): Column = Literal(b)
+
+ implicit def shortToLiteral(s: Short): Column = Literal(s)
+
+ implicit def intToLiteral(i: Int): Column = Literal(i)
+
+ implicit def longToLiteral(l: Long): Column = Literal(l)
+
+ implicit def floatToLiteral(f: Float): Column = Literal(f)
+
+ implicit def doubleToLiteral(d: Double): Column = Literal(d)
+
+ implicit def stringToLiteral(s: String): Column = Literal(s)
+
+ implicit def dateToLiteral(d: Date): Column = Literal(d)
+
+ implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying())
+
+ implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d)
+
+ implicit def timestampToLiteral(t: Timestamp): Column = Literal(t)
+
+ implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a)
+ }
+
+
+ /* Use the following code to generate:
+ (0 to 22).map { x =>
+ val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
+ val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf))
+ }""")
+ }
+
+ (0 to 22).map { x =>
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val fTypes = Seq.fill(x + 1)("_").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, returnType, Seq($argsInUdf))
+ }""")
+ }
+ }
+ */
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag](f: Function0[RT]): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function0[_], returnType: DataType): Column = {
+ ScalaUdf(f, returnType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 52a31f01a4..6fba76c521 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
import org.apache.spark.sql.catalyst.plans.logical
@@ -137,7 +137,9 @@ case class CacheTableCommand(
isLazy: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName))
+ plan.foreach { logicalPlan =>
+ sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName)
+ }
sqlContext.cacheTable(tableName)
if (!isLazy) {
@@ -159,7 +161,7 @@ case class CacheTableCommand(
case class UncacheTableCommand(tableName: String) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- sqlContext.table(tableName).unpersist()
+ sqlContext.table(tableName).unpersist(blocking = false)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 4d7e338e8e..aeb0960e87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
-import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -42,7 +42,7 @@ package object debug {
* Augments SchemaRDDs with debug methods.
*/
@DeveloperApi
- implicit class DebugQuery(query: SchemaRDD) {
+ implicit class DebugQuery(query: DataFrame) {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[TreeNodeRef]()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 6dd39be807..7c49b5220d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -37,5 +37,5 @@ package object sql {
* Converts a logical plan into zero or more SparkPlans.
*/
@DeveloperApi
- type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+ protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 02ce1b3e6d..0b312ef51d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.util.Utils
@@ -100,7 +100,7 @@ trait ParquetTest {
*/
protected def withParquetRDD[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
- (f: SchemaRDD => Unit): Unit = {
+ (f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(parquetFile(path)))
}
@@ -120,7 +120,7 @@ trait ParquetTest {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetRDD(data) { rdd =>
- rdd.registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(rdd, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 37853d4d03..d13f2ce2a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -18,19 +18,18 @@
package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
-import org.apache.spark.sql._
+import org.apache.spark.sql.{Row, Strategy}
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
private[sql] object DataSourceStrategy extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
l,
@@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy {
}
}
+ /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */
protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v)
- case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v)
+ case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v)
+ case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v)
- case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v)
- case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v)
+ case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v)
+ case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v)
- case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v)
- case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
+ case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v)
+ case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
GreaterThanOrEqual(a.name, v)
- case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
+ LessThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
+ GreaterThanOrEqual(a.name, v)
case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 171b816a26..b4af91a768 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{SchemaRDD, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.execution.RunnableCommand
@@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing(
def run(sqlContext: SQLContext) = {
val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
- new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(
+ new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index f9c0822160..2564c849b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.test
import scala.language.implicitConversions
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/** A SQLContext that can be used for local testing. */
@@ -40,8 +40,8 @@ object TestSQLContext
* Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to
* construct SchemaRDD directly out of local data without relying on implicits.
*/
- protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = {
- new SchemaRDD(this, plan)
+ protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
+ new DataFrame(this, plan)
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
index 9ff40471a0..e5588938ea 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
@@ -61,7 +61,7 @@ public class JavaAPISuite implements Serializable {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
assert(result.getInt(0) == 4);
}
@@ -81,7 +81,7 @@ public class JavaAPISuite implements Serializable {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
assert(result.getInt(0) == 9);
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9e96738ac0..badd00d34b 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -98,8 +98,8 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
- schemaRDD.registerTempTable("people");
+ DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
+ df.registerTempTable("people");
Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect();
List<Row> expected = new ArrayList<Row>(2);
@@ -147,17 +147,17 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
- StructType actualSchema1 = schemaRDD1.schema();
+ DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
+ StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
- schemaRDD1.registerTempTable("jsonTable1");
+ df1.registerTempTable("jsonTable1");
List<Row> actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
- StructType actualSchema2 = schemaRDD2.schema();
+ DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
+ StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
- schemaRDD2.registerTempTable("jsonTable2");
+ df2.registerTempTable("jsonTable2");
List<Row> actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index cfc037caff..34763156a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index afbfe214f1..a5848f219c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/* Implicits */
-import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import scala.language.postfixOps
@@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
- testData2.groupBy('a)('a, sum('b)),
+ testData2.groupBy("a").agg($"a", sum($"b")),
Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
- testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
+ testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
- testData2.aggregate(sum('b)),
+ testData2.agg(sum('b)),
Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
- testData.where($"key" === 1).select($"value"),
+ testData.where($"key" === Literal(1)).select($"value"),
Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select *") {
checkAnswer(
- testData.select(Star(None)),
+ testData.select($"*"),
testData.collect().toSeq)
}
test("simple select") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select with functions") {
checkAnswer(
- testData.select(sum('value), avg('value), count(1)),
+ testData.select(sum('value), avg('value), count(Literal(1))),
Row(5050.0, 50.5, 100))
checkAnswer(
@@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
- }
-
- test("partition wide sorting") {
- // 2 partitions totally, and
- // Partition #1 with values:
- // (1, 1)
- // (1, 2)
- // (2, 1)
- // Partition #2 with values:
- // (2, 2)
- // (3, 1)
- // (3, 2)
- checkAnswer(
- testData2.sortBy('a.asc, 'b.asc),
- Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
-
- checkAnswer(
- testData2.sortBy('a.asc, 'b.desc),
- Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.desc),
- Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.asc),
- Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest {
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
- test("SPARK-3395 limit distinct") {
- val filtered = TestData.testData2
- .distinct()
- .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending))
- .limit(1)
- .registerTempTable("onerow")
- checkAnswer(
- sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- Row(1, 1, 1, 1) ::
- Row(1, 1, 1, 2) :: Nil)
- }
-
- test("SPARK-3858 generator qualifiers are discarded") {
- checkAnswer(
- arrayData.as('ad)
- .generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
- .select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
- }
-
test("average") {
checkAnswer(
- testData2.aggregate(avg('a)),
+ testData2.agg(avg('a)),
Row(2.0))
checkAnswer(
- testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+ testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a)),
+ decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+ decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+ decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
+ decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
- testData3.aggregate(avg('b)),
+ testData3.agg(avg('b)),
Row(2.0))
checkAnswer(
- testData3.aggregate(avg('b), countDistinct('b)),
+ testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))
checkAnswer(
- testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+ testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
- emptyTableData.aggregate(avg('a)),
+ emptyTableData.agg(avg('a)),
Row(null))
checkAnswer(
- emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+ emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}
@@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest {
assert(testData2.count() === testData2.map(_ => 1).count())
checkAnswer(
- testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+ testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}
test("null count") {
checkAnswer(
- testData3.groupBy('a)('a, count('b)),
+ testData3.groupBy('a).agg('a, count('b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.groupBy('a)('a, count('a + 'b)),
+ testData3.groupBy('a).agg('a, count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
+ testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)
checkAnswer(
- testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+ testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}
@@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest {
assert(emptyTableData.count() === 0)
checkAnswer(
- emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+ emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}
test("zero sum") {
checkAnswer(
- emptyTableData.aggregate(sum('a)),
+ emptyTableData.agg(sum('a)),
Row(null))
}
test("zero sum distinct") {
checkAnswer(
- emptyTableData.aggregate(sumDistinct('a)),
+ emptyTableData.agg(sumDistinct('a)),
Row(null))
}
@@ -320,7 +271,7 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
- testData.select(Star(None), foo.call('key, 'value)).limit(3),
+ testData.select($"*", callUDF(foo, 'key, 'value)).limit(3),
Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
@@ -362,7 +313,7 @@ class DslQuerySuite extends QueryTest {
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase))
)
checkAnswer(
@@ -379,7 +330,7 @@ class DslQuerySuite extends QueryTest {
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase))
)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index cd36da7751..79713725c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,19 +20,20 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+
class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
test("equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("multiple-key equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner,
- Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
test("inner join where, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+ upperCaseData.join(lowerCaseData).where('n === 'N),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("inner join ON, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N"),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, where, multiple matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 1).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(1)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Row(1,1,1,1) ::
Row(1,1,1,2) ::
Row(1,2,1,1) ::
@@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, no matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 2).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(2)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Nil)
}
test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.as('x)
- val bigDataY = bigData.as('y)
+ val bigDataX = bigData.as("x")
+ val bigDataY = bigData.as("y")
checkAnswer(
- bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
- testData.flatMap(
- row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
+ bigDataX.join(bigDataY).where($"x.key" === $"y.key"),
+ testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
@@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left outer join") {
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("right outer join") {
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -299,14 +298,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= Literal(4)).registerTempTable("left")
+ upperCaseData.where('N >= Literal(3)).registerTempTable("right")
val left = UnresolvedRelation(Seq("left"), None)
val right = UnresolvedRelation(Seq("right"), None)
checkAnswer(
- left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+ left.join(right, $"left.N" === $"right.N", "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", 3, "C") ::
@@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
@@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 42a21c148d..07c52de377 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -26,12 +26,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
+ def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
@@ -44,10 +44,10 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -91,7 +91,7 @@ class QueryTest extends PlanTest {
}
}
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}
@@ -102,7 +102,7 @@ class QueryTest extends PlanTest {
}
/** Asserts that a given SchemaRDD will be executed using the given number of cached results. */
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 03b44ca1d6..4fff99cb3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.util.TimeZone
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -29,6 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
+
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
@@ -381,8 +383,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("big inner join, 4 matches per row") {
-
-
checkAnswer(
sql(
"""
@@ -396,7 +396,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| SELECT * FROM testData UNION ALL
| SELECT * FROM testData) y
|WHERE x.key = y.key""".stripMargin),
- testData.flatMap(
+ testData.rdd.flatMap(
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
@@ -742,7 +742,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("metadata is propagated correctly") {
- val person = sql("SELECT * FROM person")
+ val person: DataFrame = sql("SELECT * FROM person")
val schema = person.schema
val docKey = "doc"
val docValue = "first name"
@@ -751,14 +751,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person, schemaWithMeta)
- def validateMetadata(rdd: SchemaRDD): Unit = {
+ val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
personWithMeta.registerTempTable("personWithMeta")
- validateMetadata(personWithMeta.select('name))
- validateMetadata(personWithMeta.select("name".attr))
- validateMetadata(personWithMeta.select('id, 'name))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"id", $"name"))
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 808ed5288c..fffa2b7dfa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test._
/* Implicits */
@@ -29,11 +30,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +45,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
+ LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +56,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toSchemaRDD
+ TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +68,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toSchemaRDD
+ DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,17 +78,17 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
+ BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toSchemaRDD
+ TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")
- val emptyTableData = logical.LocalRelation('a.int, 'b.int)
+ val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
@@ -97,7 +98,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toSchemaRDD
+ UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +107,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toSchemaRDD
+ LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -200,6 +201,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toSchemaRDD
+ :: Nil).toDF
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 0c98120031..5abd7b9383 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
@@ -28,17 +29,17 @@ class UDFSuite extends QueryTest {
test("Simple UDF") {
udf.register("strLenScala", (_: String).length)
- assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
+ assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
udf.register("random0", () => { Math.random()})
- assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
+ assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
udf.register("strLenScala", (_: String).length + (_:Int))
- assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
+ assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("struct UDF") {
@@ -46,7 +47,7 @@ class UDFSuite extends QueryTest {
val result=
sql("SELECT returnStruct('test', 'test2') as ret")
- .select("ret.f1".attr).first().getString(0)
- assert(result == "test")
+ .select($"ret.f1").head().getString(0)
+ assert(result === "test")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index fbc8704f78..62b2e89403 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
+
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
@@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest {
test("register user type: MyDenseVector for MyLabeledPoint") {
- val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
+ val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
val features: RDD[MyDenseVector] =
- pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
+ pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v }
val featuresArrays: Array[MyDenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index e61f3c3963..6f051dfe3d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 67007b8c09..be5e63c76f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
+
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
@@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite {
}
test("count is partially aggregated") {
- val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
val planned = HashAggregation(query).head
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
@@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite {
}
test("count distinct is partially aggregated") {
- val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
test("mixed aggregates are partially aggregated") {
val query =
- testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
+ testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
@@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite {
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
- val a = testData.as('a)
- val b = table("tiny").as('b)
- val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+ val a = testData.as("a")
+ val b = table("tiny").as("b")
+ val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
deleted file mode 100644
index 272c0d4cb2..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-
-/* Implicit conversions */
-import org.apache.spark.sql.test.TestSQLContext._
-
-/**
- * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns
- * from the input data. These will be replaced during analysis with specific AttributeReferences
- * and then bound to specific ordinals during query planning. While TGFs could also access specific
- * columns using hand-coded ordinals, doing so violates data independence.
- *
- * Note: this is only a rough example of how TGFs can be expressed, the final version will likely
- * involve a lot more sugar for cleaner use in Scala/Java/etc.
- */
-case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator {
- def children = input
- protected def makeOutput() = 'nameAndAge.string :: Nil
-
- val Seq(nameAttr, ageAttr) = input
-
- override def eval(input: Row): TraversableOnce[Row] = {
- val name = nameAttr.eval(input)
- val age = ageAttr.eval(input).asInstanceOf[Int]
-
- Iterator(
- new GenericRow(Array[Any](s"$name is $age years old")),
- new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old")))
- }
-}
-
-class TgfSuite extends QueryTest {
- val inputData =
- logical.LocalRelation('name.string, 'age.int).loadData(
- ("michael", 29) :: Nil
- )
-
- test("simple tgf example") {
- checkAnswer(
- inputData.generate(ExampleTGF()),
- Seq(
- Row("michael is 29 years old"),
- Row("Next year, michael will be 30 years old")))
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 94d14acccb..ef198f846c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.{Literal, QueryTest, Row, SQLConf}
class JsonSuite extends QueryTest {
import org.apache.spark.sql.json.TestJsonData._
@@ -463,8 +464,8 @@ class JsonSuite extends QueryTest {
// in the Project.
checkAnswer(
jsonSchemaRDD.
- where('num_str > BigDecimal("92233720368547758060")).
- select('num_str + 1.2 as Symbol("num")),
+ where('num_str > Literal(BigDecimal("92233720368547758060"))).
+ select(('num_str + Literal(1.2)).as("num")),
Row(new java.math.BigDecimal("92233720368547758061.2"))
)
@@ -820,7 +821,7 @@ class JsonSuite extends QueryTest {
val schemaRDD1 = applySchema(rowRDD1, schema1)
schemaRDD1.registerTempTable("applySchema1")
- val schemaRDD2 = schemaRDD1.toSchemaRDD
+ val schemaRDD2 = schemaRDD1.toDF
val result = schemaRDD2.toJSON.collect()
assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
@@ -841,7 +842,7 @@ class JsonSuite extends QueryTest {
val schemaRDD3 = applySchema(rowRDD2, schema2)
schemaRDD3.registerTempTable("applySchema2")
- val schemaRDD4 = schemaRDD3.toSchemaRDD
+ val schemaRDD4 = schemaRDD3.toDF
val result2 = schemaRDD4.toJSON.collect()
assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index 1e7d3e06fc..c9bc55900d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -23,7 +23,7 @@ import parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row}
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -41,15 +41,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
private def checkFilterPredicate(
- rdd: SchemaRDD,
+ rdd: DataFrame,
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
- checker: (SchemaRDD, Seq[Row]) => Unit,
+ checker: (DataFrame, Seq[Row]) => Unit,
expected: Seq[Row]): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
- val query = rdd.select(output: _*).where(predicate)
+ val query = rdd
+ .select(output.map(e => new org.apache.spark.sql.Column(e)): _*)
+ .where(new org.apache.spark.sql.Column(predicate))
val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect {
case plan: ParquetTableScan => plan.columnPruningPred
@@ -71,13 +73,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
private def checkFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected)
}
private def checkFilterPredicate[T]
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T)
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
@@ -93,24 +95,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - integer") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -118,24 +120,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - long") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -143,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - float") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -168,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
test("filter pushdown - double") {
withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
@@ -197,30 +199,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkFilterPredicate(
'_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
- checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1")
+ checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1")
checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString)))
- checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1")
- checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4")
+ checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1")
+ checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4")
checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
- checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1")
- checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1")
- checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4")
- checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
- checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
+ checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1")
+ checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1")
+ checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4")
+ checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
+ checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
- checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
+ checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3")
- checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
+ checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
}
}
def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: SchemaRDD): Unit = {
- def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = {
+ (implicit rdd: DataFrame): Unit = {
+ def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = {
assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) {
rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
}
@@ -231,7 +233,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte])
- (implicit rdd: SchemaRDD): Unit = {
+ (implicit rdd: DataFrame): Unit = {
checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
@@ -249,16 +251,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
checkBinaryFilterPredicate(
'_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq)
- checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b)
- checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b)
+ checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b)
checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b)
- checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b)
- checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index a57e4e85a3..f03b3a32e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.DecimalType
-import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
// with an empty configuration (it is after all not intended to be used in this way?)
@@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
}
test("fixed-length decimals") {
- def makeDecimalRDD(decimal: DecimalType): SchemaRDD =
+ def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
- .select('_1 cast decimal)
+ .select($"_1" cast decimal as "abcd")
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 7900b3e894..a33cf1172c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.sources
+import scala.language.existentials
+
import org.apache.spark.sql._
import org.apache.spark.sql.types._
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 7385952861..bb19ac232f 100755
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -23,6 +23,7 @@ import java.io._
import java.util.{ArrayList => JArrayList}
import jline.{ConsoleReader, History}
+
import org.apache.commons.lang.StringUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
@@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket
import org.apache.spark.Logging
import org.apache.spark.sql.hive.HiveShim
-import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim
private[hive] object SparkSQLCLIDriver {
private var prompt = "spark-sql"
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index 166c56b9df..ea9d61d8d0 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.Logging
-import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}
+import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation(
sessionToActivePool: SMap[SessionHandle, String])
extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging {
- private var result: SchemaRDD = _
+ private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
@@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.toLocalIterator
+ result.rdd.toLocalIterator
} else {
result.collect().iterator
}
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index eaf7a1ddd4..71e3954b2c 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.Logging
-import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation(
// NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {
- private var result: SchemaRDD = _
+ private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
@@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.toLocalIterator
+ result.rdd.toLocalIterator
} else {
result.collect().iterator
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 9d2cfd8e0d..b746942cb1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+ new this.QueryExecution(plan)
- override def sql(sqlText: String): SchemaRDD = {
+ override def sql(sqlText: String): DataFrame = {
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
+ new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
@@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override protected[sql] val planner = hivePlanner
/** Extends QueryExecution with hive specific features. */
- protected[sql] abstract class QueryExecution extends super.QueryExecution {
+ protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
+ extends super.QueryExecution(logicalPlan) {
/**
* Returns the result as a hive compatible sequence of strings. For native commands, the
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 6952b126cf..ace9329cd5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy}
+import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
@@ -55,16 +55,15 @@ private[hive] trait HiveStrategies {
*/
@Experimental
object ParquetConversion extends Strategy {
- implicit class LogicalPlanHacks(s: SchemaRDD) {
- def lowerCase =
- new SchemaRDD(s.sqlContext, s.logicalPlan)
+ implicit class LogicalPlanHacks(s: DataFrame) {
+ def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan)
def addPartitioningAttributes(attrs: Seq[Attribute]) = {
// Don't add the partitioning key if its already present in the data.
if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) {
s
} else {
- new SchemaRDD(
+ new DataFrame(
s.sqlContext,
s.logicalPlan transform {
case p: ParquetRelation => p.copy(partitioningAttributes = attrs)
@@ -97,13 +96,13 @@ private[hive] trait HiveStrategies {
// We are going to throw the predicates and projection back at the whole optimization
// sequence so lets unresolve all the attributes, allowing them to be rebound to the
// matching parquet attributes.
- val unresolvedOtherPredicates = otherPredicates.map(_ transform {
+ val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
- }).reduceOption(And).getOrElse(Literal(true))
+ }).reduceOption(And).getOrElse(Literal(true)))
- val unresolvedProjection = projectList.map(_ transform {
+ val unresolvedProjection: Seq[Column] = projectList.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
- })
+ }).map(new Column(_))
try {
if (relation.hiveQlTable.isPartitioned) {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 47431cef03..8e70ae8f56 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))
override def executePlan(plan: LogicalPlan): this.QueryExecution =
- new this.QueryExecution { val logical = plan }
+ new this.QueryExecution(plan)
/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
@@ -150,8 +150,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
val describedTable = "DESCRIBE (\\w+)".r
- protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
- lazy val logical = HiveQl.parseSql(hql)
+ protected[hive] class HiveQLQueryExecution(hql: String)
+ extends this.QueryExecution(HiveQl.parseSql(hql)) {
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}
@@ -159,7 +159,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
/**
* Override QueryExecution with special debug workflow.
*/
- abstract class QueryExecution extends super.QueryExecution {
+ class QueryExecution(logicalPlan: LogicalPlan)
+ extends super.QueryExecution(logicalPlan) {
override lazy val analyzed = {
val describedTables = logical match {
case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f320d732fb..ba39129388 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -36,12 +36,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
+ def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
@@ -54,10 +54,10 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -101,7 +101,7 @@ class QueryTest extends PlanTest {
}
}
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index f95a6b43af..61e5117fea 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{QueryTest, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.storage.RDDBlockId
class CachedTableSuite extends QueryTest {
@@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest {
* Throws a test failed exception when the number of cached tables differs from the expected
* number.
*/
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 0e6636d38e..5775d83fcb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
+ testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq
)
// Now overwrite.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index df72be7746..d67b00bc9d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -27,11 +27,12 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{SQLConf, Row, SchemaRDD}
case class TestData(a: Int, b: String)
@@ -473,7 +474,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
- def isExplanation(result: SchemaRDD) = {
+ def isExplanation(result: DataFrame) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
explanation.contains("== Physical Plan ==")
}
@@ -842,7 +843,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ def collectResults(rdd: DataFrame): Set[(String, String)] =
rdd.collect().map {
case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 16f77a438e..a081227b4e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.Row
import org.apache.spark.util.Utils
@@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest {
sql("create table spark_4959 (col1 string)")
sql("""insert into table spark_4959 select "hi" from src limit 1""")
table("spark_4959").select(
- 'col1.as('CaseSensitiveColName),
- 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2")
+ 'col1.as("CaseSensitiveColName"),
+ 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2")
- assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi"))
- assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi"))
+ assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
+ assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index f2374a2152..dd0df1a9f6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f3,
| getStruct(1).f4,
| getStruct(1).f5 FROM src LIMIT 1
- """.stripMargin).first() === Row(1, 2, 3, 4, 5))
+ """.stripMargin).head() === Row(1, 2, 3, 4, 5))
}
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {