summaryrefslogtreecommitdiff
path: root/src/repl/scala/tools/nsc/interpreter/Scripted.scala
blob: 8d87d98e534318e96906e56a63c0c2586d96bd01 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
/* NSC -- new Scala compiler
 * Copyright 2005-2016 LAMP/EPFL
 */
package scala
package tools.nsc
package interpreter

import scala.language.dynamics

import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import scala.reflect.classTag
import scala.reflect.internal.util.Position
import scala.tools.nsc.util.stringFromReader
import javax.script._, ScriptContext.{ ENGINE_SCOPE, GLOBAL_SCOPE }
import java.io.{ Closeable, Reader }

/* A REPL adaptor for the javax.script API. */
class Scripted(@BeanProperty val factory: ScriptEngineFactory, settings: Settings, out: JPrintWriter)
  extends AbstractScriptEngine with Compilable {

  def createBindings: Bindings = new SimpleBindings

  // dynamic context bound under this name
  final val ctx = "$ctx"

  // the underlying interpreter, tweaked to handle dynamic bindings
  val intp = new IMain(settings, out) {
    import global.{ Name, TermName }

    /* Modify the template to snag definitions from dynamic context.
     * So object $iw { x + 42 } becomes object $iw { def x = $ctx.x ; x + 42 }
     */
    override protected def importsCode(wanted: Set[Name], wrapper: Request#Wrapper, definesClass: Boolean, generousImports: Boolean) = {

      // cull references that can be satisfied from the current dynamic context
      val contextual = wanted & contextNames

      if (contextual.nonEmpty) {
        val neededContext = (wanted &~ contextual) + TermName(ctx)
        val ComputedImports(header, preamble, trailer, path) = super.importsCode(neededContext, wrapper, definesClass, generousImports)
        val adjusted = contextual.map { n =>
            val valname = n.decodedName
            s"""def `$valname` = $ctx.`$valname`
                def `${valname}_=`(x: _root_.java.lang.Object) = $ctx.`$valname` = x"""
          }.mkString(preamble, "\n", "\n")
        ComputedImports(header, adjusted, trailer, path)
      }
      else super.importsCode(wanted, wrapper, definesClass, generousImports)
    }

    // names available in current dynamic context
    def contextNames: Set[Name] = {
      val ctx = compileContext
      val terms = for {
        scope   <- ctx.getScopes.asScala
        binding <- Option(ctx.getBindings(scope)) map (_.asScala) getOrElse Nil
        key = binding._1
      } yield (TermName(key): Name)
      terms.to[Set]
    }

    // save first error for exception; console display only if debugging
    override lazy val reporter: ReplReporter = new ReplReporter(this) {
      override def display(pos: Position, msg: String, severity: Severity): Unit =
        if (isReplDebug) super.display(pos, msg, severity)
      override def error(pos: Position, msg: String): Unit = {
        if (firstError.isEmpty) firstError = Some((pos, msg))
        super.error(pos, msg)
      }
      override def reset() = { super.reset() ; firstError = None }
    }
  }
  intp.initializeSynchronous()

  var compileContext: ScriptContext = getContext

  val scriptContextRep = new intp.ReadEvalPrint

  def dynamicContext_=(ctx: ScriptContext): Unit = scriptContextRep.callEither("set", ctx)

  def dynamicContext: ScriptContext = scriptContextRep.callEither("value") match {
    case Right(ctx: ScriptContext) => ctx
    case Left(e) => throw e
    case Right(other) => throw new ScriptException(s"Unexpected value for context: $other")
  }

  if (intp.isInitializeComplete) {
    // compile the dynamic ScriptContext object holder
    val ctxRes = scriptContextRep compile s"""
      |import _root_.javax.script._
      |object ${scriptContextRep.evalName} {
      |  var value: ScriptContext = _
      |  def set(x: _root_.scala.Any) = value = x.asInstanceOf[ScriptContext]
      |}
    """.stripMargin
    if (!ctxRes) throw new ScriptException("Failed to compile ctx")
    dynamicContext = getContext

    // Bridge dynamic references and script context
    val dynRes = intp compileString s"""
     |package scala.tools.nsc.interpreter
     |import _root_.scala.language.dynamics
     |import _root_.javax.script._, ScriptContext.ENGINE_SCOPE
     |object dynamicBindings extends _root_.scala.Dynamic {
     |  def context: ScriptContext = ${ scriptContextRep.evalPath }.value
     |  // $ctx.x retrieves the attribute x
     |  def selectDynamic(field: _root_.java.lang.String): _root_.java.lang.Object = context.getAttribute(field)
     |  // $ctx.x = v
     |  def updateDynamic(field: _root_.java.lang.String)(value: _root_.java.lang.Object) = context.setAttribute(field, value, ENGINE_SCOPE)
     |}
     |""".stripMargin
    if (!dynRes) throw new ScriptException("Failed to compile dynamicBindings")
    intp beQuietDuring {
      intp interpret s"val $ctx: _root_.scala.tools.nsc.interpreter.dynamicBindings.type = _root_.scala.tools.nsc.interpreter.dynamicBindings"
      intp bind ("$engine" -> (this: ScriptEngine with Compilable))
    }
  }

  // Set the context for dynamic resolution and run the body.
  // Defines attributes available for evaluation.
  // Avoid reflective access if using default context.
  def withScriptContext[A](context: ScriptContext)(body: => A): A =
    if (context eq getContext) body else {
      val saved = dynamicContext
      dynamicContext = context
      try body
      finally dynamicContext = saved
    }
  // Defines attributes available for compilation.
  def withCompileContext[A](context: ScriptContext)(body: => A): A = {
    val saved = compileContext
    compileContext = context
    try body
    finally compileContext = saved
  }

  // not obvious that ScriptEngine should accumulate code text
  private var code = ""

  private var firstError: Option[(Position, String)] = None

  /* All scripts are compiled. The supplied context defines what references
   * not in REPL history are allowed, though a different context may be
   * supplied for evaluation of a compiled script.
   */
  def compile(script: String, context: ScriptContext): CompiledScript =
    withCompileContext(context) {
      val cat = code + script
      intp.compile(cat, synthetic = false) match {
        case Right(req)          =>
          code = ""
          new WrappedRequest(req)
        case Left(IR.Incomplete) =>
          code = cat + "\n"
          new CompiledScript {
            def eval(context: ScriptContext): Object = null
            def getEngine: ScriptEngine = Scripted.this
          }
        case Left(_)             =>
          code = ""
          throw firstError map {
            case (pos, msg) => new ScriptException(msg, script, pos.line, pos.column)
          } getOrElse new ScriptException("compile-time error")
      }
    }

  // documentation
  //protected var context: ScriptContext
  //def getContext: ScriptContext = context

  /* Compile with the default context. All references must be resolvable. */
  @throws[ScriptException]
  def compile(script: String): CompiledScript = compile(script, context)

  @throws[ScriptException]
  def compile(reader: Reader): CompiledScript = compile(stringFromReader(reader), context)

  /* Compile and evaluate with the given context. */
  @throws[ScriptException]
  def eval(script: String, context: ScriptContext): Object = compile(script, context).eval(context)

  @throws[ScriptException]
  def eval(reader: Reader, context: ScriptContext): Object = compile(stringFromReader(reader), context).eval(context)

  private class WrappedRequest(val req: intp.Request) extends CompiledScript {
    var first = true

    private def evalEither(r: intp.Request, ctx: ScriptContext) = {
      if (ctx.getWriter == null && ctx.getErrorWriter == null && ctx.getReader == null) r.lineRep.evalEither
      else {
        val closeables = Array.ofDim[Closeable](2)
        val w = if (ctx.getWriter == null) Console.out else {
          val v = new WriterOutputStream(ctx.getWriter)
          closeables(0) = v
          v
        }
        val e = if (ctx.getErrorWriter == null) Console.err else {
          val v = new WriterOutputStream(ctx.getErrorWriter)
          closeables(1) = v
          v
        }
        val in = if (ctx.getReader == null) Console.in else ctx.getReader
        try {
          Console.withOut(w) {
            Console.withErr(e) {
              Console.withIn(in) {
                r.lineRep.evalEither
              }
            }
          }
        } finally {
          closeables foreach (c => if (c != null) c.close())
        }
      }
    }

    /* First time, cause lazy evaluation of a memoized result.
     * Subsequently, instantiate a new object for evaluation.
     * Per the API: Checked exception types thrown by underlying scripting implementations
     * must be wrapped in instances of ScriptException.
     */
    @throws[ScriptException]
    override def eval(context: ScriptContext) = withScriptContext(context) {
      if (first) {
        val result = evalEither(req, context) match {
          case Left(e: RuntimeException) => throw e
          case Left(e: Exception)        => throw new ScriptException(e)
          case Left(e)                   => throw e
          case Right(result)             => result.asInstanceOf[Object]
        }
        intp recordRequest req
        first = false
        result
      } else {
        val defines = req.defines
        if (defines.isEmpty) {
          Scripted.this.eval(s"new ${req.lineRep.readPath}")
          intp recordRequest duplicate(req)
          null
        } else {
          val instance = s"val $$INSTANCE = new ${req.lineRep.readPath};"
          val newline  = (defines map (s => s"val ${s.name} = $$INSTANCE${req.accessPath}.${s.name}")).mkString(instance, ";", ";")
          val newreq   = intp.requestFromLine(newline).right.get
          val ok = newreq.compile

          val result = evalEither(newreq, context) match {
            case Left(e: RuntimeException) => throw e
            case Left(e: Exception)        => throw new ScriptException(e)
            case Left(e)                   => throw e
            case Right(result)             => intp recordRequest newreq ; result.asInstanceOf[Object]
          }
          result
        }
      }
    }

    def duplicate(req: intp.Request) = new intp.Request(req.line, req.trees)

    def getEngine: ScriptEngine = Scripted.this
  }
}

