From 068c0e2ee05ee8b133c2dc26b8fa094ab2712d45 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 3 Feb 2015 16:01:56 -0800 Subject: [SPARK-5554] [SQL] [PySpark] add more tests for DataFrame Python API Add more tests and docs for DataFrame Python API, improve test coverage, fix bugs. Author: Davies Liu Closes #4331 from davies/fix_df and squashes the following commits: dd9919f [Davies Liu] fix tests 467332c [Davies Liu] support string in cast() 83c92fe [Davies Liu] address comments c052f6f [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df 8dd19a9 [Davies Liu] fix tests in python 2.6 35ccb9f [Davies Liu] fix build 78ebcfa [Davies Liu] add sql_test.py in run_tests 9ab78b4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df 6040ba7 [Davies Liu] fix docs 3ab2661 [Davies Liu] add more tests for DataFrame --- python/pyspark/sql.py | 467 ++++++++++++++++++++++++++------------------ python/pyspark/sql_tests.py | 299 ++++++++++++++++++++++++++++ python/pyspark/tests.py | 261 ------------------------- python/run-tests | 1 + 4 files changed, 581 insertions(+), 447 deletions(-) create mode 100644 python/pyspark/sql_tests.py (limited to 'python') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 32bff0c7e8..268c7ef97c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -62,7 +62,7 @@ __all__ = [ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl", "SchemaRDD"] @@ -1804,7 +1804,7 @@ class DataFrame(object): people = sqlContext.parquetFile("...") Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: [[DataFrame]], [[Column]]. + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. To select a column from the data frame, use the apply method:: @@ -1835,8 +1835,10 @@ class DataFrame(object): @property def rdd(self): - """Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row`s. """ + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) @@ -1850,18 +1852,6 @@ class DataFrame(object): return self._lazy_rdd - def limit(self, num): - """Limit the result count to the number specified. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - def toJSON(self, use_unicode=False): """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. @@ -1886,7 +1876,6 @@ class DataFrame(object): >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) >>> df.saveAsParquetFile(parquetFile) >>> df2 = sqlCtx.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) @@ -1900,9 +1889,8 @@ class DataFrame(object): The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this DataFrame. - >>> df = sqlCtx.inferSchema(rdd) - >>> df.registerTempTable("test") - >>> df2 = sqlCtx.sql("select * from test") + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -1926,11 +1914,22 @@ class DataFrame(object): def schema(self): """Returns the schema of this DataFrame (represented by - a L{StructType}).""" + a L{StructType}). + + >>> df.schema() + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): - """Prints out the schema in the tree format.""" + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ print (self._jdf.schema().treeString()) def count(self): @@ -1940,11 +1939,8 @@ class DataFrame(object): leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> df = sqlCtx.inferSchema(rdd) >>> df.count() - 3L - >>> df.count() == df.map(lambda x: x).count() - True + 2L """ return self._jdf.count() @@ -1954,13 +1950,11 @@ class DataFrame(object): Each object in the list is a Row, the fields can be accessed as attributes. - >>> df = sqlCtx.inferSchema(rdd) >>> df.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: bytesInJava = self._jdf.javaToPython().collect().iterator() - cls = _create_cls(self.schema()) tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) tempFile.close() self._sc._writeToFile(bytesInJava, tempFile.name) @@ -1968,23 +1962,37 @@ class DataFrame(object): with open(tempFile.name, 'rb') as tempFile: rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) os.unlink(tempFile.name) + cls = _create_cls(self.schema()) return [cls(r) for r in rs] + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + def take(self, num): """Take the first num rows of the RDD. Each object in the list is a Row, the fields can be accessed as attributes. - >>> df = sqlCtx.inferSchema(rdd) >>> df.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ return self.limit(num).collect() def map(self, f): """ Return a new RDD by applying a function to each Row, it's a shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] """ return self.rdd.map(f) @@ -2067,140 +2075,167 @@ class DataFrame(object): @property def dtypes(self): """Return all column names and their data types as a list. + + >>> df.dtypes + [(u'age', 'IntegerType'), (u'name', 'StringType')] """ return [(f.name, str(f.dataType)) for f in self.schema().fields] @property def columns(self): """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] """ return [f.name for f in self.schema().fields] - def show(self): - raise NotImplemented - def join(self, other, joinExprs=None, joinType=None): """ Join with another DataFrame, using the given join expression. The following performs a full outer join between `df1` and `df2`:: - df1.join(df2, df1.key == df2.key, "outer") - :param other: Right side of the join :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, - `semijoin`. + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] """ - if joinType is None: - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - jdf = self._jdf.join(other._jdf, joinExprs) + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) else: - jdf = self._jdf.join(other._jdf, joinExprs, joinType) + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new [[DataFrame]] sorted by the specified column, - in ascending column. + """ Return a new :class:`DataFrame` sorted by the specified column. :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sortBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") - for i, c in enumerate(cols): - if isinstance(c, basestring): - cols[i] = Column(c) - jcols = [c._jc for c in cols] - jdf = self._jdf.join(*jcols) + jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(_to_java_column(cols[0]), + self._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) sortBy = sort def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. """ + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n) def first(self): - """ Return the first row. """ - return self.head() + """ Return the first row. - def tail(self): - raise NotImplemented + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() def __getitem__(self, item): + """ Return the column by given name + + >>> df['age'].collect() + [Row(age=2), Row(age=5)] + """ if isinstance(item, basestring): - return Column(self._jdf.apply(item)) + jc = self._jdf.apply(item) + return Column(jc, self.sql_ctx) # TODO projection raise IndexError def __getattr__(self, name): - """ Return the column by given name """ + """ Return the column by given name + + >>> df.age.collect() + [Row(age=2), Row(age=5)] + """ if name.startswith("__"): raise AttributeError(name) - return Column(self._jdf.apply(name)) - - def alias(self, name): - """ Alias the current DataFrame """ - return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + jc = self._jdf.apply(name) + return Column(jc, self.sql_ctx) def select(self, *cols): - """ Selecting a set of expressions.:: - - df.select() - df.select('colA', 'colB') - df.select(df.colA, df.colB + 1) - + """ Selecting a set of expressions. + + >>> df.select().collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).As('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ if not cols: cols = ["*"] - if isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition:: - - df.filter(df.age > 15) - df.where(df.age > 15) + """ Filtering rows using the given condition. + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] """ return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) where = filter def groupBy(self, *cols): - """ Group the [[DataFrame]] using the specified columns, + """ Group the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedDataFrame` - for all the available aggregate functions:: - - df.groupBy(df.department).avg() - df.groupBy("department", "gender").agg({ - "salary": "avg", - "age": "max", - }) + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] """ - if cols and isinstance(cols[0], basestring): - cols = [_create_column_from_name(n) for n in cols] - else: - cols = [c._jc for c in cols] - jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): - """ Aggregate on the entire [[DataFrame]] without groups - (shorthand for df.groupBy.agg()):: - - df.agg({"age": "max", "salary": "avg"}) + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=2)] """ return self.groupBy().agg(*exprs) @@ -2213,7 +2248,7 @@ class DataFrame(object): return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) def intersect(self, other): - """ Return a new [[DataFrame]] containing rows only in + """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. This is equivalent to `INTERSECT` in SQL. @@ -2221,7 +2256,7 @@ class DataFrame(object): return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) def subtract(self, other): - """ Return a new [[DataFrame]] containing rows in this frame + """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. This is equivalent to `EXCEPT` in SQL. @@ -2229,7 +2264,11 @@ class DataFrame(object): return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. """ + """ Return a new DataFrame by sampling a fraction of rows. + + >>> df.sample(False, 0.5, 10).collect() + [Row(age=2, name=u'Alice')] + """ if seed is None: jdf = self._jdf.sample(withReplacement, fraction) else: @@ -2237,11 +2276,12 @@ class DataFrame(object): return DataFrame(jdf, self.sql_ctx) def addColumn(self, colName, col): - """ Return a new [[DataFrame]] by adding a column. """ - return self.select('*', col.alias(colName)) + """ Return a new :class:`DataFrame` by adding a column. - def removeColumn(self, colName): - raise NotImplemented + >>> df.addColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.As(colName)) # Having SchemaRDD for backward compatibility (for docs) @@ -2280,7 +2320,14 @@ class GroupedDataFrame(object): `sum`, `count`. :param exprs: list or aggregate columns or a map from column - name to agregate methods. + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"age": "max"}).collect() + [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + >>> from pyspark.sql import Dsl + >>> gdf.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -2297,7 +2344,11 @@ class GroupedDataFrame(object): @dfapi def count(self): - """ Count the number of rows for each group. """ + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ @dfapi def mean(self): @@ -2349,18 +2400,25 @@ SCALA_METHOD_MAPPINGS = { def _create_column_from_literal(literal): sc = SparkContext._active_spark_context - return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) + return sc._jvm.Dsl.lit(literal) def _create_column_from_name(name): sc = SparkContext._active_spark_context - return sc._jvm.IncomputableColumn(name) + return sc._jvm.Dsl.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol def _scalaMethod(name): """ Translate operators into methodName in Scala - For example: >>> _scalaMethod('+') '$plus' >>> _scalaMethod('>=') @@ -2371,37 +2429,34 @@ def _scalaMethod(name): return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) -def _unary_op(name): +def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): - return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + jc = getattr(self._jc, _scalaMethod(name))() + return Column(jc, self.sql_ctx) + _.__doc__ = doc return _ -def _bin_op(name, pass_literal_through=True): +def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator - - Keyword arguments: - pass_literal_through -- whether to pass literal value directly through to the JVM. """ def _(self, other): - if isinstance(other, Column): - jc = other._jc - else: - if pass_literal_through: - jc = other - else: - jc = _create_column_from_literal(other) - return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, _scalaMethod(name))(jc) + return Column(njc, self.sql_ctx) + _.__doc__ = doc return _ -def _reverse_op(name): +def _reverse_op(name, doc="binary operator"): """ Create a method for binary operator (this object is on right side) """ def _(self, other): - return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), - self._jdf, self.sql_ctx) + jother = _create_column_from_literal(other) + jc = getattr(jother, _scalaMethod(name))(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc return _ @@ -2410,20 +2465,20 @@ class Column(DataFrame): """ A column in a DataFrame. - `Column` instances can be created by: - {{{ - // 1. Select a column out of a DataFrame - df.colName - df["colName"] + `Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + df.colName + df["colName"] - // 2. Create from an expression - df["colName"] + 1 - }}} + # 2. Create from an expression + df.colName + 1 + 1 / df.colName """ - def __init__(self, jc, jdf=None, sql_ctx=None): + def __init__(self, jc, sql_ctx=None): self._jc = jc - super(Column, self).__init__(jdf, sql_ctx) + super(Column, self).__init__(jc, sql_ctx) # arithmetic operators __neg__ = _unary_op("unary_-") @@ -2438,8 +2493,6 @@ class Column(DataFrame): __rdiv__ = _reverse_op("/") __rmod__ = _reverse_op("%") __abs__ = _unary_op("abs") - abs = _unary_op("abs") - sqrt = _unary_op("sqrt") # logistic operators __eq__ = _bin_op("===") @@ -2448,47 +2501,45 @@ class Column(DataFrame): __le__ = _bin_op("<=") __ge__ = _bin_op(">=") __gt__ = _bin_op(">") - # `and`, `or`, `not` cannot be overloaded in Python - And = _bin_op('&&') - Or = _bin_op('||') - Not = _unary_op('unary_!') - - # bitwise operators - __and__ = _bin_op("&") - __or__ = _bin_op("|") - __invert__ = _unary_op("unary_~") - __xor__ = _bin_op("^") - # __lshift__ = _bin_op("<<") - # __rshift__ = _bin_op(">>") - __rand__ = _bin_op("&") - __ror__ = _bin_op("|") - __rxor__ = _bin_op("^") - # __rlshift__ = _reverse_op("<<") - # __rrshift__ = _reverse_op(">>") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('&&') + __or__ = _bin_op('||') + __invert__ = _unary_op('unary_!') + __rand__ = _bin_op("&&") + __ror__ = _bin_op("||") # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - # __getattr__ = _bin_op("getField") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") # string methods rlike = _bin_op("rlike") like = _bin_op("like") startswith = _bin_op("startsWith") endswith = _bin_op("endsWith") - upper = _unary_op("upper") - lower = _unary_op("lower") - def substr(self, startPos, pos): - if type(startPos) != type(pos): + def substr(self, startPos, length): + """ + Return a Column which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.name.substr(1, 3).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): raise TypeError("Can not mix the type") if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, pos) + jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, pos._jc) + jc = self._jc.substr(startPos._jc, length._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self._jdf, self.sql_ctx) + return Column(jc, self.sql_ctx) __getslice__ = substr @@ -2496,55 +2547,89 @@ class Column(DataFrame): asc = _unary_op("asc") desc = _unary_op("desc") - isNull = _unary_op("isNull") - isNotNull = _unary_op("isNotNull") + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") # `as` is keyword def alias(self, alias): - return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + """Return a alias for this column + + >>> df.age.As("age2").collect() + [Row(age2=2), Row(age2=5)] + >>> df.age.alias("age2").collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + As = alias def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").As('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).As('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ if self.sql_ctx is None: sc = SparkContext._active_spark_context ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) else: ssql_ctx = self.sql_ctx._ssql_ctx - jdt = ssql_ctx.parseDataType(dataType.json()) - return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc, self.sql_ctx) -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _aggregate_func(name): +def _aggregate_func(name, doc=""): """ Create a function for aggregator by name""" def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) return Column(jc) - + _.__name__ = name + _.__doc__ = doc return staticmethod(_) -class Aggregator(object): +class Dsl(object): """ A collections of builtin aggregators """ - AGGS = [ - 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs', - 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct', - ] - for _name in AGGS: - locals()[_name] = _aggregate_func(_name) - del _name + DSLS = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + } + + for _name, _doc in DSLS.items(): + locals()[_name] = _aggregate_func(_name, _doc) + del _name, _doc @staticmethod def countDistinct(col, *cols): + """ Return a new Column for distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect() + [Row(c=2)] + """ sc = SparkContext._active_spark_context jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) @@ -2554,6 +2639,12 @@ class Aggregator(object): @staticmethod def approxCountDistinct(col, rsd=None): + """ Return a new Column for approxiate distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect() + [Row(c=2)] + """ sc = SparkContext._active_spark_context if rsd is None: jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) @@ -2568,16 +2659,20 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlCtx'] = sqlCtx = SQLContext(sc) globs['rdd'] = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) + globs['df'] = sqlCtx.inferSchema(rdd2) + globs['df2'] = sqlCtx.inferSchema(rdd3) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py new file mode 100644 index 0000000000..d314f46e8d --- /dev/null +++ b/python/pyspark/sql_tests.py @@ -0,0 +1,299 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType +from pyspark.tests import ReusedPySparkTestCase + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql_tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import Dsl + self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) + self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c7d0622d65..b5e28c4980 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,7 +23,6 @@ from array import array from fileinput import input from glob import glob import os -import pydoc import re import shutil import subprocess @@ -52,8 +51,6 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -795,264 +792,6 @@ class ProfilerTests(PySparkTestCase): rdd.foreach(heavy_foo) -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - self.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(self.testData) - self.df = self.sqlCtx.inferSchema(rdd) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - df.count() - df.collect() - df.schema() - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist() - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) - self.assertEqual(10, df3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(df.schema(), df2.schema()) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - k, v = df.head().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - schema = df.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) - point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_column_operators(self): - from pyspark.sql import Column, LongType - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbit)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import Aggregator as Agg - self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) - self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0]) - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - class InputFormatTests(ReusedPySparkTestCase): @classmethod diff --git a/python/run-tests b/python/run-tests index e91f1a875d..649a2c44d1 100755 --- a/python/run-tests +++ b/python/run-tests @@ -65,6 +65,7 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." run_test "pyspark/sql.py" + run_test "pyspark/sql_tests.py" } function run_mllib_tests() { -- cgit v1.2.3