aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-11-23 13:44:30 -0800
committerReynold Xin <rxin@databricks.com>2015-11-23 13:44:30 -0800
commit1d9120201012213edb1971a09e0849336dbb9415 (patch)
tree2bdd3094a67aa195a482c6be0b551431e4f5320c
parent1b6e938be836786bac542fa430580248161e5403 (diff)
downloadspark-1d9120201012213edb1971a09e0849336dbb9415.tar.gz
spark-1d9120201012213edb1971a09e0849336dbb9415.tar.bz2
spark-1d9120201012213edb1971a09e0849336dbb9415.zip
[SPARK-11836][SQL] udf/cast should not create new SQLContext
They should use the existing SQLContext. Author: Davies Liu <davies@databricks.com> Closes #9914 from davies/create_udf.
-rw-r--r--python/pyspark/sql/column.py7
-rw-r--r--python/pyspark/sql/functions.py7
2 files changed, 8 insertions, 6 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 9ca8e1f264..81fd4e7826 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -346,9 +346,10 @@ class Column(object):
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(dataType.json())
+ from pyspark.sql import SQLContext
+ sc = SparkContext.getOrCreate()
+ ctx = SQLContext.getOrCreate(sc)
+ jdt = ctx._ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise TypeError("unexpected type: %s" % type(dataType))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index c3da513c13..a1ca723bbd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1457,14 +1457,15 @@ class UserDefinedFunction(object):
self._judf = self._create_judf(name)
def _create_judf(self, name):
+ from pyspark.sql import SQLContext
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
+ sc = SparkContext.getOrCreate()
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
+ ctx = SQLContext.getOrCreate(sc)
+ jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,