aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-03 16:01:56 -0800
committerReynold Xin <rxin@databricks.com>2015-02-03 16:01:56 -0800
commit068c0e2ee05ee8b133c2dc26b8fa094ab2712d45 (patch)
treeab40097fe86c2aae58b11d0f0160ae5d07ecfd94 /python
parent1e8b5394b44a0d3b36f64f10576c3ae3b977810c (diff)
downloadspark-068c0e2ee05ee8b133c2dc26b8fa094ab2712d45.tar.gz
spark-068c0e2ee05ee8b133c2dc26b8fa094ab2712d45.tar.bz2
spark-068c0e2ee05ee8b133c2dc26b8fa094ab2712d45.zip
[SPARK-5554] [SQL] [PySpark] add more tests for DataFrame Python API
Add more tests and docs for DataFrame Python API, improve test coverage, fix bugs. Author: Davies Liu <davies@databricks.com> Closes #4331 from davies/fix_df and squashes the following commits: dd9919f [Davies Liu] fix tests 467332c [Davies Liu] support string in cast() 83c92fe [Davies Liu] address comments c052f6f [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df 8dd19a9 [Davies Liu] fix tests in python 2.6 35ccb9f [Davies Liu] fix build 78ebcfa [Davies Liu] add sql_test.py in run_tests 9ab78b4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df 6040ba7 [Davies Liu] fix docs 3ab2661 [Davies Liu] add more tests for DataFrame
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py467
-rw-r--r--python/pyspark/sql_tests.py299
-rw-r--r--python/pyspark/tests.py261
-rwxr-xr-xpython/run-tests1
4 files changed, 581 insertions, 447 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 32bff0c7e8..268c7ef97c 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -62,7 +62,7 @@ __all__ = [
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+ "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl",
"SchemaRDD"]
@@ -1804,7 +1804,7 @@ class DataFrame(object):
people = sqlContext.parquetFile("...")
Once created, it can be manipulated using the various domain-specific-language
- (DSL) functions defined in: [[DataFrame]], [[Column]].
+ (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
To select a column from the data frame, use the apply method::
@@ -1835,8 +1835,10 @@ class DataFrame(object):
@property
def rdd(self):
- """Return the content of the :class:`DataFrame` as an :class:`RDD`
- of :class:`Row`s. """
+ """
+ Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row` s.
+ """
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
@@ -1850,18 +1852,6 @@ class DataFrame(object):
return self._lazy_rdd
- def limit(self, num):
- """Limit the result count to the number specified.
-
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.limit(2).collect()
- [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
- >>> df.limit(0).collect()
- []
- """
- jdf = self._jdf.limit(num)
- return DataFrame(jdf, self.sql_ctx)
-
def toJSON(self, use_unicode=False):
"""Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
@@ -1886,7 +1876,6 @@ class DataFrame(object):
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlCtx.parquetFile(parquetFile)
>>> sorted(df2.collect()) == sorted(df.collect())
@@ -1900,9 +1889,8 @@ class DataFrame(object):
The lifetime of this temporary table is tied to the L{SQLContext}
that was used to create this DataFrame.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.registerTempTable("test")
- >>> df2 = sqlCtx.sql("select * from test")
+ >>> df.registerTempTable("people")
+ >>> df2 = sqlCtx.sql("select * from people")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -1926,11 +1914,22 @@ class DataFrame(object):
def schema(self):
"""Returns the schema of this DataFrame (represented by
- a L{StructType})."""
+ a L{StructType}).
+
+ >>> df.schema()
+ StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
+ """
return _parse_datatype_json_string(self._jdf.schema().json())
def printSchema(self):
- """Prints out the schema in the tree format."""
+ """Prints out the schema in the tree format.
+
+ >>> df.printSchema()
+ root
+ |-- age: integer (nullable = true)
+ |-- name: string (nullable = true)
+ <BLANKLINE>
+ """
print (self._jdf.schema().treeString())
def count(self):
@@ -1940,11 +1939,8 @@ class DataFrame(object):
leverages the query optimizer to compute the count on the DataFrame,
which supports features such as filter pushdown.
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.count()
- 3L
- >>> df.count() == df.map(lambda x: x).count()
- True
+ 2L
"""
return self._jdf.count()
@@ -1954,13 +1950,11 @@ class DataFrame(object):
Each object in the list is a Row, the fields can be accessed as
attributes.
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.collect()
- [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.javaToPython().collect().iterator()
- cls = _create_cls(self.schema())
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
self._sc._writeToFile(bytesInJava, tempFile.name)
@@ -1968,23 +1962,37 @@ class DataFrame(object):
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
+ cls = _create_cls(self.schema())
return [cls(r) for r in rs]
+ def limit(self, num):
+ """Limit the result count to the number specified.
+
+ >>> df.limit(1).collect()
+ [Row(age=2, name=u'Alice')]
+ >>> df.limit(0).collect()
+ []
+ """
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
+
def take(self, num):
"""Take the first num rows of the RDD.
Each object in the list is a Row, the fields can be accessed as
attributes.
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.take(2)
- [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
return self.limit(num).collect()
def map(self, f):
""" Return a new RDD by applying a function to each Row, it's a
shorthand for df.rdd.map()
+
+ >>> df.map(lambda p: p.name).collect()
+ [u'Alice', u'Bob']
"""
return self.rdd.map(f)
@@ -2067,140 +2075,167 @@ class DataFrame(object):
@property
def dtypes(self):
"""Return all column names and their data types as a list.
+
+ >>> df.dtypes
+ [(u'age', 'IntegerType'), (u'name', 'StringType')]
"""
return [(f.name, str(f.dataType)) for f in self.schema().fields]
@property
def columns(self):
""" Return all column names as a list.
+
+ >>> df.columns
+ [u'age', u'name']
"""
return [f.name for f in self.schema().fields]
- def show(self):
- raise NotImplemented
-
def join(self, other, joinExprs=None, joinType=None):
"""
Join with another DataFrame, using the given join expression.
The following performs a full outer join between `df1` and `df2`::
- df1.join(df2, df1.key == df2.key, "outer")
-
:param other: Right side of the join
:param joinExprs: Join expression
- :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
- `semijoin`.
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+
+ >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
+ [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
"""
- if joinType is None:
- if joinExprs is None:
- jdf = self._jdf.join(other._jdf)
- else:
- jdf = self._jdf.join(other._jdf, joinExprs)
+
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
else:
- jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+ assert isinstance(joinExprs, Column), "joinExprs should be Column"
+ if joinType is None:
+ jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ else:
+ assert isinstance(joinType, basestring), "joinType should be basestring"
+ jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
return DataFrame(jdf, self.sql_ctx)
def sort(self, *cols):
- """ Return a new [[DataFrame]] sorted by the specified column,
- in ascending column.
+ """ Return a new :class:`DataFrame` sorted by the specified column.
:param cols: The columns or expressions used for sorting
+
+ >>> df.sort(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> df.sortBy(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
raise ValueError("should sort by at least one column")
- for i, c in enumerate(cols):
- if isinstance(c, basestring):
- cols[i] = Column(c)
- jcols = [c._jc for c in cols]
- jdf = self._jdf.join(*jcols)
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.sort(_to_java_column(cols[0]),
+ self._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
sortBy = sort
def head(self, n=None):
- """ Return the first `n` rows or the first row if n is None. """
+ """ Return the first `n` rows or the first row if n is None.
+
+ >>> df.head()
+ Row(age=2, name=u'Alice')
+ >>> df.head(1)
+ [Row(age=2, name=u'Alice')]
+ """
if n is None:
rs = self.head(1)
return rs[0] if rs else None
return self.take(n)
def first(self):
- """ Return the first row. """
- return self.head()
+ """ Return the first row.
- def tail(self):
- raise NotImplemented
+ >>> df.first()
+ Row(age=2, name=u'Alice')
+ """
+ return self.head()
def __getitem__(self, item):
+ """ Return the column by given name
+
+ >>> df['age'].collect()
+ [Row(age=2), Row(age=5)]
+ """
if isinstance(item, basestring):
- return Column(self._jdf.apply(item))
+ jc = self._jdf.apply(item)
+ return Column(jc, self.sql_ctx)
# TODO projection
raise IndexError
def __getattr__(self, name):
- """ Return the column by given name """
+ """ Return the column by given name
+
+ >>> df.age.collect()
+ [Row(age=2), Row(age=5)]
+ """
if name.startswith("__"):
raise AttributeError(name)
- return Column(self._jdf.apply(name))
-
- def alias(self, name):
- """ Alias the current DataFrame """
- return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+ jc = self._jdf.apply(name)
+ return Column(jc, self.sql_ctx)
def select(self, *cols):
- """ Selecting a set of expressions.::
-
- df.select()
- df.select('colA', 'colB')
- df.select(df.colA, df.colB + 1)
-
+ """ Selecting a set of expressions.
+
+ >>> df.select().collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('*').collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('name', 'age').collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df.select(df.name, (df.age + 10).As('age')).collect()
+ [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
if not cols:
cols = ["*"]
- if isinstance(cols[0], basestring):
- cols = [_create_column_from_name(n) for n in cols]
- else:
- cols = [c._jc for c in cols]
- jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
- """ Filtering rows using the given condition::
-
- df.filter(df.age > 15)
- df.where(df.age > 15)
+ """ Filtering rows using the given condition.
+ >>> df.filter(df.age > 3).collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where(df.age == 2).collect()
+ [Row(age=2, name=u'Alice')]
"""
return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
where = filter
def groupBy(self, *cols):
- """ Group the [[DataFrame]] using the specified columns,
+ """ Group the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedDataFrame`
- for all the available aggregate functions::
-
- df.groupBy(df.department).avg()
- df.groupBy("department", "gender").agg({
- "salary": "avg",
- "age": "max",
- })
+ for all the available aggregate functions.
+
+ >>> df.groupBy().avg().collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+ >>> df.groupBy(df.name).avg().collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
"""
- if cols and isinstance(cols[0], basestring):
- cols = [_create_column_from_name(n) for n in cols]
- else:
- cols = [c._jc for c in cols]
- jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
def agg(self, *exprs):
- """ Aggregate on the entire [[DataFrame]] without groups
- (shorthand for df.groupBy.agg())::
-
- df.agg({"age": "max", "salary": "avg"})
+ """ Aggregate on the entire :class:`DataFrame` without groups
+ (shorthand for df.groupBy.agg()).
+
+ >>> df.agg({"age": "max"}).collect()
+ [Row(MAX(age#0)=5)]
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
@@ -2213,7 +2248,7 @@ class DataFrame(object):
return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
def intersect(self, other):
- """ Return a new [[DataFrame]] containing rows only in
+ """ Return a new :class:`DataFrame` containing rows only in
both this frame and another frame.
This is equivalent to `INTERSECT` in SQL.
@@ -2221,7 +2256,7 @@ class DataFrame(object):
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
def subtract(self, other):
- """ Return a new [[DataFrame]] containing rows in this frame
+ """ Return a new :class:`DataFrame` containing rows in this frame
but not in another frame.
This is equivalent to `EXCEPT` in SQL.
@@ -2229,7 +2264,11 @@ class DataFrame(object):
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
def sample(self, withReplacement, fraction, seed=None):
- """ Return a new DataFrame by sampling a fraction of rows. """
+ """ Return a new DataFrame by sampling a fraction of rows.
+
+ >>> df.sample(False, 0.5, 10).collect()
+ [Row(age=2, name=u'Alice')]
+ """
if seed is None:
jdf = self._jdf.sample(withReplacement, fraction)
else:
@@ -2237,11 +2276,12 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
def addColumn(self, colName, col):
- """ Return a new [[DataFrame]] by adding a column. """
- return self.select('*', col.alias(colName))
+ """ Return a new :class:`DataFrame` by adding a column.
- def removeColumn(self, colName):
- raise NotImplemented
+ >>> df.addColumn('age2', df.age + 2).collect()
+ [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
+ """
+ return self.select('*', col.As(colName))
# Having SchemaRDD for backward compatibility (for docs)
@@ -2280,7 +2320,14 @@ class GroupedDataFrame(object):
`sum`, `count`.
:param exprs: list or aggregate columns or a map from column
- name to agregate methods.
+ name to aggregate methods.
+
+ >>> gdf = df.groupBy(df.name)
+ >>> gdf.agg({"age": "max"}).collect()
+ [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
+ >>> from pyspark.sql import Dsl
+ >>> gdf.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
@@ -2297,7 +2344,11 @@ class GroupedDataFrame(object):
@dfapi
def count(self):
- """ Count the number of rows for each group. """
+ """ Count the number of rows for each group.
+
+ >>> df.groupBy(df.age).count().collect()
+ [Row(age=2, count=1), Row(age=5, count=1)]
+ """
@dfapi
def mean(self):
@@ -2349,18 +2400,25 @@ SCALA_METHOD_MAPPINGS = {
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.org.apache.spark.sql.Dsl.lit(literal)
+ return sc._jvm.Dsl.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.IncomputableColumn(name)
+ return sc._jvm.Dsl.col(name)
+
+
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
def _scalaMethod(name):
""" Translate operators into methodName in Scala
- For example:
>>> _scalaMethod('+')
'$plus'
>>> _scalaMethod('>=')
@@ -2371,37 +2429,34 @@ def _scalaMethod(name):
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
-def _unary_op(name):
+def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
- return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+ jc = getattr(self._jc, _scalaMethod(name))()
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
return _
-def _bin_op(name, pass_literal_through=True):
+def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
-
- Keyword arguments:
- pass_literal_through -- whether to pass literal value directly through to the JVM.
"""
def _(self, other):
- if isinstance(other, Column):
- jc = other._jc
- else:
- if pass_literal_through:
- jc = other
- else:
- jc = _create_column_from_literal(other)
- return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+ jc = other._jc if isinstance(other, Column) else other
+ njc = getattr(self._jc, _scalaMethod(name))(jc)
+ return Column(njc, self.sql_ctx)
+ _.__doc__ = doc
return _
-def _reverse_op(name):
+def _reverse_op(name, doc="binary operator"):
""" Create a method for binary operator (this object is on right side)
"""
def _(self, other):
- return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
- self._jdf, self.sql_ctx)
+ jother = _create_column_from_literal(other)
+ jc = getattr(jother, _scalaMethod(name))(self._jc)
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
return _
@@ -2410,20 +2465,20 @@ class Column(DataFrame):
"""
A column in a DataFrame.
- `Column` instances can be created by:
- {{{
- // 1. Select a column out of a DataFrame
- df.colName
- df["colName"]
+ `Column` instances can be created by::
+
+ # 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
- // 2. Create from an expression
- df["colName"] + 1
- }}}
+ # 2. Create from an expression
+ df.colName + 1
+ 1 / df.colName
"""
- def __init__(self, jc, jdf=None, sql_ctx=None):
+ def __init__(self, jc, sql_ctx=None):
self._jc = jc
- super(Column, self).__init__(jdf, sql_ctx)
+ super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
__neg__ = _unary_op("unary_-")
@@ -2438,8 +2493,6 @@ class Column(DataFrame):
__rdiv__ = _reverse_op("/")
__rmod__ = _reverse_op("%")
__abs__ = _unary_op("abs")
- abs = _unary_op("abs")
- sqrt = _unary_op("sqrt")
# logistic operators
__eq__ = _bin_op("===")
@@ -2448,47 +2501,45 @@ class Column(DataFrame):
__le__ = _bin_op("<=")
__ge__ = _bin_op(">=")
__gt__ = _bin_op(">")
- # `and`, `or`, `not` cannot be overloaded in Python
- And = _bin_op('&&')
- Or = _bin_op('||')
- Not = _unary_op('unary_!')
-
- # bitwise operators
- __and__ = _bin_op("&")
- __or__ = _bin_op("|")
- __invert__ = _unary_op("unary_~")
- __xor__ = _bin_op("^")
- # __lshift__ = _bin_op("<<")
- # __rshift__ = _bin_op(">>")
- __rand__ = _bin_op("&")
- __ror__ = _bin_op("|")
- __rxor__ = _bin_op("^")
- # __rlshift__ = _reverse_op("<<")
- # __rrshift__ = _reverse_op(">>")
+
+ # `and`, `or`, `not` cannot be overloaded in Python,
+ # so use bitwise operators as boolean operators
+ __and__ = _bin_op('&&')
+ __or__ = _bin_op('||')
+ __invert__ = _unary_op('unary_!')
+ __rand__ = _bin_op("&&")
+ __ror__ = _bin_op("||")
# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")
- # __getattr__ = _bin_op("getField")
+ getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
# string methods
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
- upper = _unary_op("upper")
- lower = _unary_op("lower")
- def substr(self, startPos, pos):
- if type(startPos) != type(pos):
+ def substr(self, startPos, length):
+ """
+ Return a Column which is a substring of the column
+
+ :param startPos: start position (int or Column)
+ :param length: length of the substring (int or Column)
+
+ >>> df.name.substr(1, 3).collect()
+ [Row(col=u'Ali'), Row(col=u'Bob')]
+ """
+ if type(startPos) != type(length):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):
- jc = self._jc.substr(startPos, pos)
+ jc = self._jc.substr(startPos, length)
elif isinstance(startPos, Column):
- jc = self._jc.substr(startPos._jc, pos._jc)
+ jc = self._jc.substr(startPos._jc, length._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
- return Column(jc, self._jdf, self.sql_ctx)
+ return Column(jc, self.sql_ctx)
__getslice__ = substr
@@ -2496,55 +2547,89 @@ class Column(DataFrame):
asc = _unary_op("asc")
desc = _unary_op("desc")
- isNull = _unary_op("isNull")
- isNotNull = _unary_op("isNotNull")
+ isNull = _unary_op("isNull", "True if the current expression is null.")
+ isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
# `as` is keyword
def alias(self, alias):
- return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+ """Return a alias for this column
+
+ >>> df.age.As("age2").collect()
+ [Row(age2=2), Row(age2=5)]
+ >>> df.age.alias("age2").collect()
+ [Row(age2=2), Row(age2=5)]
+ """
+ return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+ As = alias
def cast(self, dataType):
+ """ Convert the column into type `dataType`
+
+ >>> df.select(df.age.cast("string").As('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ >>> df.select(df.age.cast(StringType()).As('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ """
if self.sql_ctx is None:
sc = SparkContext._active_spark_context
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
else:
ssql_ctx = self.sql_ctx._ssql_ctx
- jdt = ssql_ctx.parseDataType(dataType.json())
- return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+ if isinstance(dataType, basestring):
+ jc = self._jc.cast(dataType)
+ elif isinstance(dataType, DataType):
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ jc = self._jc.cast(jdt)
+ return Column(jc, self.sql_ctx)
-def _to_java_column(col):
- if isinstance(col, Column):
- jcol = col._jc
- else:
- jcol = _create_column_from_name(col)
- return jcol
-
-
-def _aggregate_func(name):
+def _aggregate_func(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)
-
+ _.__name__ = name
+ _.__doc__ = doc
return staticmethod(_)
-class Aggregator(object):
+class Dsl(object):
"""
A collections of builtin aggregators
"""
- AGGS = [
- 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
- 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
- ]
- for _name in AGGS:
- locals()[_name] = _aggregate_func(_name)
- del _name
+ DSLS = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+ }
+
+ for _name, _doc in DSLS.items():
+ locals()[_name] = _aggregate_func(_name, _doc)
+ del _name, _doc
@staticmethod
def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
+ [Row(c=2)]
+ """
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
@@ -2554,6 +2639,12 @@ class Aggregator(object):
@staticmethod
def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approxiate distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
+ [Row(c=2)]
+ """
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
@@ -2568,16 +2659,20 @@ def _test():
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
- from pyspark.tests import ExamplePoint, ExamplePointUDT
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlCtx'] = sqlCtx = SQLContext(sc)
globs['rdd'] = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
+ rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
+ globs['df'] = sqlCtx.inferSchema(rdd2)
+ globs['df2'] = sqlCtx.inferSchema(rdd3)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py
new file mode 100644
index 0000000000..d314f46e8d
--- /dev/null
+++ b/python/pyspark/sql_tests.py
@@ -0,0 +1,299 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for pyspark.sql; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import sys
+import pydoc
+import shutil
+import tempfile
+
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
+class SQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sqlCtx = SQLContext(cls.sc)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_udf2(self):
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
+ self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.assertEqual(4, res[0])
+
+ def test_udf_with_array_type(self):
+ d = [Row(l=range(3), d={"key": range(5)})]
+ rdd = self.sc.parallelize(d)
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(range(3), l1)
+ self.assertEqual(1, l2)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+ def test_basic_functions(self):
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
+
+ # cache and checkpoint
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
+
+ def test_apply_schema_to_row(self):
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
+
+ def test_serialize_nested_array_and_map(self):
+ d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
+ self.assertEqual(1, len(row.l))
+ self.assertEqual(1, row.l[0].a)
+ self.assertEqual("2", row.d["key"].d)
+
+ l = df.map(lambda x: x.l).first()
+ self.assertEqual(1, len(l))
+ self.assertEqual('s', l[0].b)
+
+ d = df.map(lambda x: x.d).first()
+ self.assertEqual(1, len(d))
+ self.assertEqual(1.0, d["key"].c)
+
+ row = df.map(lambda x: x.d["key"]).first()
+ self.assertEqual(1.0, row.c)
+ self.assertEqual("2", row.d)
+
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
+ result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
+ result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
+ def test_convert_row_to_dict(self):
+ row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+ self.assertEqual(1, row.asDict()['l'][0].a)
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+ def test_infer_schema_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbool = (ci & ci), (ci | ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbool))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+ from pyspark.sql import Dsl
+ self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+
+ def test_help_command(self):
+ # Regression test for SPARK-5464
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(df)
+ pydoc.render_doc(df.foo)
+ pydoc.render_doc(df.take(1))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c7d0622d65..b5e28c4980 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -23,7 +23,6 @@ from array import array
from fileinput import input
from glob import glob
import os
-import pydoc
import re
import shutil
import subprocess
@@ -52,8 +51,6 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
@@ -795,264 +792,6 @@ class ProfilerTests(PySparkTestCase):
rdd.foreach(heavy_foo)
-class ExamplePointUDT(UserDefinedType):
- """
- User-defined type (UDT) for ExamplePoint.
- """
-
- @classmethod
- def sqlType(self):
- return ArrayType(DoubleType(), False)
-
- @classmethod
- def module(cls):
- return 'pyspark.tests'
-
- @classmethod
- def scalaUDT(cls):
- return 'org.apache.spark.sql.test.ExamplePointUDT'
-
- def serialize(self, obj):
- return [obj.x, obj.y]
-
- def deserialize(self, datum):
- return ExamplePoint(datum[0], datum[1])
-
-
-class ExamplePoint:
- """
- An example class to demonstrate UDT in Scala, Java, and Python.
- """
-
- __UDT__ = ExamplePointUDT()
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def __repr__(self):
- return "ExamplePoint(%s,%s)" % (self.x, self.y)
-
- def __str__(self):
- return "(%s,%s)" % (self.x, self.y)
-
- def __eq__(self, other):
- return isinstance(other, ExamplePoint) and \
- other.x == self.x and other.y == self.y
-
-
-class SQLTests(ReusedPySparkTestCase):
-
- @classmethod
- def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
- cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(cls.tempdir.name)
-
- @classmethod
- def tearDownClass(cls):
- ReusedPySparkTestCase.tearDownClass()
- shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-
- def setUp(self):
- self.sqlCtx = SQLContext(self.sc)
- self.testData = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = self.sc.parallelize(self.testData)
- self.df = self.sqlCtx.inferSchema(rdd)
-
- def test_udf(self):
- self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
-
- def test_udf2(self):
- self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
- [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
- self.assertEqual(4, res[0])
-
- def test_udf_with_array_type(self):
- d = [Row(l=range(3), d={"key": range(5)})]
- rdd = self.sc.parallelize(d)
- self.sqlCtx.inferSchema(rdd).registerTempTable("test")
- self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
- self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
- [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
- self.assertEqual(range(3), l1)
- self.assertEqual(1, l2)
-
- def test_broadcast_in_udf(self):
- bar = {"a": "aa", "b": "bb", "c": "abc"}
- foo = self.sc.broadcast(bar)
- self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
- [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
- self.assertEqual("abc", res[0])
- [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
- self.assertEqual("", res[0])
-
- def test_basic_functions(self):
- rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.sqlCtx.jsonRDD(rdd)
- df.count()
- df.collect()
- df.schema()
-
- # cache and checkpoint
- self.assertFalse(df.is_cached)
- df.persist()
- df.unpersist()
- df.cache()
- self.assertTrue(df.is_cached)
- self.assertEqual(2, df.count())
-
- df.registerTempTable("temp")
- df = self.sqlCtx.sql("select foo from temp")
- df.count()
- df.collect()
-
- def test_apply_schema_to_row(self):
- df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
- self.assertEqual(df.collect(), df2.collect())
-
- rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.applySchema(rdd, df.schema())
- self.assertEqual(10, df3.count())
-
- def test_serialize_nested_array_and_map(self):
- d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
- row = df.head()
- self.assertEqual(1, len(row.l))
- self.assertEqual(1, row.l[0].a)
- self.assertEqual("2", row.d["key"].d)
-
- l = df.map(lambda x: x.l).first()
- self.assertEqual(1, len(l))
- self.assertEqual('s', l[0].b)
-
- d = df.map(lambda x: x.d).first()
- self.assertEqual(1, len(d))
- self.assertEqual(1.0, d["key"].c)
-
- row = df.map(lambda x: x.d["key"]).first()
- self.assertEqual(1.0, row.c)
- self.assertEqual("2", row.d)
-
- def test_infer_schema(self):
- d = [Row(l=[], d={}),
- Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
- self.assertEqual([], df.map(lambda r: r.l).first())
- self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
- df.registerTempTable("test")
- result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.head()[0])
-
- df2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(df.schema(), df2.schema())
- self.assertEqual({}, df2.map(lambda r: r.d).first())
- self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
- df2.registerTempTable("test2")
- result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.head()[0])
-
- def test_struct_in_map(self):
- d = [Row(m={Row(i=1): Row(s="")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
- k, v = df.head().m.items()[0]
- self.assertEqual(1, k.i)
- self.assertEqual("", v.s)
-
- def test_convert_row_to_dict(self):
- row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
- self.assertEqual(1, row.asDict()['l'][0].a)
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
- df.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").head()
- self.assertEqual(1, row.asDict()["l"][0].a)
- self.assertEqual(1.0, row.asDict()['d']['key'].c)
-
- def test_infer_schema_with_udt(self):
- from pyspark.tests import ExamplePoint, ExamplePointUDT
- row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
- schema = df.schema()
- field = [f for f in schema.fields if f.name == "point"][0]
- self.assertEqual(type(field.dataType), ExamplePointUDT)
- df.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
- self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
- def test_apply_schema_with_udt(self):
- from pyspark.tests import ExamplePoint, ExamplePointUDT
- row = (1.0, ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- schema = StructType([StructField("label", DoubleType(), False),
- StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.applySchema(rdd, schema)
- point = df.head().point
- self.assertEquals(point, ExamplePoint(1.0, 2.0))
-
- def test_parquet_with_udt(self):
- from pyspark.tests import ExamplePoint
- row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.inferSchema(rdd)
- output_dir = os.path.join(self.tempdir.name, "labeled_point")
- df0.saveAsParquetFile(output_dir)
- df1 = self.sqlCtx.parquetFile(output_dir)
- point = df1.head().point
- self.assertEquals(point, ExamplePoint(1.0, 2.0))
-
- def test_column_operators(self):
- from pyspark.sql import Column, LongType
- ci = self.df.key
- cs = self.df.value
- c = ci == cs
- self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
- rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
- self.assertTrue(all(isinstance(c, Column) for c in rcc))
- cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
- self.assertTrue(all(isinstance(c, Column) for c in cb))
- cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
- self.assertTrue(all(isinstance(c, Column) for c in cbit))
- css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
- self.assertTrue(all(isinstance(c, Column) for c in css))
- self.assertTrue(isinstance(ci.cast(LongType()), Column))
-
- def test_column_select(self):
- df = self.df
- self.assertEqual(self.testData, df.select("*").collect())
- self.assertEqual(self.testData, df.select(df.key, df.value).collect())
- self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
-
- def test_aggregator(self):
- df = self.df
- g = df.groupBy()
- self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
- self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
-
- from pyspark.sql import Aggregator as Agg
- self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
- self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
-
- def test_help_command(self):
- # Regression test for SPARK-5464
- rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.sqlCtx.jsonRDD(rdd)
- # render_doc() reproduces the help() exception without printing output
- pydoc.render_doc(df)
- pydoc.render_doc(df.foo)
- pydoc.render_doc(df.take(1))
-
-
class InputFormatTests(ReusedPySparkTestCase):
@classmethod
diff --git a/python/run-tests b/python/run-tests
index e91f1a875d..649a2c44d1 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -65,6 +65,7 @@ function run_core_tests() {
function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql.py"
+ run_test "pyspark/sql_tests.py"
}
function run_mllib_tests() {