aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/__init__.py15
-rw-r--r--python/pyspark/sql/_types.py (renamed from python/pyspark/sql/types.py)49
-rw-r--r--python/pyspark/sql/context.py32
-rw-r--r--python/pyspark/sql/dataframe.py63
-rw-r--r--python/pyspark/sql/functions.py6
-rw-r--r--python/pyspark/sql/tests.py11
6 files changed, 120 insertions, 56 deletions
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 65abb24eed..6d54b9e49e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -37,9 +37,22 @@ Important classes of Spark SQL and DataFrames:
- L{types}
List of data types available.
"""
+from __future__ import absolute_import
+
+# fix the module name conflict for Python 3+
+import sys
+from . import _types as types
+modname = __name__ + '.types'
+types.__name__ = modname
+# update the __module__ for all objects, make them picklable
+for v in types.__dict__.values():
+ if hasattr(v, "__module__") and v.__module__.endswith('._types'):
+ v.__module__ = modname
+sys.modules[modname] = types
+del modname, sys
-from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
+from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
__all__ = [
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/_types.py
index ef76d84c00..492c0cbdcf 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/_types.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import sys
import decimal
import datetime
import keyword
@@ -25,6 +26,9 @@ import weakref
from array import array
from operator import itemgetter
+if sys.version >= "3":
+ long = int
+ unicode = str
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
@@ -410,7 +414,7 @@ class UserDefinedType(DataType):
split = pyUDT.rfind(".")
pyModule = pyUDT[:split]
pyClass = pyUDT[split+1:]
- m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ m = __import__(pyModule, globals(), locals(), [pyClass])
UDT = getattr(m, pyClass)
return UDT()
@@ -419,10 +423,9 @@ class UserDefinedType(DataType):
_all_primitive_types = dict((v.typeName(), v)
- for v in globals().itervalues()
- if type(v) is PrimitiveTypeSingleton and
- v.__base__ == PrimitiveType)
-
+ for v in list(globals().values())
+ if (type(v) is type or type(v) is PrimitiveTypeSingleton)
+ and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
@@ -486,10 +489,10 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
def _parse_datatype_json_value(json_value):
- if type(json_value) is unicode:
+ if not isinstance(json_value, dict):
if json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
- elif json_value == u'decimal':
+ elif json_value == 'decimal':
return DecimalType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
@@ -511,10 +514,8 @@ _type_mappings = {
type(None): NullType,
bool: BooleanType,
int: LongType,
- long: LongType,
float: DoubleType,
str: StringType,
- unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
@@ -522,6 +523,12 @@ _type_mappings = {
datetime.time: TimestampType,
}
+if sys.version < "3":
+ _type_mappings.update({
+ unicode: StringType,
+ long: LongType,
+ })
+
def _infer_type(obj):
"""Infer the DataType from obj
@@ -541,7 +548,7 @@ def _infer_type(obj):
return dataType()
if isinstance(obj, dict):
- for key, value in obj.iteritems():
+ for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
else:
@@ -565,10 +572,10 @@ def _infer_schema(row):
items = sorted(row.items())
elif isinstance(row, (tuple, list)):
- if hasattr(row, "_fields"): # namedtuple
- items = zip(row._fields, tuple(row))
- elif hasattr(row, "__fields__"): # Row
+ if hasattr(row, "__fields__"): # Row
items = zip(row.__fields__, tuple(row))
+ elif hasattr(row, "_fields"): # namedtuple
+ items = zip(row._fields, tuple(row))
else:
names = ['_%d' % i for i in range(1, len(row) + 1)]
items = zip(names, row)
@@ -647,7 +654,7 @@ def _python_to_sql_converter(dataType):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
- if hasattr(obj, "_fields") or hasattr(obj, "__fields__"):
+ if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
@@ -733,12 +740,12 @@ def _create_converter(dataType):
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
- return lambda row: map(conv, row)
+ return lambda row: [conv(v) for v in row]
elif isinstance(dataType, MapType):
kconv = _create_converter(dataType.keyType)
vconv = _create_converter(dataType.valueType)
- return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(dataType, NullType):
return lambda x: None
@@ -881,7 +888,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
- if dataType is NullType():
+ if isinstance(dataType, NullType):
return _infer_type(obj)
if not obj:
@@ -892,7 +899,7 @@ def _infer_schema_type(obj, dataType):
return ArrayType(eType, True)
elif isinstance(dataType, MapType):
- k, v = obj.iteritems().next()
+ k, v = next(iter(obj.items()))
return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
@@ -935,7 +942,7 @@ def _verify_type(obj, dataType):
>>> _verify_type(None, StructType([]))
>>> _verify_type("", StringType())
>>> _verify_type(0, LongType())
- >>> _verify_type(range(3), ArrayType(ShortType()))
+ >>> _verify_type(list(range(3)), ArrayType(ShortType()))
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
@@ -976,7 +983,7 @@ def _verify_type(obj, dataType):
_verify_type(i, dataType.elementType)
elif isinstance(dataType, MapType):
- for k, v in obj.iteritems():
+ for k, v in obj.items():
_verify_type(k, dataType.keyType)
_verify_type(v, dataType.valueType)
@@ -1213,6 +1220,8 @@ class Row(tuple):
return self[idx]
except IndexError:
raise AttributeError(item)
+ except ValueError:
+ raise AttributeError(item)
def __reduce__(self):
if hasattr(self, "__fields__"):
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index e8529a8f8e..c90afc326c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -15,14 +15,19 @@
# limitations under the License.
#
+import sys
import warnings
import json
-from itertools import imap
+
+if sys.version >= '3':
+ basestring = unicode = str
+else:
+ from itertools import imap as map
from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter
-from pyspark.rdd import RDD, _prepare_for_python_RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
@@ -62,31 +67,27 @@ class SQLContext(object):
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
- When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`,
- which could be used to convert an RDD into a DataFrame, it's a shorthand for
- :func:`SQLContext.createDataFrame`.
-
:param sparkContext: The :class:`SparkContext` backing this SQLContext.
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
SQLContext in the JVM, instead we make all calls to this object.
"""
+ @ignore_unicode_prefix
def __init__(self, sparkContext, sqlContext=None):
"""Creates a new SQLContext.
>>> from datetime import datetime
>>> sqlContext = SQLContext(sc)
- >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
+ >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
>>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
- [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
- ... x.row.a, x.list)).collect()
- [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
+ [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
+ [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
@@ -122,6 +123,7 @@ class SQLContext(object):
"""Returns a :class:`UDFRegistration` for UDF registration."""
return UDFRegistration(self)
+ @ignore_unicode_prefix
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -147,7 +149,7 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
"""
- func = lambda _, it: imap(lambda x: f(*x), it)
+ func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
@@ -185,6 +187,7 @@ class SQLContext(object):
schema = rdd.map(_infer_schema).reduce(_merge_type)
return schema
+ @ignore_unicode_prefix
def inferSchema(self, rdd, samplingRatio=None):
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
@@ -195,6 +198,7 @@ class SQLContext(object):
return self.createDataFrame(rdd, None, samplingRatio)
+ @ignore_unicode_prefix
def applySchema(self, rdd, schema):
"""::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
@@ -208,6 +212,7 @@ class SQLContext(object):
return self.createDataFrame(rdd, schema)
+ @ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`,
@@ -380,6 +385,7 @@ class SQLContext(object):
df = self._ssql_ctx.jsonFile(path, scala_datatype)
return DataFrame(df, self)
+ @ignore_unicode_prefix
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a :class:`DataFrame`.
@@ -477,6 +483,7 @@ class SQLContext(object):
joptions)
return DataFrame(df, self)
+ @ignore_unicode_prefix
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
@@ -497,6 +504,7 @@ class SQLContext(object):
"""
return DataFrame(self._ssql_ctx.table(tableName), self)
+ @ignore_unicode_prefix
def tables(self, dbName=None):
"""Returns a :class:`DataFrame` containing names of tables in the given database.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f2c3b74a18..d76504f986 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -16,14 +16,19 @@
#
import sys
-import itertools
import warnings
import random
+if sys.version >= '3':
+ basestring = unicode = str
+ long = int
+else:
+ from itertools import imap as map
+
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _load_from_socket
+from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -65,19 +70,20 @@ class DataFrame(object):
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
self._schema = None # initialized lazily
+ self._lazy_rdd = None
@property
def rdd(self):
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
"""
- if not hasattr(self, '_lazy_rdd'):
+ if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
- return itertools.imap(cls, it)
+ return map(cls, it)
self._lazy_rdd = rdd.mapPartitions(applySchema)
@@ -89,13 +95,14 @@ class DataFrame(object):
"""
return DataFrameNaFunctions(self)
- def toJSON(self, use_unicode=False):
+ @ignore_unicode_prefix
+ def toJSON(self, use_unicode=True):
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
Each row is turned into a JSON document as one element in the returned RDD.
>>> df.toJSON().first()
- '{"age":2,"name":"Alice"}'
+ u'{"age":2,"name":"Alice"}'
"""
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
@@ -228,7 +235,7 @@ class DataFrame(object):
|-- name: string (nullable = true)
<BLANKLINE>
"""
- print (self._jdf.schema().treeString())
+ print(self._jdf.schema().treeString())
def explain(self, extended=False):
"""Prints the (logical and physical) plans to the console for debugging purpose.
@@ -250,9 +257,9 @@ class DataFrame(object):
== RDD ==
"""
if extended:
- print self._jdf.queryExecution().toString()
+ print(self._jdf.queryExecution().toString())
else:
- print self._jdf.queryExecution().executedPlan().toString()
+ print(self._jdf.queryExecution().executedPlan().toString())
def isLocal(self):
"""Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
@@ -270,7 +277,7 @@ class DataFrame(object):
2 Alice
5 Bob
"""
- print self._jdf.showString(n).encode('utf8', 'ignore')
+ print(self._jdf.showString(n))
def __repr__(self):
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
@@ -279,10 +286,11 @@ class DataFrame(object):
"""Returns the number of rows in this :class:`DataFrame`.
>>> df.count()
- 2L
+ 2
"""
- return self._jdf.count()
+ return int(self._jdf.count())
+ @ignore_unicode_prefix
def collect(self):
"""Returns all the records as a list of :class:`Row`.
@@ -295,6 +303,7 @@ class DataFrame(object):
cls = _create_cls(self.schema)
return [cls(r) for r in rs]
+ @ignore_unicode_prefix
def limit(self, num):
"""Limits the result count to the number specified.
@@ -306,6 +315,7 @@ class DataFrame(object):
jdf = self._jdf.limit(num)
return DataFrame(jdf, self.sql_ctx)
+ @ignore_unicode_prefix
def take(self, num):
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
@@ -314,6 +324,7 @@ class DataFrame(object):
"""
return self.limit(num).collect()
+ @ignore_unicode_prefix
def map(self, f):
""" Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
@@ -324,6 +335,7 @@ class DataFrame(object):
"""
return self.rdd.map(f)
+ @ignore_unicode_prefix
def flatMap(self, f):
""" Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
and then flattening the results.
@@ -353,7 +365,7 @@ class DataFrame(object):
This is a shorthand for ``df.rdd.foreach()``.
>>> def f(person):
- ... print person.name
+ ... print(person.name)
>>> df.foreach(f)
"""
return self.rdd.foreach(f)
@@ -365,7 +377,7 @@ class DataFrame(object):
>>> def f(people):
... for person in people:
- ... print person.name
+ ... print(person.name)
>>> df.foreachPartition(f)
"""
return self.rdd.foreachPartition(f)
@@ -412,7 +424,7 @@ class DataFrame(object):
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
>>> df.distinct().count()
- 2L
+ 2
"""
return DataFrame(self._jdf.distinct(), self.sql_ctx)
@@ -420,10 +432,10 @@ class DataFrame(object):
"""Returns a sampled subset of this :class:`DataFrame`.
>>> df.sample(False, 0.5, 97).count()
- 1L
+ 1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
- seed = seed if seed is not None else random.randint(0, sys.maxint)
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
@@ -437,6 +449,7 @@ class DataFrame(object):
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
+ @ignore_unicode_prefix
def columns(self):
"""Returns all column names as a list.
@@ -445,6 +458,7 @@ class DataFrame(object):
"""
return [f.name for f in self.schema.fields]
+ @ignore_unicode_prefix
def join(self, other, joinExprs=None, joinType=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
@@ -470,6 +484,7 @@ class DataFrame(object):
jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
return DataFrame(jdf, self.sql_ctx)
+ @ignore_unicode_prefix
def sort(self, *cols):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
@@ -513,6 +528,7 @@ class DataFrame(object):
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
return DataFrame(jdf, self.sql_ctx)
+ @ignore_unicode_prefix
def head(self, n=None):
"""
Returns the first ``n`` rows as a list of :class:`Row`,
@@ -528,6 +544,7 @@ class DataFrame(object):
return rs[0] if rs else None
return self.take(n)
+ @ignore_unicode_prefix
def first(self):
"""Returns the first row as a :class:`Row`.
@@ -536,6 +553,7 @@ class DataFrame(object):
"""
return self.head()
+ @ignore_unicode_prefix
def __getitem__(self, item):
"""Returns the column as a :class:`Column`.
@@ -567,6 +585,7 @@ class DataFrame(object):
jc = self._jdf.apply(name)
return Column(jc)
+ @ignore_unicode_prefix
def select(self, *cols):
"""Projects a set of expressions and returns a new :class:`DataFrame`.
@@ -598,6 +617,7 @@ class DataFrame(object):
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
return DataFrame(jdf, self.sql_ctx)
+ @ignore_unicode_prefix
def filter(self, condition):
"""Filters rows using the given condition.
@@ -626,6 +646,7 @@ class DataFrame(object):
where = filter
+ @ignore_unicode_prefix
def groupBy(self, *cols):
"""Groups the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedData`
@@ -775,6 +796,7 @@ class DataFrame(object):
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
+ @ignore_unicode_prefix
def withColumn(self, colName, col):
"""Returns a new :class:`DataFrame` by adding a column.
@@ -786,6 +808,7 @@ class DataFrame(object):
"""
return self.select('*', col.alias(colName))
+ @ignore_unicode_prefix
def withColumnRenamed(self, existing, new):
"""REturns a new :class:`DataFrame` by renaming an existing column.
@@ -852,6 +875,7 @@ class GroupedData(object):
self._jdf = jdf
self.sql_ctx = sql_ctx
+ @ignore_unicode_prefix
def agg(self, *exprs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.
@@ -1041,11 +1065,13 @@ class Column(object):
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
+ __truediv__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
+ __rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
# logistic operators
@@ -1075,6 +1101,7 @@ class Column(object):
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
+ @ignore_unicode_prefix
def substr(self, startPos, length):
"""
Return a :class:`Column` which is a substring of the column
@@ -1097,6 +1124,7 @@ class Column(object):
__getslice__ = substr
+ @ignore_unicode_prefix
def inSet(self, *cols):
""" A boolean expression that is evaluated to true if the value of this
expression is contained by the evaluated values of the arguments.
@@ -1131,6 +1159,7 @@ class Column(object):
"""
return Column(getattr(self._jc, "as")(alias))
+ @ignore_unicode_prefix
def cast(self, dataType):
""" Convert the column into type `dataType`
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index daeb6916b5..1d65369528 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -18,8 +18,10 @@
"""
A collections of builtin functions
"""
+import sys
-from itertools import imap
+if sys.version < "3":
+ from itertools import imap as map
from py4j.java_collections import ListConverter
@@ -116,7 +118,7 @@ class UserDefinedFunction(object):
def _create_judf(self):
f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
+ func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b3a6a2c6a9..7c09a0cfe3 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -157,13 +157,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
- d = [Row(l=range(3), d={"key": range(5)})]
+ d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
- self.assertEqual(range(3), l1)
+ self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
@@ -266,7 +266,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema(self):
from datetime import date, datetime
- rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
{"a": 1}, (2,), [1, 2, 3], None)])
schema = StructType([
@@ -309,7 +309,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()
- k, v = df.head().m.items()[0]
+ k, v = list(df.head().m.items())[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -554,6 +554,9 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
except py4j.protocol.Py4JError:
cls.sqlCtx = None
return
+ except TypeError:
+ cls.sqlCtx = None
+ return
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())