aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorYun Ni <yunn@uber.com>2017-02-15 16:26:05 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-02-15 16:26:05 -0800
commit08c1972a0661d42f300520cc6e5fb31023de093b (patch)
tree8b392b4520df66ca32834c11fc376009be70e8b8 /examples/src/main/java
parent21b4ba2d6f21a9759af879471715c123073bd67a (diff)
downloadspark-08c1972a0661d42f300520cc6e5fb31023de093b.tar.gz
spark-08c1972a0661d42f300520cc6e5fb31023de093b.tar.bz2
spark-08c1972a0661d42f300520cc6e5fb31023de093b.zip
[SPARK-18080][ML][PYTHON] Python API & Examples for Locality Sensitive Hashing
## What changes were proposed in this pull request? This pull request includes python API and examples for LSH. The API changes was based on yanboliang 's PR #15768 and resolved conflicts and API changes on the Scala API. The examples are consistent with Scala examples of MinHashLSH and BucketedRandomProjectionLSH. ## How was this patch tested? API and examples are tested using spark-submit: `bin/spark-submit examples/src/main/python/ml/min_hash_lsh.py` `bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh.py` User guide changes are generated and manually inspected: `SKIP_API=1 jekyll build` Author: Yun Ni <yunn@uber.com> Author: Yanbo Liang <ybliang8@gmail.com> Author: Yunni <Euler57721@gmail.com> Closes #16715 from Yunni/spark-18080.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java38
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java57
2 files changed, 74 insertions, 21 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java
index ca3ee5a285..4594e3462b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java
@@ -35,8 +35,15 @@ import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.functions.col;
// $example off$
+/**
+ * An example demonstrating BucketedRandomProjectionLSH.
+ * Run with:
+ * bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample
+ */
public class JavaBucketedRandomProjectionLSHExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
@@ -61,7 +68,7 @@ public class JavaBucketedRandomProjectionLSHExample {
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
- new StructField("keys", new VectorUDT(), false, Metadata.empty())
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
Dataset<Row> dfB = spark.createDataFrame(dataB, schema);
@@ -71,26 +78,31 @@ public class JavaBucketedRandomProjectionLSHExample {
BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH()
.setBucketLength(2.0)
.setNumHashTables(3)
- .setInputCol("keys")
- .setOutputCol("values");
+ .setInputCol("features")
+ .setOutputCol("hashes");
BucketedRandomProjectionLSHModel model = mh.fit(dfA);
// Feature Transformation
+ System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");
model.transform(dfA).show();
- // Cache the transformed columns
- Dataset<Row> transformedA = model.transform(dfA).cache();
- Dataset<Row> transformedB = model.transform(dfB).cache();
- // Approximate similarity join
- model.approxSimilarityJoin(dfA, dfB, 1.5).show();
- model.approxSimilarityJoin(transformedA, transformedB, 1.5).show();
- // Self Join
- model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show();
+ // Compute the locality sensitive hashes for the input rows, then perform approximate
+ // similarity join.
+ // We could avoid computing hashes by passing in the already-transformed dataset, e.g.
+ // `model.approxSimilarityJoin(transformedA, transformedB, 1.5)`
+ System.out.println("Approximately joining dfA and dfB on distance smaller than 1.5:");
+ model.approxSimilarityJoin(dfA, dfB, 1.5, "EuclideanDistance")
+ .select(col("datasetA.id").alias("idA"),
+ col("datasetB.id").alias("idB"),
+ col("EuclideanDistance")).show();
- // Approximate nearest neighbor search
+ // Compute the locality sensitive hashes for the input rows, then perform approximate nearest
+ // neighbor search.
+ // We could avoid computing hashes by passing in the already-transformed dataset, e.g.
+ // `model.approxNearestNeighbors(transformedA, key, 2)`
+ System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");
model.approxNearestNeighbors(dfA, key, 2).show();
- model.approxNearestNeighbors(transformedA, key, 2).show();
// $example off$
spark.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java
index 9dbbf6d117..0aace46939 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java
@@ -25,6 +25,7 @@ import java.util.List;
import org.apache.spark.ml.feature.MinHashLSH;
import org.apache.spark.ml.feature.MinHashLSHModel;
+import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
@@ -34,8 +35,15 @@ import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.functions.col;
// $example off$
+/**
+ * An example demonstrating MinHashLSH.
+ * Run with:
+ * bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample
+ */
public class JavaMinHashLSHExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
@@ -44,25 +52,58 @@ public class JavaMinHashLSHExample {
.getOrCreate();
// $example on$
- List<Row> data = Arrays.asList(
+ List<Row> dataA = Arrays.asList(
RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})),
RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0}))
);
+ List<Row> dataB = Arrays.asList(
+ RowFactory.create(0, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0, 1.0, 1.0})),
+ RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0, 1.0, 1.0})),
+ RowFactory.create(2, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0, 1.0, 1.0}))
+ );
+
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
- new StructField("keys", new VectorUDT(), false, Metadata.empty())
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
+ Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
+ Dataset<Row> dfB = spark.createDataFrame(dataB, schema);
+
+ int[] indices = {1, 3};
+ double[] values = {1.0, 1.0};
+ Vector key = Vectors.sparse(6, indices, values);
MinHashLSH mh = new MinHashLSH()
- .setNumHashTables(1)
- .setInputCol("keys")
- .setOutputCol("values");
+ .setNumHashTables(5)
+ .setInputCol("features")
+ .setOutputCol("hashes");
+
+ MinHashLSHModel model = mh.fit(dfA);
+
+ // Feature Transformation
+ System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");
+ model.transform(dfA).show();
+
+ // Compute the locality sensitive hashes for the input rows, then perform approximate
+ // similarity join.
+ // We could avoid computing hashes by passing in the already-transformed dataset, e.g.
+ // `model.approxSimilarityJoin(transformedA, transformedB, 0.6)`
+ System.out.println("Approximately joining dfA and dfB on Jaccard distance smaller than 0.6:");
+ model.approxSimilarityJoin(dfA, dfB, 0.6, "JaccardDistance")
+ .select(col("datasetA.id").alias("idA"),
+ col("datasetB.id").alias("idB"),
+ col("JaccardDistance")).show();
- MinHashLSHModel model = mh.fit(dataFrame);
- model.transform(dataFrame).show();
+ // Compute the locality sensitive hashes for the input rows, then perform approximate nearest
+ // neighbor search.
+ // We could avoid computing hashes by passing in the already-transformed dataset, e.g.
+ // `model.approxNearestNeighbors(transformedA, key, 2)`
+ // It may return less than 2 rows when not enough approximate near-neighbor candidates are
+ // found.
+ System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");
+ model.approxNearestNeighbors(dfA, key, 2).show();
// $example off$
spark.stop();