Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/stdlib.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ jobs:
test-scala3-compiler-bootstrapped:
runs-on: ubuntu-latest
needs: [scala3-compiler-bootstrapped, tasty-core-bootstrapped, scala-library-bootstrapped, scala3-staging, scala3-tasty-inspector]
env:
DOTTY_DISABLE_REPL_BYTECODE_INSTRUMENTATION: true
steps:
- name: Git Checkout
uses: actions/checkout@v5
Expand Down Expand Up @@ -550,6 +552,8 @@ jobs:
scripted-tests:
runs-on: ubuntu-latest
needs: [scala3-compiler-bootstrapped, tasty-core-bootstrapped, scala3-staging, scala3-tasty-inspector, scala-library-sjs, scaladoc]
env:
DOTTY_DISABLE_REPL_BYTECODE_INSTRUMENTATION: true
steps:
- name: Git Checkout
uses: actions/checkout@v5
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ private sealed trait XSettings:
val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.")
val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).")
val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.")
val XreplDisableBytecodeInstrumentation: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-bytecode-instrumentation", "Disable bytecode instrumentation for interrupt handling in REPL.")
val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.")
val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.")
val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000)
Expand Down
64 changes: 61 additions & 3 deletions compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ package repl
import scala.language.unsafeNulls

import io.AbstractFile
import dotty.tools.repl.ReplBytecodeInstrumentation

import java.net.{URL, URLConnection, URLStreamHandler}
import java.util.Collections

class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent):
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, instrumentBytecode: Boolean = true) extends ClassLoader(parent):
private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false)

// on JDK 20 the URL constructor we're using is deprecated,
Expand Down Expand Up @@ -53,9 +54,66 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten
if (file == null) {
throw new ClassNotFoundException(name)
}
val bytes = file.toByteArray
val originalBytes = file.toByteArray

// Instrument bytecode for everything except StopRepl itself to avoid infinite recursion
val bytes =
if !instrumentBytecode || name == "dotty.tools.repl.StopRepl" then originalBytes
else ReplBytecodeInstrumentation.instrument(originalBytes)

defineClass(name, bytes, 0, bytes.length)
}

override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name)
private def tryInstrumentLibraryClass(name: String): Class[?] =
try
val resourceName = name.replace('.', '/') + ".class"
getParent.getResourceAsStream(resourceName) match{
case null => super.loadClass(resourceName)
case is =>
try
val bytes = is.readAllBytes()
val instrumentedBytes =
if instrumentBytecode then ReplBytecodeInstrumentation.instrument(bytes)
else bytes
defineClass(name, instrumentedBytes, 0, instrumentedBytes.length)
finally is.close()
}
catch
case ex: Exception => super.loadClass(name)

override def loadClass(name: String): Class[?] =
if !instrumentBytecode then
return super.loadClass(name)

// Check if already loaded
val loaded = findLoadedClass(name)
if loaded != null then
return loaded

// Don't instrument JDK classes or StopRepl
name match{
case s"java.$_" => super.loadClass(name)
case s"javax.$_" => super.loadClass(name)
case s"sun.$_" => super.loadClass(name)
case s"jdk.$_" => super.loadClass(name)
case "dotty.tools.repl.StopRepl" =>
// Load StopRepl from parent but ensure each classloader gets its own copy
val is = getParent.getResourceAsStream(name.replace('.', '/') + ".class")
if is != null then
try
val bytes = is.readAllBytes()
defineClass(name, bytes, 0, bytes.length)
finally
is.close()
else
// Can't get as resource, use the classloader that loaded this AbstractFileClassLoader
// class itself, which must have access to StopRepl
classOf[AbstractFileClassLoader].getClassLoader.loadClass(name)
case _ =>
try findClass(name)
catch case _: ClassNotFoundException =>
// Not in REPL output, try to load from parent and instrument it
tryInstrumentLibraryClass(name)
}

end AbstractFileClassLoader
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/repl/Rendering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None):
new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader)
}

myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent)
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent, instrumentBytecode)
myClassLoader
}

Expand Down
75 changes: 75 additions & 0 deletions compiler/src/dotty/tools/repl/ReplBytecodeInstrumentation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package dotty.tools
package repl

import scala.language.unsafeNulls

import scala.tools.asm.*
import scala.tools.asm.Opcodes.*
import scala.tools.asm.tree.*
import scala.collection.JavaConverters.*
import java.util.concurrent.atomic.AtomicBoolean

object ReplBytecodeInstrumentation:
/** Instrument bytecode to add checks to throw an exception if the REPL command is cancelled
*/
def instrument(originalBytes: Array[Byte]): Array[Byte] =
try
val cr = new ClassReader(originalBytes)
val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
val instrumenter = new InstrumentClassVisitor(cw)
cr.accept(instrumenter, ClassReader.EXPAND_FRAMES)
cw.toByteArray
catch
case ex: Exception => originalBytes

