aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-04-08 13:31:45 -0700
committerReynold Xin <rxin@databricks.com>2015-04-08 13:31:45 -0700
commit6ada4f6f52cf1d992c7ab0c32318790cf08b0a0d (patch)
tree495c9bb86bb98de40365538bebcf9144547d8cce /python
parent66159c35010af35098dd1ec75475bb5d4d0fd6ca (diff)
downloadspark-6ada4f6f52cf1d992c7ab0c32318790cf08b0a0d.tar.gz
spark-6ada4f6f52cf1d992c7ab0c32318790cf08b0a0d.tar.bz2
spark-6ada4f6f52cf1d992c7ab0c32318790cf08b0a0d.zip
[SPARK-6781] [SQL] use sqlContext in python shell
Use `sqlContext` in PySpark shell, make it consistent with SQL programming guide. `sqlCtx` is also kept for compatibility. Author: Davies Liu <davies@databricks.com> Closes #5425 from davies/sqlCtx and squashes the following commits: af67340 [Davies Liu] sqlCtx -> sqlContext 15a278f [Davies Liu] use sqlContext in python shell
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py4
-rw-r--r--python/pyspark/ml/feature.py4
-rw-r--r--python/pyspark/shell.py6
-rw-r--r--python/pyspark/sql/context.py79
-rw-r--r--python/pyspark/sql/dataframe.py6
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--python/pyspark/sql/types.py4
7 files changed, 52 insertions, 53 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 4ff7463498..7f42de531f 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -91,9 +91,9 @@ if __name__ == "__main__":
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx
+ globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 433b4fb5d2..1cfcd019df 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -117,9 +117,9 @@ if __name__ == "__main__":
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx
+ globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 1a02fece9c..81aa970a32 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -53,9 +53,9 @@ atexit.register(lambda: sc.stop())
try:
# Try to access HiveConf, it will raise exception if Hive is not added
sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
- sqlCtx = HiveContext(sc)
+ sqlCtx = sqlContext = HiveContext(sc)
except py4j.protocol.Py4JError:
- sqlCtx = SQLContext(sc)
+ sqlCtx = sqlContext = SQLContext(sc)
print("""Welcome to
____ __
@@ -68,7 +68,7 @@ print("Using Python version %s (%s, %s)" % (
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1]))
-print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__)
+print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__)
if add_files is not None:
print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead")
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c2d81ba804..93e2d176a5 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -37,12 +37,12 @@ except ImportError:
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
-def _monkey_patch_RDD(sqlCtx):
+def _monkey_patch_RDD(sqlContext):
def toDF(self, schema=None, sampleRatio=None):
"""
Converts current :class:`RDD` into a :class:`DataFrame`
- This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)``
+ This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)``
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
@@ -51,7 +51,7 @@ def _monkey_patch_RDD(sqlCtx):
>>> rdd.toDF().collect()
[Row(name=u'Alice', age=1)]
"""
- return sqlCtx.createDataFrame(self, schema, sampleRatio)
+ return sqlContext.createDataFrame(self, schema, sampleRatio)
RDD.toDF = toDF
@@ -75,13 +75,13 @@ class SQLContext(object):
"""Creates a new SQLContext.
>>> from datetime import datetime
- >>> sqlCtx = SQLContext(sc)
+ >>> sqlContext = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
>>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
- >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
+ >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
@@ -133,18 +133,18 @@ class SQLContext(object):
:param samplingRatio: lambda function
:param returnType: a :class:`DataType` object
- >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
- >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+ >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
+ >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
>>> from pyspark.sql.types import IntegerType
- >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
- >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
>>> from pyspark.sql.types import IntegerType
- >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
- >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
@@ -229,26 +229,26 @@ class SQLContext(object):
:param samplingRatio: the sample ratio of rows used for inferring
>>> l = [('Alice', 1)]
- >>> sqlCtx.createDataFrame(l).collect()
+ >>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
- >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+ >>> sqlContext.createDataFrame(l, ['name', 'age']).collect()
[Row(name=u'Alice', age=1)]
>>> d = [{'name': 'Alice', 'age': 1}]
- >>> sqlCtx.createDataFrame(d).collect()
+ >>> sqlContext.createDataFrame(d).collect()
[Row(age=1, name=u'Alice')]
>>> rdd = sc.parallelize(l)
- >>> sqlCtx.createDataFrame(rdd).collect()
+ >>> sqlContext.createDataFrame(rdd).collect()
[Row(_1=u'Alice', _2=1)]
- >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df = sqlContext.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql import Row
>>> Person = Row('name', 'age')
>>> person = rdd.map(lambda r: Person(*r))
- >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2 = sqlContext.createDataFrame(person)
>>> df2.collect()
[Row(name=u'Alice', age=1)]
@@ -256,11 +256,11 @@ class SQLContext(object):
>>> schema = StructType([
... StructField("name", StringType(), True),
... StructField("age", IntegerType(), True)])
- >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3 = sqlContext.createDataFrame(rdd, schema)
>>> df3.collect()
[Row(name=u'Alice', age=1)]
- >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
+ >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
[Row(name=u'Alice', age=1)]
"""
if isinstance(data, DataFrame):
@@ -316,7 +316,7 @@ class SQLContext(object):
Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
if (df.__class__ is DataFrame):
self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName)
@@ -330,7 +330,7 @@ class SQLContext(object):
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -352,7 +352,7 @@ class SQLContext(object):
>>> shutil.rmtree(jsonFile)
>>> with open(jsonFile, 'w') as f:
... f.writelines(jsonStrings)
- >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> df1 = sqlContext.jsonFile(jsonFile)
>>> df1.printSchema()
root
|-- field1: long (nullable = true)
@@ -365,7 +365,7 @@ class SQLContext(object):
... StructField("field2", StringType()),
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> df2 = sqlContext.jsonFile(jsonFile, schema)
>>> df2.printSchema()
root
|-- field2: string (nullable = true)
@@ -386,11 +386,11 @@ class SQLContext(object):
If the schema is provided, applies the given schema to this JSON dataset.
Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
- >>> df1 = sqlCtx.jsonRDD(json)
+ >>> df1 = sqlContext.jsonRDD(json)
>>> df1.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
- >>> df2 = sqlCtx.jsonRDD(json, df1.schema)
+ >>> df2 = sqlContext.jsonRDD(json, df1.schema)
>>> df2.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
@@ -400,7 +400,7 @@ class SQLContext(object):
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))
... ])
- >>> df3 = sqlCtx.jsonRDD(json, schema)
+ >>> df3 = sqlContext.jsonRDD(json, schema)
>>> df3.first()
Row(field2=u'row1', field3=Row(field5=None))
"""
@@ -480,8 +480,8 @@ class SQLContext(object):
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
@@ -490,8 +490,8 @@ class SQLContext(object):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.table("table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -505,8 +505,8 @@ class SQLContext(object):
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.tables()
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
Row(tableName=u'table1', isTemporary=True)
"""
@@ -520,10 +520,10 @@ class SQLContext(object):
If ``dbName`` is not specified, the current database will be used.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> "table1" in sqlCtx.tableNames()
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> "table1" in sqlContext.tableNames()
True
- >>> "table1" in sqlCtx.tableNames("db")
+ >>> "table1" in sqlContext.tableNames("db")
True
"""
if dbName is None:
@@ -578,11 +578,11 @@ class HiveContext(SQLContext):
class UDFRegistration(object):
"""Wrapper for user-defined function registration."""
- def __init__(self, sqlCtx):
- self.sqlCtx = sqlCtx
+ def __init__(self, sqlContext):
+ self.sqlContext = sqlContext
def register(self, name, f, returnType=StringType()):
- return self.sqlCtx.registerFunction(name, f, returnType)
+ return self.sqlContext.registerFunction(name, f, returnType)
register.__doc__ = SQLContext.registerFunction.__doc__
@@ -595,13 +595,12 @@ def _test():
globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
- _monkey_patch_RDD(sqlCtx)
globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index c30326ebd1..ef91a9c4f5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -110,7 +110,7 @@ class DataFrame(object):
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df2.collect()) == sorted(df.collect())
True
"""
@@ -123,7 +123,7 @@ class DataFrame(object):
that was used to create this :class:`DataFrame`.
>>> df.registerTempTable("people")
- >>> df2 = sqlCtx.sql("select * from people")
+ >>> df2 = sqlContext.sql("select * from people")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -1180,7 +1180,7 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 146ba6f3e0..daeb6916b5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -161,7 +161,7 @@ def _test():
globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.functions, globs=globs,
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 45eb8b945d..7e0124b136 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -434,7 +434,7 @@ def _parse_datatype_json_string(json_string):
>>> def check_datatype(datatype):
... pickled = pickle.loads(pickle.dumps(datatype))
... assert datatype == pickled
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... assert datatype == python_datatype
>>> for cls in _all_primitive_types.values():
@@ -1237,7 +1237,7 @@ def _test():
globs = pyspark.sql.types.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
(failure_count, test_count) = doctest.testmod(