object Scripted {
  import IMain.{ defaultSettings, defaultOut }
  import java.util.Arrays.asList
  import scala.util.Properties.versionString

  class Factory extends ScriptEngineFactory {
    @BeanProperty val engineName      = "Scala REPL"

    @BeanProperty val engineVersion   = "2.0"

    @BeanProperty val extensions      = asList("scala")

    @BeanProperty val languageName    = "Scala"

    @BeanProperty val languageVersion = versionString

    @BeanProperty val mimeTypes       = asList("application/x-scala")

    @BeanProperty val names           = asList("scala")

    def getMethodCallSyntax(obj: String, m: String, args: String*): String = args.mkString(s"$obj.$m(", ", ", ")")

    def getOutputStatement(toDisplay: String): String = s"Console.println($toDisplay)"

    def getParameter(key: String): Object = key match {
      case ScriptEngine.ENGINE           => engineName
      case ScriptEngine.ENGINE_VERSION   => engineVersion
      case ScriptEngine.LANGUAGE         => languageName
      case ScriptEngine.LANGUAGE_VERSION => languageVersion
      case ScriptEngine.NAME             => names.get(0)
      case _ => null
    }

    def getProgram(statements: String*): String = statements.mkString("object Main extends _root_.scala.App {\n\t", "\n\t", "\n}")

