# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys import itertools import warnings import random from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import * from pyspark.sql.types import _create_cls, _parse_datatype_json_string __all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] class DataFrame(object): """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: people = sqlContext.parquetFile("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. To select a column from the data frame, use the apply method:: ageCol = people.age A more concrete example:: # To create DataFrame using SQLContext people = sqlContext.parquetFile("...") department = sqlContext.parquetFile("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily @property def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) schema = self.schema def applySchema(it): cls = _create_cls(schema) return itertools.imap(cls, it) self._lazy_rdd = rdd.mapPartitions(applySchema) return self._lazy_rdd @property def na(self): """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) def toJSON(self, use_unicode=False): """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as a :class:`DataFrame` using :func:`SQLContext.parquetFile`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) True """ self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the :class:`SQLContext` that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use :func:`registerTempTable` instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. Optionally overwriting any existing data. """ self._jdf.insertInto(tableName, overwrite) def _java_save_mode(self, mode): """Returns the Java save mode based on the Python save mode represented by a string. """ jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode jmode = jSaveMode.ErrorIfExists mode = mode.lower() if mode == "append": jmode = jSaveMode.Append elif mode == "overwrite": jmode = jSaveMode.Overwrite elif mode == "ignore": jmode = jSaveMode.Ignore elif mode == "error": pass else: raise ValueError( "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") return jmode def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) joptions = MapConverter().convert(options, self.sql_ctx._sc._gateway._gateway_client) self._jdf.saveAsTable(tableName, source, jmode, joptions) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. """ if path is not None: options["path"] = path if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) joptions = MapConverter().convert(options, self._sc._gateway._gateway_client) self._jdf.save(source, jmode, joptions) @property def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ if self._schema is None: self._schema = _parse_datatype_json_string(self._jdf.schema().json()) return self._schema def printSchema(self): """Prints out the schema in the tree format. >>> df.printSchema() root |-- age: integer (nullable = true) |-- name: string (nullable = true) """ print (self._jdf.schema().treeString()) def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... >>> df.explain(True) == Parsed Logical Plan == ... == Analyzed Logical Plan == ... == Optimized Logical Plan == ... == Physical Plan == ... == RDD == """ if extended: print self._jdf.queryExecution().toString() else: print self._jdf.queryExecution().executedPlan().toString() def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() def show(self, n=20): """Prints the first ``n`` rows to the console. >>> df DataFrame[age: int, name: string] >>> df.show() age name 2 Alice 5 Bob """ print self._jdf.showString(n).encode('utf8', 'ignore') def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) def count(self): """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2L """ return self._jdf.count() def collect(self): """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) cls = _create_cls(self.schema) return [cls(r) for r in rs] def limit(self, num): """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] >>> df.limit(0).collect() [] """ jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx) def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ return self.limit(num).collect() def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. This is a shorthand for ``df.rdd.map()``. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] """ return self.rdd.map(f) def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. This is a shorthand for ``df.rdd.flatMap()``. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] """ return self.rdd.flatMap(f) def mapPartitions(self, f, preservesPartitioning=False): """Returns a new :class:`RDD` by applying the ``f`` function to each partition. This is a shorthand for ``df.rdd.mapPartitions()``. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 >>> rdd.mapPartitions(f).sum() 4 """ return self.rdd.mapPartitions(f, preservesPartitioning) def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print person.name >>> df.foreach(f) """ return self.rdd.foreach(f) def foreachPartition(self, f): """Applies the ``f`` function to each partition of this :class:`DataFrame`. This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: ... print person.name >>> df.foreachPartition(f) """ return self.rdd.foreachPartition(f) def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. """ self.is_cached = False self._jdf.unpersist(blocking) return self # def coalesce(self, numPartitions, shuffle=False): # rdd = self._jdf.coalesce(numPartitions, shuffle, None) # return DataFrame(rdd, self.sql_ctx) def repartition(self, numPartitions): """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. >>> df.repartition(10).rdd.getNumPartitions() 10 """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2L """ return DataFrame(self._jdf.distinct(), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() 1L """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxint) rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) @property def dtypes(self): """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] """ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property def columns(self): """Returns all column names as a list. >>> df.columns [u'age', u'name'] """ return [f.name for f in self.schema.fields] def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param joinExprs: Join expression :param joinType: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: jdf = self._jdf.join(other._jdf) else: assert isinstance(joinExprs, Column), "joinExprs should be Column" if joinType is None: jdf = self._jdf.join(other._jdf, joinExprs._jc) else: assert isinstance(joinType, basestring), "joinType should be basestring" jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): """Returns a new :class:`DataFrame` sorted by the specified column(s). :param cols: list of :class:`Column` to sort by. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * >>> df.sort(asc("age")).collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) orderBy = sort def describe(self, *cols): """Computes statistics for numeric columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. >>> df.describe().show() summary age count 2 mean 3.5 stddev 1.5 min 2 max 5 """ cols = ListConverter().convert(cols, self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) return DataFrame(jdf, self.sql_ctx) def head(self, n=None): """ Returns the first ``n`` rows as a list of :class:`Row`, or the first :class:`Row` if ``n`` is ``None.`` >>> df.head() Row(age=2, name=u'Alice') >>> df.head(1) [Row(age=2, name=u'Alice')] """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n) def first(self): """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') """ return self.head() def __getitem__(self, item): """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, list): return self.select(*item) else: raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ if name.startswith("__"): raise AttributeError(name) jc = self._jdf.apply(name) return Column(jc) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. :param cols: list of column names (string) or expressions (:class:`Column`). If one of the column names is '*', that column is expanded to include all columns in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): """Projects a set of SQL expressions and returns a new :class:`DataFrame`. This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] """ jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. :param condition: a :class:`Column` of :class:`types.BooleanType` or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] >>> df.filter("age > 3").collect() [Row(age=5, name=u'Bob')] >>> df.where("age = 2").collect() [Row(age=2, name=u'Alice')] """ if isinstance(condition, basestring): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): jdf = self._jdf.filter(condition._jc) else: raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx) where = filter def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return GroupedData(jdf, self.sql_ctx) def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(MAX(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() [Row(MIN(age)=2)] """ return self.groupBy().agg(*exprs) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. This is equivalent to `UNION ALL` in SQL. """ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) def intersect(self, other): """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. This is equivalent to `INTERSECT` in SQL. """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. This is equivalent to `EXCEPT` in SQL. """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. This is an alias for ``na.drop()``. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. If 'all', drop a row only if all its values are null. :param thresh: int, default None If specified, drop rows that have less than `thresh` non-null values. This overwrites the `how` parameter. :param subset: optional list of column names to consider. >>> df4.dropna().show() age height name 10 80 Alice >>> df4.na.drop().show() age height name 10 80 Alice """ if how is not None and how not in ['any', 'all']: raise ValueError("how ('" + how + "') should be 'any' or 'all'") if subset is None: subset = self.columns elif isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") if thresh is None: thresh = len(subset) if how == 'any' else 1 cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. :param value: int, long, float, string, or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be an int, long, float, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. >>> df4.fillna(50).show() age height name 10 80 Alice 5 50 Bob 50 50 Tom 50 50 null >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() age height name 10 80 Alice 5 null Bob 50 null Tom 50 null unknown >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() age height name 10 80 Alice 5 null Bob 50 null Tom 50 null unknown """ if not isinstance(value, (float, int, long, basestring, dict)): raise ValueError("value should be a float, int, long, string, or dict") if isinstance(value, (int, long)): value = float(value) if isinstance(value, dict): value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) else: if isinstance(subset, basestring): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ return self.select('*', col.alias(colName)) def withColumnRenamed(self, existing, new): """REturns a new :class:`DataFrame` by renaming an existing column. :param existing: string, name of the existing column to rename. :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ cols = [Column(_to_java_column(c)).alias(new) if c == existing else c for c in self.columns] return self.select(*cols) def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """ import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): """SchemaRDD is deprecated, please use :class:`DataFrame`. """ def dfapi(f): def _api(self): name = f.__name__ jdf = getattr(self._jdf, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ return _api def df_varargs_api(f): def _api(self, *args): jargs = ListConverter().convert(args, self.sql_ctx._sc._gateway._gateway_client) name = f.__name__ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ return _api class GroupedData(object): """ A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. """ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): jmap = MapConverter().convert(exprs[0], self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.agg(jmap) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" jcols = ListConverter().convert([c._jc for c in exprs[1:]], self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) @dfapi def count(self): """Counts the number of records for each group. >>> df.groupBy(df.age).count().collect() [Row(age=2, count=1), Row(age=5, count=1)] """ @df_varargs_api def mean(self, *cols): """Computes average values for each numeric columns for each group. :func:`mean` is an alias for :func:`avg`. :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() [Row(AVG(age)=3.5)] >>> df3.groupBy().mean('age', 'height').collect() [Row(AVG(age)=3.5, AVG(height)=82.5)] """ @df_varargs_api def avg(self, *cols): """Computes average values for each numeric columns for each group. :func:`mean` is an alias for :func:`avg`. :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() [Row(AVG(age)=3.5)] >>> df3.groupBy().avg('age', 'height').collect() [Row(AVG(age)=3.5, AVG(height)=82.5)] """ @df_varargs_api def max(self, *cols): """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() [Row(MAX(age)=5)] >>> df3.groupBy().max('age', 'height').collect() [Row(MAX(age)=5, MAX(height)=85)] """ @df_varargs_api def min(self, *cols): """Computes the min value for each numeric column for each group. :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() [Row(MIN(age)=2)] >>> df3.groupBy().min('age', 'height').collect() [Row(MIN(age)=2, MIN(height)=80)] """ @df_varargs_api def sum(self, *cols): """Compute the sum for each numeric columns for each group. :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() [Row(SUM(age)=7)] >>> df3.groupBy().sum('age', 'height').collect() [Row(SUM(age)=7, SUM(height)=165)] """ def _create_column_from_literal(literal): sc = SparkContext._active_spark_context return sc._jvm.functions.lit(literal) def _create_column_from_name(name): sc = SparkContext._active_spark_context return sc._jvm.functions.col(name) def _to_java_column(col): if isinstance(col, Column): jcol = col._jc else: jcol = _create_column_from_name(col) return jcol def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): jc = getattr(self._jc, name)() return Column(jc) _.__doc__ = doc return _ def _func_op(name, doc=''): def _(self): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(self._jc) return Column(jc) _.__doc__ = doc return _ def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ def _(self, other): jc = other._jc if isinstance(other, Column) else other njc = getattr(self._jc, name)(jc) return Column(njc) _.__doc__ = doc return _ def _reverse_op(name, doc="binary operator"): """ Create a method for binary operator (this object is on right side) """ def _(self, other): jother = _create_column_from_literal(other) jc = getattr(jother, name)(self._jc) return Column(jc) _.__doc__ = doc return _ class Column(object): """ A column in a DataFrame. :class:`Column` instances can be created by:: # 1. Select a column out of a DataFrame df.colName df["colName"] # 2. Create from an expression df.colName + 1 1 / df.colName """ def __init__(self, jc): self._jc = jc # arithmetic operators __neg__ = _func_op("negate") __add__ = _bin_op("plus") __sub__ = _bin_op("minus") __mul__ = _bin_op("multiply") __div__ = _bin_op("divide") __mod__ = _bin_op("mod") __radd__ = _bin_op("plus") __rsub__ = _reverse_op("minus") __rmul__ = _bin_op("multiply") __rdiv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") # logistic operators __eq__ = _bin_op("equalTo") __ne__ = _bin_op("notEqual") __lt__ = _bin_op("lt") __le__ = _bin_op("leq") __ge__ = _bin_op("geq") __gt__ = _bin_op("gt") # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators __and__ = _bin_op('and') __or__ = _bin_op('or') __invert__ = _func_op('not') __rand__ = _bin_op("and") __ror__ = _bin_op("or") # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") # string methods rlike = _bin_op("rlike") like = _bin_op("like") startswith = _bin_op("startsWith") endswith = _bin_op("endsWith") def substr(self, startPos, length): """ Return a :class:`Column` which is a substring of the column :param startPos: start position (int or Column) :param length: length of the substring (int or Column) >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col=u'Ali'), Row(col=u'Bob')] """ if type(startPos) != type(length): raise TypeError("Can not mix the type") if isinstance(startPos, (int, long)): jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): jc = self._jc.substr(startPos._jc, length._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) return Column(jc) __getslice__ = substr def inSet(self, *cols): """ A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. >>> df[df.name.inSet("Bob", "Mike")].collect() [Row(age=5, name=u'Bob')] >>> df[df.age.inSet([1, 2, 3])].collect() [Row(age=2, name=u'Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context jcols = ListConverter().convert(cols, sc._gateway._gateway_client) jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) return Column(jc) # order asc = _unary_op("asc", "Returns a sort expression based on the" " ascending order of the given column name.") desc = _unary_op("desc", "Returns a sort expression based on the" " descending order of the given column name.") isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") def alias(self, alias): """Return a alias for this column >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ return Column(getattr(self._jc, "as")(alias)) def cast(self, dataType): """ Convert the column into type `dataType` >>> df.select(df.age.cast("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] """ if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): sc = SparkContext._active_spark_context ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. """ def __init__(self, df): self.df = df def drop(self, how='any', thresh=None, subset=None): return self.df.dropna(how=how, thresh=thresh, subset=subset) drop.__doc__ = DataFrame.dropna.__doc__ def fill(self, value, subset=None): return self.df.fillna(value=value, subset=subset) fill.__doc__ = DataFrame.fillna.__doc__ def _test(): import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.dataframe globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: exit(-1) if __name__ == "__main__": _test()