aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-03 11:49:45 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-03 11:49:45 -0700
commit6481d27425f6d42ead36663c9a4ef7ee13b3a8c9 (patch)
tree051c394c0735be33d4bb7f9fd90f403e9b5f2dcd /python
parent6a72a36940311fcb3429bd34c8818bc7d513115c (diff)
downloadspark-6481d27425f6d42ead36663c9a4ef7ee13b3a8c9.tar.gz
spark-6481d27425f6d42ead36663c9a4ef7ee13b3a8c9.tar.bz2
spark-6481d27425f6d42ead36663c9a4ef7ee13b3a8c9.zip
[SPARK-3309] [PySpark] Put all public API in __all__
Put all public API in __all__, also put them all in pyspark.__init__.py, then we can got all the documents for public API by `pydoc pyspark`. It also can be used by other programs (such as Sphinx or Epydoc) to generate only documents for public APIs. Author: Davies Liu <davies.liu@gmail.com> Closes #2205 from davies/public and squashes the following commits: c6c5567 [Davies Liu] fix message f7b35be [Davies Liu] put SchemeRDD, Row in pyspark.sql module 7e3016a [Davies Liu] add __all__ in mllib 6281b48 [Davies Liu] fix doc for SchemaRDD 6caab21 [Davies Liu] add public interfaces into pyspark.__init__.py
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/__init__.py14
-rw-r--r--python/pyspark/accumulators.py3
-rw-r--r--python/pyspark/broadcast.py24
-rw-r--r--python/pyspark/conf.py2
-rw-r--r--python/pyspark/context.py3
-rw-r--r--python/pyspark/files.py3
-rw-r--r--python/pyspark/mllib/classification.py4
-rw-r--r--python/pyspark/mllib/clustering.py2
-rw-r--r--python/pyspark/mllib/linalg.py3
-rw-r--r--python/pyspark/mllib/random.py3
-rw-r--r--python/pyspark/mllib/recommendation.py2
-rw-r--r--python/pyspark/mllib/regression.py10
-rw-r--r--python/pyspark/mllib/stat.py6
-rw-r--r--python/pyspark/mllib/tree.py4
-rw-r--r--python/pyspark/rdd.py1
-rw-r--r--python/pyspark/serializers.py2
-rw-r--r--python/pyspark/sql.py21
17 files changed, 81 insertions, 26 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index c58555fc9d..1a2e774738 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -61,13 +61,17 @@ sys.path.insert(0, s)
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
-from pyspark.sql import SQLContext
from pyspark.rdd import RDD
-from pyspark.sql import SchemaRDD
-from pyspark.sql import Row
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
+from pyspark.accumulators import Accumulator, AccumulatorParam
+from pyspark.broadcast import Broadcast
+from pyspark.serializers import MarshalSerializer, PickleSerializer
+# for back compatibility
+from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
-__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD",
- "SparkFiles", "StorageLevel", "Row"]
+__all__ = [
+ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
+ "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
+]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index f133cf6f7b..ccbca67656 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -94,6 +94,9 @@ from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, PickleSerializer
+__all__ = ['Accumulator', 'AccumulatorParam']
+
+
pickleSer = PickleSerializer()
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 675a2fcd2f..5c7c9cc161 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -31,6 +31,10 @@ import os
from pyspark.serializers import CompressedSerializer, PickleSerializer
+
+__all__ = ['Broadcast']
+
+
# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}
@@ -59,11 +63,20 @@ class Broadcast(object):
"""
self.bid = bid
if path is None:
- self.value = value
+ self._value = value
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
self.path = path
+ @property
+ def value(self):
+ """ Return the broadcasted value
+ """
+ if not hasattr(self, "_value") and self.path is not None:
+ ser = CompressedSerializer(PickleSerializer())
+ self._value = ser.load_stream(open(self.path)).next()
+ return self._value
+
def unpersist(self, blocking=False):
self._jbroadcast.unpersist(blocking)
os.unlink(self.path)
@@ -72,15 +85,6 @@ class Broadcast(object):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))
- def __getattr__(self, item):
- if item == 'value' and self.path is not None:
- ser = CompressedSerializer(PickleSerializer())
- value = ser.load_stream(open(self.path)).next()
- self.value = value
- return value
-
- raise AttributeError(item)
-
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index fb716f6753..b64875a3f4 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -54,6 +54,8 @@ spark.home=/path
(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
"""
+__all__ = ['SparkConf']
+
class SparkConf(object):
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6e4fdaa6ee..5a30431568 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -37,6 +37,9 @@ from pyspark.rdd import RDD
from py4j.java_collections import ListConverter
+__all__ = ['SparkContext']
+
+
# These are special default configs for PySpark, they will overwrite
# the default ones for Spark if they are not configured by user.
DEFAULT_CONFIGS = {
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index 331de9a9b2..797573f49d 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -18,6 +18,9 @@
import os
+__all__ = ['SparkFiles']
+
+
class SparkFiles(object):
"""
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index ffdda7ee19..71ab46b61d 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -30,6 +30,10 @@ from pyspark.mllib.regression import LabeledPoint, LinearModel
from math import exp, log
+__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
+ 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+
+
class LogisticRegressionModel(LinearModel):
"""A linear binary classification model derived from logistic regression.
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index a0630d1d5c..f3e952a1d8 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -25,6 +25,8 @@ from pyspark.mllib._common import \
_get_initial_weights, _serialize_rating, _regression_train_wrapper
from pyspark.mllib.linalg import SparseVector
+__all__ = ['KMeansModel', 'KMeans']
+
class KMeansModel(object):
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index f485a69db1..e69051c104 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -27,6 +27,9 @@ import numpy
from numpy import array, array_equal, ndarray, float64, int32
+__all__ = ['SparseVector', 'Vectors']
+
+
class SparseVector(object):
"""
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 4dc1a4a912..3e59c73db8 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -25,6 +25,9 @@ from pyspark.mllib._common import _deserialize_double, _deserialize_double_vecto
from pyspark.serializers import NoOpSerializer
+__all__ = ['RandomRDDs', ]
+
+
class RandomRDDs:
"""
Generator methods for creating RDDs comprised of i.i.d samples from
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index e863fc249e..2df23394da 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -24,6 +24,8 @@ from pyspark.mllib._common import \
_serialize_tuple, RatingDeserializer
from pyspark.rdd import RDD
+__all__ = ['MatrixFactorizationModel', 'ALS']
+
class MatrixFactorizationModel(object):
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index d8792cf448..f572dcfb84 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -17,15 +17,15 @@
from numpy import array, ndarray
from pyspark import SparkContext
-from pyspark.mllib._common import \
- _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
- _serialize_double_matrix, _deserialize_double_matrix, \
- _serialize_double_vector, _deserialize_double_vector, \
- _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+from pyspark.mllib._common import _dot, _regression_train_wrapper, \
_linear_predictor_typecheck, _have_scipy, _scipy_issparse
from pyspark.mllib.linalg import SparseVector, Vectors
+__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel'
+ 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
+
+
class LabeledPoint(object):
"""
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index feef0d16cd..8c726f171c 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -21,8 +21,10 @@ Python package for statistical functions in MLlib.
from pyspark.mllib._common import \
_get_unmangled_double_vector_rdd, _get_unmangled_rdd, \
- _serialize_double, _serialize_double_vector, \
- _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector
+ _serialize_double, _deserialize_double_matrix, _deserialize_double_vector
+
+
+__all__ = ['MultivariateStatisticalSummary', 'Statistics']
class MultivariateStatisticalSummary(object):
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index e9d778df5a..a2fade61e9 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -26,6 +26,9 @@ from pyspark.mllib.regression import LabeledPoint
from pyspark.serializers import NoOpSerializer
+__all__ = ['DecisionTreeModel', 'DecisionTree']
+
+
class DecisionTreeModel(object):
"""
@@ -88,6 +91,7 @@ class DecisionTree(object):
It will probably be modified for Spark v1.2.
Example usage:
+
>>> from numpy import array
>>> import sys
>>> from pyspark.mllib.regression import LabeledPoint
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6fc9f66bc5..dff6fc26fc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -48,6 +48,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
from py4j.java_collections import ListConverter, MapConverter
+
__all__ = ["RDD"]
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fc49aa42db..55e6cf3308 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -409,7 +409,7 @@ class AutoSerializer(FramedSerializer):
class CompressedSerializer(FramedSerializer):
"""
- compress the serialized data
+ Compress the serialized data
"""
def __init__(self, serializer):
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 0ff6a548a8..44316926ba 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -40,8 +40,7 @@ __all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext",
- "SchemaRDD", "Row"]
+ "SQLContext", "HiveContext", "SchemaRDD", "Row"]
class DataType(object):
@@ -1037,7 +1036,7 @@ class SQLContext:
"can not infer schema")
if type(first) is dict:
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.Row instead")
+ "please use pyspark.sql.Row instead")
schema = _infer_schema(first)
rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
@@ -1487,6 +1486,21 @@ class Row(tuple):
return "<Row(%s)>" % ", ".join(self)
+def inherit_doc(cls):
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
+
+
+@inherit_doc
class SchemaRDD(RDD):
"""An RDD of L{Row} objects that has an associated schema.
@@ -1563,6 +1577,7 @@ class SchemaRDD(RDD):
self._jschema_rdd.registerTempTable(name)
def registerAsTable(self, name):
+ """DEPRECATED: use registerTempTable() instead"""
warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
self.registerTempTable(name)