From 6481d27425f6d42ead36663c9a4ef7ee13b3a8c9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 11:49:45 -0700 Subject: [SPARK-3309] [PySpark] Put all public API in __all__ Put all public API in __all__, also put them all in pyspark.__init__.py, then we can got all the documents for public API by `pydoc pyspark`. It also can be used by other programs (such as Sphinx or Epydoc) to generate only documents for public APIs. Author: Davies Liu Closes #2205 from davies/public and squashes the following commits: c6c5567 [Davies Liu] fix message f7b35be [Davies Liu] put SchemeRDD, Row in pyspark.sql module 7e3016a [Davies Liu] add __all__ in mllib 6281b48 [Davies Liu] fix doc for SchemaRDD 6caab21 [Davies Liu] add public interfaces into pyspark.__init__.py --- python/pyspark/__init__.py | 14 +++++++++----- python/pyspark/accumulators.py | 3 +++ python/pyspark/broadcast.py | 24 ++++++++++++++---------- python/pyspark/conf.py | 2 ++ python/pyspark/context.py | 3 +++ python/pyspark/files.py | 3 +++ python/pyspark/mllib/classification.py | 4 ++++ python/pyspark/mllib/clustering.py | 2 ++ python/pyspark/mllib/linalg.py | 3 +++ python/pyspark/mllib/random.py | 3 +++ python/pyspark/mllib/recommendation.py | 2 ++ python/pyspark/mllib/regression.py | 10 +++++----- python/pyspark/mllib/stat.py | 6 ++++-- python/pyspark/mllib/tree.py | 4 ++++ python/pyspark/rdd.py | 1 + python/pyspark/serializers.py | 2 +- python/pyspark/sql.py | 21 ++++++++++++++++++--- 17 files changed, 81 insertions(+), 26 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c58555fc9d..1a2e774738 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -61,13 +61,17 @@ sys.path.insert(0, s) from pyspark.conf import SparkConf from pyspark.context import SparkContext -from pyspark.sql import SQLContext from pyspark.rdd import RDD -from pyspark.sql import SchemaRDD -from pyspark.sql import Row from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel +from pyspark.accumulators import Accumulator, AccumulatorParam +from pyspark.broadcast import Broadcast +from pyspark.serializers import MarshalSerializer, PickleSerializer +# for back compatibility +from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row -__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", - "SparkFiles", "StorageLevel", "Row"] +__all__ = [ + "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", + "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", +] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f133cf6f7b..ccbca67656 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,6 +94,9 @@ from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer +__all__ = ['Accumulator', 'AccumulatorParam'] + + pickleSer = PickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 675a2fcd2f..5c7c9cc161 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -31,6 +31,10 @@ import os from pyspark.serializers import CompressedSerializer, PickleSerializer + +__all__ = ['Broadcast'] + + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -59,11 +63,20 @@ class Broadcast(object): """ self.bid = bid if path is None: - self.value = value + self._value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry self.path = path + @property + def value(self): + """ Return the broadcasted value + """ + if not hasattr(self, "_value") and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + self._value = ser.load_stream(open(self.path)).next() + return self._value + def unpersist(self, blocking=False): self._jbroadcast.unpersist(blocking) os.unlink(self.path) @@ -72,15 +85,6 @@ class Broadcast(object): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - def __getattr__(self, item): - if item == 'value' and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) - value = ser.load_stream(open(self.path)).next() - self.value = value - return value - - raise AttributeError(item) - if __name__ == "__main__": import doctest diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index fb716f6753..b64875a3f4 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -54,6 +54,8 @@ spark.home=/path (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] """ +__all__ = ['SparkConf'] + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6e4fdaa6ee..5a30431568 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -37,6 +37,9 @@ from pyspark.rdd import RDD from py4j.java_collections import ListConverter +__all__ = ['SparkContext'] + + # These are special default configs for PySpark, they will overwrite # the default ones for Spark if they are not configured by user. DEFAULT_CONFIGS = { diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 331de9a9b2..797573f49d 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -18,6 +18,9 @@ import os +__all__ = ['SparkFiles'] + + class SparkFiles(object): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ffdda7ee19..71ab46b61d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -30,6 +30,10 @@ from pyspark.mllib.regression import LabeledPoint, LinearModel from math import exp, log +__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', + 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + + class LogisticRegressionModel(LinearModel): """A linear binary classification model derived from logistic regression. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a0630d1d5c..f3e952a1d8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -25,6 +25,8 @@ from pyspark.mllib._common import \ _get_initial_weights, _serialize_rating, _regression_train_wrapper from pyspark.mllib.linalg import SparseVector +__all__ = ['KMeansModel', 'KMeans'] + class KMeansModel(object): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f485a69db1..e69051c104 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -27,6 +27,9 @@ import numpy from numpy import array, array_equal, ndarray, float64, int32 +__all__ = ['SparseVector', 'Vectors'] + + class SparseVector(object): """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 4dc1a4a912..3e59c73db8 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,6 +25,9 @@ from pyspark.mllib._common import _deserialize_double, _deserialize_double_vecto from pyspark.serializers import NoOpSerializer +__all__ = ['RandomRDDs', ] + + class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e863fc249e..2df23394da 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -24,6 +24,8 @@ from pyspark.mllib._common import \ _serialize_tuple, RatingDeserializer from pyspark.rdd import RDD +__all__ = ['MatrixFactorizationModel', 'ALS'] + class MatrixFactorizationModel(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index d8792cf448..f572dcfb84 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -17,15 +17,15 @@ from numpy import array, ndarray from pyspark import SparkContext -from pyspark.mllib._common import \ - _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ +from pyspark.mllib._common import _dot, _regression_train_wrapper, \ _linear_predictor_typecheck, _have_scipy, _scipy_issparse from pyspark.mllib.linalg import SparseVector, Vectors +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' + 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] + + class LabeledPoint(object): """ diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index feef0d16cd..8c726f171c 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -21,8 +21,10 @@ Python package for statistical functions in MLlib. from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ - _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + _serialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(object): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e9d778df5a..a2fade61e9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -26,6 +26,9 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer +__all__ = ['DecisionTreeModel', 'DecisionTree'] + + class DecisionTreeModel(object): """ @@ -88,6 +91,7 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: + >>> from numpy import array >>> import sys >>> from pyspark.mllib.regression import LabeledPoint diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6fc9f66bc5..dff6fc26fc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,6 +48,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ from py4j.java_collections import ListConverter, MapConverter + __all__ = ["RDD"] diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fc49aa42db..55e6cf3308 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -409,7 +409,7 @@ class AutoSerializer(FramedSerializer): class CompressedSerializer(FramedSerializer): """ - compress the serialized data + Compress the serialized data """ def __init__(self, serializer): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0ff6a548a8..44316926ba 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -40,8 +40,7 @@ __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", - "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "SchemaRDD", "Row"] class DataType(object): @@ -1037,7 +1036,7 @@ class SQLContext: "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.Row instead") + "please use pyspark.sql.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -1487,6 +1486,21 @@ class Row(tuple): return "" % ", ".join(self) +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + +@inherit_doc class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. @@ -1563,6 +1577,7 @@ class SchemaRDD(RDD): self._jschema_rdd.registerTempTable(name) def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) -- cgit v1.2.3