    def getScriptEngine: ScriptEngine = {
      val settings = new Settings()
      settings.usemanifestcp.value = true
      Scripted(this, settings)
    }
  }

  def apply(factory: ScriptEngineFactory = new Factory, settings: Settings = defaultSettings, out: JPrintWriter = defaultOut) = {
    settings.Yreplclassbased.value = true
    settings.usejavacp.value       = true
    val s = new Scripted(factory, settings, out)
    s.setBindings(s.createBindings, ScriptContext.ENGINE_SCOPE)
    s
  }
}

import java.io.Writer
import java.nio.{ ByteBuffer, CharBuffer }
import java.nio.charset.{ Charset, CodingErrorAction }
import CodingErrorAction.{ REPLACE => Replace }

/* An OutputStream that decodes bytes and flushes to the writer. */
class WriterOutputStream(writer: Writer) extends OutputStream {
  val decoder = Charset.defaultCharset.newDecoder
  decoder onMalformedInput Replace
  decoder onUnmappableCharacter Replace

  val byteBuffer = ByteBuffer.allocate(64)
  val charBuffer = CharBuffer.allocate(64)

  override def write(b: Int): Unit = {
    byteBuffer.put(b.toByte)
    byteBuffer.flip()
    val result = decoder.decode(byteBuffer, charBuffer, /*eoi=*/ false)
    if (byteBuffer.remaining == 0) byteBuffer.clear()
    if (charBuffer.position() > 0) {
      charBuffer.flip()
      writer write charBuffer.toString
      charBuffer.clear()
    }
  }
  override def close(): Unit = {
    decoder.decode(byteBuffer, charBuffer, /*eoi=*/ true)
    decoder.flush(charBuffer)
  }
  override def toString = charBuffer.toString
}