def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit =
val cancelClassOpt =
try Some(classLoader.loadClass(classOf[dotty.tools.repl.StopRepl].getName))
catch{
case _: java.lang.ClassNotFoundException => None
}
for(cancelClass <- cancelClassOpt){
val setAllStopMethod = cancelClass.getDeclaredMethod("setStop", classOf[Boolean])
setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef])
}

private class InstrumentClassVisitor(cv: ClassVisitor) extends ClassVisitor(ASM9, cv):

override def visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String,
exceptions: Array[String]
): MethodVisitor =
new InstrumentMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions))

/** MethodVisitor that inserts stop checks at backward branches */
private class InstrumentMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv):
// Track labels we've seen to identify backward branches
private val seenLabels = scala.collection.mutable.Set[Label]()

def addStopCheck() = mv.visitMethodInsn(
INVOKESTATIC,
classOf[dotty.tools.repl.StopRepl].getName.replace('.', '/'),
"throwIfReplStopped",
"()V",
false
)

override def visitCode(): Unit =
super.visitCode()
// Insert throwIfReplStopped() call at the start of the method
// to allow breaking out of deeply recursive methods like fib(99)
addStopCheck()

override def visitLabel(label: Label): Unit =
seenLabels.add(label)
super.visitLabel(label)

override def visitJumpInsn(opcode: Int, label: Label): Unit =
// Add throwIfReplStopped if this is a backward branch (jumping to a label we've already seen)
if seenLabels.contains(label) then addStopCheck()
super.visitJumpInsn(opcode, label)

end ReplBytecodeInstrumentation
13 changes: 11 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver}
import dotty.tools.dotc.config.CompilerCommand
import dotty.tools.io.*
import dotty.tools.repl.Rendering.showUser
import dotty.tools.repl.ReplBytecodeInstrumentation
import dotty.tools.runner.ScalaClassLoader.*
import org.jline.reader.*

Expand Down Expand Up @@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String],
// Set up interrupt handler for command execution
var firstCtrlCEntered = false
val thread = Thread.currentThread()

// Clear the stop flag before executing new code
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false)

val previousSignalHandler = terminal.handle(
org.jline.terminal.Terminal.Signal.INT,
(sig: org.jline.terminal.Terminal.Signal) => {
if (!firstCtrlCEntered) {
firstCtrlCEntered = true
// Set the stop flag to trigger throwIfReplStopped() in instrumented code
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true)
// Also interrupt the thread as a fallback for non-instrumented code
thread.interrupt()
out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process")
out.println("\nInterrupting running thread")
} else {
out.println("\nTerminating REPL Process...")
System.exit(130) // Standard exit code for SIGINT
Expand Down Expand Up @@ -591,8 +599,9 @@ class ReplDriver(settings: Array[String],
val prevClassLoader = rendering.classLoader()
val jarClassLoader = fromURLsParallelCapable(
jarClassPath.asURLs, prevClassLoader)
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
rendering.myClassLoader = new AbstractFileClassLoader(
prevOutputDir, jarClassLoader)
prevOutputDir, jarClassLoader, instrumentBytecode)

out.println(s"Added '$path' to classpath.")
} catch {
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/repl/StopRepl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dotty.tools.repl

import scala.annotation.static

class StopRepl

object StopRepl {
// Needs to be volatile, otherwise changes to this may not get seen by other threads
// for arbitrarily long periods of time (minutes!)
@static @volatile private var stop: Boolean = false

@static def setStop(n: Boolean): Unit = { stop = n }

/** Check if execution should stop, and throw ThreadDeath if so */
@static def throwIfReplStopped(): Unit = {
if (stop) throw new ThreadDeath()
}
}
11 changes: 10 additions & 1 deletion compiler/test/dotty/tools/vulpix/TestConfiguration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,16 @@ object TestConfiguration {

val commonOptions = Array("-indent") ++ checkOptions ++ noCheckOptions ++ yCheckOptions
val noYcheckCommonOptions = Array("-indent") ++ checkOptions ++ noCheckOptions
val defaultOptions = TestFlags(basicClasspath, commonOptions) `and` "-Yno-stdlib-patches"

// Conditionally add -Xrepl-disable-bytecode-instrumentation when running in CI environments
// that have classloader issues with the new bytecode instrumentation
private val replBytecodeInstrumentationFlag =
if sys.env.isDefinedAt("DOTTY_DISABLE_REPL_BYTECODE_INSTRUMENTATION") then
Array("-Xrepl-disable-bytecode-instrumentation")
else
Array.empty[String]

val defaultOptions = TestFlags(basicClasspath, commonOptions ++ replBytecodeInstrumentationFlag) `and` "-Yno-stdlib-patches"
val noYcheckOptions = TestFlags(basicClasspath, noYcheckCommonOptions)
val bestEffortBaselineOptions = TestFlags(basicClasspath, noCheckOptions)
val unindentOptions = TestFlags(basicClasspath, Array("-no-indent") ++ checkOptions ++ noCheckOptions ++ yCheckOptions)
Expand Down
Loading