aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/ml-guide.md12
-rw-r--r--docs/sql-programming-guide.md16
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java2
-rw-r--r--examples/src/main/python/sql.py4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java2
-rw-r--r--python/pyspark/sql/context.py87
-rw-r--r--python/pyspark/sql/tests.py26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala4
23 files changed, 222 insertions, 97 deletions
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index be178d7689..4bf14fba34 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -260,7 +260,7 @@ List<LabeledPoint> localTraining = Lists.newArrayList(
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
-JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+JavaSchemaRDD training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -300,7 +300,7 @@ List<LabeledPoint> localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
@@ -443,7 +443,7 @@ List<LabeledDocument> localTraining = Lists.newArrayList(
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -469,7 +469,7 @@ List<Document> localTest = Lists.newArrayList(
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
JavaSchemaRDD test =
- jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
model.transform(test).registerAsTable("prediction");
@@ -626,7 +626,7 @@ List<LabeledDocument> localTraining = Lists.newArrayList(
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -669,7 +669,7 @@ List<Document> localTest = Lists.newArrayList(
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+JavaSchemaRDD test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerAsTable("prediction");
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 38f617d0c8..b2b007509c 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -225,7 +225,7 @@ public static class Person implements Serializable {
{% endhighlight %}
-A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object
+A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object
for the JavaBean.
{% highlight java %}
@@ -247,7 +247,7 @@ JavaRDD<Person> people = sc.textFile("examples/src/main/resources/people.txt").m
});
// Apply a schema to an RDD of JavaBeans and register it as a table.
-JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class);
+JavaSchemaRDD schemaPeople = sqlContext.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
@@ -315,7 +315,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
by `SQLContext`.
For example:
@@ -341,7 +341,7 @@ val schema =
val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
// Apply the schema to the RDD.
-val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema)
+val peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema)
// Register the SchemaRDD as a table.
peopleSchemaRDD.registerTempTable("people")
@@ -367,7 +367,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
by `JavaSQLContext`.
For example:
@@ -406,7 +406,7 @@ JavaRDD<Row> rowRDD = people.map(
});
// Apply the schema to the RDD.
-JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema);
+JavaSchemaRDD peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema);
// Register the SchemaRDD as a table.
peopleSchemaRDD.registerTempTable("people");
@@ -436,7 +436,7 @@ a `SchemaRDD` can be created programmatically with three steps.
1. Create an RDD of tuples or lists from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
tuples or lists in the RDD created in the step 1.
-3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`.
+3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`.
For example:
{% highlight python %}
@@ -458,7 +458,7 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt
schema = StructType(fields)
# Apply the schema to the RDD.
-schemaPeople = sqlContext.applySchema(people, schema)
+schemaPeople = sqlContext.createDataFrame(people, schema)
# Register the SchemaRDD as a table.
schemaPeople.registerTempTable("people")
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 5041e0b6d3..5d8c5d0a92 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -71,7 +71,7 @@ public class JavaCrossValidatorExample {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -112,7 +112,7 @@ public class JavaCrossValidatorExample {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerTempTable("prediction");
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 4d9dad9f23..19d0eb2168 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -62,7 +62,7 @@ public class JavaDeveloperApiExample {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
@@ -80,7 +80,7 @@ public class JavaDeveloperApiExample {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
DataFrame results = model.transform(test);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index cc69e6315f..4c4d532388 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -54,7 +54,7 @@ public class JavaSimpleParamsExample {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -94,7 +94,7 @@ public class JavaSimpleParamsExample {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index d929f1ad20..fdcfc888c2 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -54,7 +54,7 @@ public class JavaSimpleTextClassificationPipeline {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -79,7 +79,7 @@ public class JavaSimpleTextClassificationPipeline {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
model.transform(test).registerTempTable("prediction");
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index 8defb769ff..dee794840a 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -74,7 +74,7 @@ public class JavaSparkSQL {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class);
+ DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index 7f5c68e3d0..47202fde75 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -31,7 +31,7 @@ if __name__ == "__main__":
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
# Infer schema from the first row, create a DataFrame and print the schema
- some_df = sqlContext.inferSchema(some_rdd)
+ some_df = sqlContext.createDataFrame(some_rdd)
some_df.printSchema()
# Another RDD is created from a list of tuples
@@ -40,7 +40,7 @@ if __name__ == "__main__":
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
# Create a DataFrame by applying the schema to the RDD and print the schema
- another_df = sqlContext.applySchema(another_rdd, schema)
+ another_df = sqlContext.createDataFrame(another_rdd, schema)
another_df.printSchema()
# root
# |-- age: integer (nullable = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5d51c51346..324b1ba784 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -76,8 +76,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val metrics = new Array[Double](epm.size)
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.applySchema(training, schema).cache()
- val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+ val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 50995ffef9..0a8c9e5954 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -45,7 +45,7 @@ public class JavaPipelineSuite {
jsql = new SQLContext(jsc);
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.applySchema(points, LabeledPoint.class);
+ dataset = jsql.createDataFrame(points, LabeledPoint.class);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index d4b6644792..3f8e59de0f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -50,7 +50,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 40d5a92bb3..0cc36c8d56 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -46,7 +46,7 @@ public class JavaLinearRegressionSuite implements Serializable {
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 074b58c07d..0bb6b489f2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -45,7 +45,7 @@ public class JavaCrossValidatorSuite implements Serializable {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 882c0f98ea..9d29ef4839 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -25,7 +25,7 @@ from py4j.java_collections import MapConverter
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
@@ -47,23 +47,11 @@ class SQLContext(object):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
-
- >>> bad_rdd = sc.parallelize([1,2,3])
- >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
-
>>> from datetime import datetime
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df = sqlCtx.createDataFrame(allTypes)
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -131,6 +119,9 @@ class SQLContext(object):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
@@ -199,7 +190,7 @@ class SQLContext(object):
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
- if samplingRatio > 0.99:
+ if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)
@@ -211,6 +202,9 @@ class SQLContext(object):
"""
Applies the given schema to the given RDD of L{tuple} or L{list}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
These tuples or lists can contain complex nested structures like
lists, maps or nested rows.
@@ -300,13 +294,68 @@ class SQLContext(object):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
+ def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+ """
+ Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+
+ `schema` could be :class:`StructType` or a list of column names.
+
+ When `schema` is a list of column names, the type of each column
+ will be inferred from `rdd`.
+
+ When `schema` is None, it will try to infer the column name and type
+ from `rdd`, which should be an RDD of :class:`Row`, or namedtuple,
+ or dict.
+
+ If referring needed, `samplingRatio` is used to determined how many
+ rows will be used to do referring. The first row will be used if
+ `samplingRatio` is None.
+
+ :param rdd: an RDD of Row or tuple or list or dict
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd = sc.parallelize([('Alice', 1)])
+ >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql import Row
+ >>> Person = Row('name', 'age')
+ >>> person = rdd.map(lambda r: Person(*r))
+ >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("name", StringType(), True),
+ ... StructField("age", IntegerType(), True)])
+ >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3.collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ if isinstance(rdd, DataFrame):
+ raise TypeError("rdd is already a DataFrame")
+
+ if isinstance(schema, StructType):
+ return self.applySchema(rdd, schema)
+ else:
+ if isinstance(schema, (list, tuple)):
+ first = rdd.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple")
+ row_cls = Row(*schema)
+ rdd = rdd.map(lambda r: row_cls(*r))
+ return self.inferSchema(rdd, samplingRatio)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
"""
if (rdd.__class__ is DataFrame):
@@ -321,7 +370,6 @@ class SQLContext(object):
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlCtx.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -526,7 +574,6 @@ class SQLContext(object):
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
@@ -537,7 +584,6 @@ class SQLContext(object):
def table(self, tableName):
"""Returns the specified table as a L{DataFrame}.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -685,11 +731,12 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- globs['rdd'] = sc.parallelize(
+ globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['df'] = sqlCtx.createDataFrame(rdd)
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index bc945091f7..5e41e36897 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -96,7 +96,7 @@ class SQLTests(ReusedPySparkTestCase):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = cls.sqlCtx.createDataFrame(rdd)
@classmethod
def tearDownClass(cls):
@@ -110,14 +110,14 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf2(self):
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -155,17 +155,17 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema())
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema())
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
@@ -187,14 +187,14 @@ class SQLTests(ReusedPySparkTestCase):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
self.assertEqual([], df.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
- df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ df2 = self.sqlCtx.createDataFrame(rdd, 1.0)
self.assertEqual(df.schema(), df2.schema())
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
@@ -205,7 +205,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -214,7 +214,7 @@ class SQLTests(ReusedPySparkTestCase):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -224,7 +224,7 @@ class SQLTests(ReusedPySparkTestCase):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
@@ -238,7 +238,7 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.applySchema(rdd, schema)
+ df = self.sqlCtx.createDataFrame(rdd, schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
@@ -246,7 +246,7 @@ class SQLTests(ReusedPySparkTestCase):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.createDataFrame(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 801505bceb..523911d108 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
- * val dataFrame = sqlContext. applySchema(people, schema)
+ * val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
@@ -252,11 +252,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
- *
- * @group userf
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
@@ -264,8 +262,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
@DeveloperApi
- def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- applySchema(rowRDD.rdd, schema);
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
+ * a seq of names of columns to this RDD, the data type for each column will
+ * be inferred by the first row.
+ *
+ * @param rowRDD an JavaRDD of Row
+ * @param columns names for each column
+ * @return DataFrame
+ */
+ def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
+ createDataFrame(rowRDD.rdd, columns.toSeq)
}
/**
@@ -274,7 +285,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -301,8 +312,72 @@ class SQLContext(@transient val sparkContext: SparkContext)
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sqlContext. applySchema(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group userf
+ */
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ @DeveloperApi
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd, beanClass)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- applySchema(rdd.rdd, beanClass)
+ createDataFrame(rdd, beanClass)
}
/**
@@ -375,7 +450,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
@@ -393,7 +468,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
@Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa4cdecbcb..1d71039872 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -180,7 +180,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -240,7 +240,7 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
- val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 55fd0b0892..bba8899651 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
+ val sqlCtx = TestSQLContext
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
@@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = applySchema(rowRDD2, schema2)
+ val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD3, schema2)
+ val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df108a9d26..c3210733f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -71,7 +71,7 @@ class PlannerSuite extends FunSuite {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- applySchema(rowRDD, schema).registerTempTable("testLimit")
+ createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e581ac9b50..21e7093610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -54,7 +54,7 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
@@ -62,8 +62,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE with overwrite") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.DROPTEST", false)
assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
@@ -75,8 +75,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
@@ -85,8 +85,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("CREATE then INSERT to truncate") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
@@ -95,8 +95,8 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
}
test("Incompatible INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val srdd = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
intercept[org.apache.spark.SparkException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 4fc92e3e3b..fde4b47438 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -820,7 +820,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDataFrame
val result = df2.toJSON.collect()
@@ -841,7 +841,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD2, schema2)
+ val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDataFrame
val result2 = df4.toJSON.collect()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 43da7519ac..89b18c3439 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -97,7 +97,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil)
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m MAP <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -142,7 +142,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array <STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
@@ -159,7 +159,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -176,7 +176,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct <f: STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 49fe79d989..9a6e8650a0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -34,6 +35,7 @@ case class Nested3(f3: Int)
class SQLQuerySuite extends QueryTest {
import org.apache.spark.sql.hive.test.TestHive.implicits._
+ val sqlCtx = TestHive
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
@@ -277,7 +279,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- applySchema(rowRdd, schema).registerTempTable("testTable")
+ sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes