aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-06 14:52:14 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-06 14:52:14 -0700
commit5aa804f3c6485670937a658ce8207c2317c6a506 (patch)
tree92757d7ad6e2195618d0d597a8be10712817e19f /python
parent16fc49617e1dfcbe9122b224f7f63b7bfddb36ce (diff)
downloadspark-5aa804f3c6485670937a658ce8207c2317c6a506.tar.gz
spark-5aa804f3c6485670937a658ce8207c2317c6a506.tar.bz2
spark-5aa804f3c6485670937a658ce8207c2317c6a506.zip
[SPARK-7639] [PYSPARK] [MLLIB] Python API for KernelDensity
Python API for KernelDensity Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6387 from MechCoder/spark-7639 and squashes the following commits: 17abc62 [MechCoder] add tests 2de6540 [MechCoder] style tests bf4acc0 [MechCoder] Added doctests 84359d5 [MechCoder] [SPARK-7639] Python API for KernelDensity
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/stat/KernelDensity.py61
-rw-r--r--python/pyspark/mllib/stat/__init__.py3
-rwxr-xr-xpython/run-tests1
3 files changed, 64 insertions, 1 deletions
diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py
new file mode 100644
index 0000000000..7da921976d
--- /dev/null
+++ b/python/pyspark/mllib/stat/KernelDensity.py
@@ -0,0 +1,61 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+if sys.version > '3':
+ xrange = range
+
+import numpy as np
+
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.rdd import RDD
+
+
+class KernelDensity(object):
+ """
+ .. note:: Experimental
+
+ Estimate probability density at required points given a RDD of samples
+ from the population.
+
+ >>> kd = KernelDensity()
+ >>> sample = sc.parallelize([0.0, 1.0])
+ >>> kd.setSample(sample)
+ >>> kd.estimate([0.0, 1.0])
+ array([ 0.12938758, 0.12938758])
+ """
+ def __init__(self):
+ self._bandwidth = 1.0
+ self._sample = None
+
+ def setBandwidth(self, bandwidth):
+ """Set bandwidth of each sample. Defaults to 1.0"""
+ self._bandwidth = bandwidth
+
+ def setSample(self, sample):
+ """Set sample points from the population. Should be a RDD"""
+ if not isinstance(sample, RDD):
+ raise TypeError("samples should be a RDD, received %s" % type(sample))
+ self._sample = sample
+
+ def estimate(self, points):
+ """Estimate the probability density at points"""
+ points = list(points)
+ densities = callMLlibFunc(
+ "estimateKernelDensity", self._sample, self._bandwidth, points)
+ return np.asarray(densities)
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
index e3e128513e..c8a721d3fe 100644
--- a/python/pyspark/mllib/stat/__init__.py
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -22,6 +22,7 @@ Python package for statistical functions in MLlib.
from pyspark.mllib.stat._statistics import *
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.stat.test import ChiSqTestResult
+from pyspark.mllib.stat.KernelDensity import KernelDensity
__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult",
- "MultivariateGaussian"]
+ "MultivariateGaussian", "KernelDensity"]
diff --git a/python/run-tests b/python/run-tests
index 17dda3eada..4468fdb3f2 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -93,6 +93,7 @@ function run_mllib_tests() {
run_test "pyspark.mllib.recommendation"
run_test "pyspark.mllib.regression"
run_test "pyspark.mllib.stat._statistics"
+ run_test "pyspark.mllib.stat.KernelDensity"
run_test "pyspark.mllib.tree"
run_test "pyspark.mllib.util"
run_test "pyspark.mllib.tests"