aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-04 15:55:09 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 15:55:09 -0800
commitdc101b0e4e23dffddbc2f70d14a19fae5d87a328 (patch)
treee436271c351a64caa4727661cd6143ba6e415fa6 /sql
parente0490e271d078aa55d7c7583e2ba80337ed1b0c4 (diff)
downloadspark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.gz
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.bz2
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.zip
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <davies@databricks.com> Closes #4351 from davies/python_udf and squashes the following commits: d250692 [Davies Liu] fix conflict 34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 440f769 [Davies Liu] address comments f0a3121 [Davies Liu] track life cycle of broadcast f99b2e1 [Davies Liu] address comments 462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 7bccc3b [Davies Liu] python udf 58dee20 [Davies Liu] clean up
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala27
2 files changed, 44 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ddce77deb8..4c2aeadae9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -128,7 +128,6 @@ trait Column extends DataFrame {
*/
def unary_! : Column = exprToColumn(Not(expr))
-
/**
* Equality test.
* {{{
@@ -166,7 +165,7 @@ trait Column extends DataFrame {
*
* // Java:
* import static org.apache.spark.sql.Dsl.*;
- * df.filter( not(col("colA").equalTo(col("colB"))) );
+ * df.filter( col("colA").notEqual(col("colB")) );
* }}}
*/
def !== (other: Any): Column = constructColumn(other) { o =>
@@ -174,6 +173,22 @@ trait Column extends DataFrame {
}
/**
+ * Inequality test.
+ * {{{
+ * // Scala:
+ * df.select( df("colA") !== df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( col("colA").notEqual(col("colB")) );
+ * }}}
+ */
+ def notEqual(other: Any): Column = constructColumn(other) { o =>
+ Not(EqualTo(expr, o.expr))
+ }
+
+ /**
* Greater than.
* {{{
* // Scala: The following selects people older than 21.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 8d7c2a1b83..c60d407094 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -17,7 +17,13 @@
package org.apache.spark.sql
+import java.util.{List => JList, Map => JMap}
+
+import org.apache.spark.Accumulator
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
/**
@@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
}
}
+
+/**
+ * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
+ * This is used by Python API.
+ */
+private[sql] case class UserDefinedPythonFunction(
+ name: String,
+ command: Array[Byte],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
+ accumulator: Accumulator[JList[Array[Byte]]],
+ dataType: DataType) {
+
+ def apply(exprs: Column*): Column = {
+ val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
+ accumulator, dataType, exprs.map(_.expr))
+ Column(udf)
+ }
+}