Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion malsim/envs/malsim_vectorized_obs_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def reset(
# Reset observation and action mask for agents
self._agent_observations[agent.name] = self._create_blank_observation()
self._agent_infos[agent.name] = self.create_action_mask(agent)
pre_enabled_nodes |= agent.performed_nodes
pre_enabled_nodes |= set(agent.performed_nodes)

self._update_observations(pre_enabled_nodes, set())

Expand Down
26 changes: 14 additions & 12 deletions malsim/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MalSimAgentState:
action_surface: frozenset[AttackGraphNode]

# Contains all nodes that this agent has performed successfully
performed_nodes: frozenset[AttackGraphNode]
performed_nodes: tuple[AttackGraphNode, ...]

# Contains the nodes performed successfully in the last step
step_performed_nodes: frozenset[AttackGraphNode]
Expand Down Expand Up @@ -652,10 +652,10 @@ def _create_attacker_state(
else None
)
compromised_nodes = (
entry_points
if self.sim_settings.compromise_entrypoints_at_start else frozenset()
tuple(entry_points)
if self.sim_settings.compromise_entrypoints_at_start else tuple()
)
attack_surface = self._get_attack_surface(compromised_nodes)
attack_surface = self._get_attack_surface(set(compromised_nodes))

if not self.sim_settings.compromise_entrypoints_at_start:
# If entrypoints not compromised at start,
Expand All @@ -671,7 +671,7 @@ def _create_attacker_state(
action_surface=frozenset(attack_surface),
step_action_surface_additions=frozenset(attack_surface),
step_action_surface_removals=frozenset(),
step_performed_nodes=compromised_nodes,
step_performed_nodes=frozenset(compromised_nodes),
step_unviable_nodes=frozenset(),
step_attempted_nodes=frozenset(),
num_attempts=MappingProxyType(
Expand All @@ -695,7 +695,7 @@ def _update_attacker_state(
# Find what nodes attacker can reach this step
action_surface_additions = (
self._get_attack_surface(
attacker_state.performed_nodes | step_agent_compromised_nodes,
set(attacker_state.performed_nodes) | step_agent_compromised_nodes,
from_nodes=step_agent_compromised_nodes,
)
- attacker_state.action_surface
Expand All @@ -718,7 +718,7 @@ def _update_attacker_state(
attacker_state.name,
sim=self,
performed_nodes=(
attacker_state.performed_nodes | step_agent_compromised_nodes
attacker_state.performed_nodes + tuple(step_agent_compromised_nodes)
),
action_surface=new_action_surface,
step_action_surface_additions=action_surface_additions,
Expand Down Expand Up @@ -762,15 +762,15 @@ def _create_defender_state(self, name: str) -> MalSimDefenderState:

compromised_steps: set[AttackGraphNode] = set()
for attacker_state in self._get_attacker_agents():
compromised_steps |= attacker_state.performed_nodes
compromised_steps |= set(attacker_state.performed_nodes)

defense_surface = self._get_defense_surface()
step_observed_nodes = self._defender_observed_nodes(compromised_steps)

defender_state = MalSimDefenderState(
name,
sim=self,
performed_nodes=frozenset(self._enabled_defenses),
performed_nodes=tuple(self._enabled_defenses),
compromised_nodes=frozenset(compromised_steps),
step_compromised_nodes=frozenset(compromised_steps),
observed_nodes=frozenset(step_observed_nodes),
Expand Down Expand Up @@ -822,7 +822,9 @@ def _update_defender_state(
updated_defender_state = MalSimDefenderState(
defender_state.name,
sim=self,
performed_nodes=(defender_state.performed_nodes | step_enabled_defenses),
performed_nodes=(
defender_state.performed_nodes + tuple(step_enabled_defenses)
),
compromised_nodes=frozenset(
defender_state.compromised_nodes | step_compromised_nodes
),
Expand Down Expand Up @@ -998,7 +1000,7 @@ def _attacker_step(
'comes from the agents action surface.'
)

traversable = self.node_is_traversable(agent.performed_nodes, node)
traversable = self.node_is_traversable(set(agent.performed_nodes), node)
if node in agent.entry_points:
# Entrypoints are traversable as long as they are viable
traversable = self.node_is_viable(node)
Expand Down Expand Up @@ -1161,7 +1163,7 @@ def _attacker_is_terminated(attacker_state: MalSimAttackerState) -> bool:
if attacker_state.goals:
# Attacker is terminated if it has goals and all goals are met
return (
attacker_state.goals & attacker_state.performed_nodes
attacker_state.goals & set(attacker_state.performed_nodes)
== attacker_state.goals
)
# Otherwise not terminated
Expand Down
45 changes: 23 additions & 22 deletions tests/test_mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,23 @@ def test_is_traversable(corelang_lang_graph: LanguageGraph, model: Model) -> Non

if node in children_of_reached_nodes:
if node.type == 'and':
if not sim.node_is_traversable(attacker_state.performed_nodes, node):
if not sim.node_is_traversable(
set(attacker_state.performed_nodes), node
):
assert not all(
p in attacker_state.performed_nodes
for p in node.parents
if p.type in ('or', 'and')
) or not sim.node_is_viable(node)
if node.type == 'or':
if not sim.node_is_traversable(attacker_state.performed_nodes, node):
if not sim.node_is_traversable(
set(attacker_state.performed_nodes), node
):
assert not sim.node_is_viable(node)
else:
assert not sim.node_is_traversable(attacker_state.performed_nodes, node)
assert not sim.node_is_traversable(
set(attacker_state.performed_nodes), node
)


def test_not_initial_compromise_entrypoints(
Expand All @@ -345,21 +351,19 @@ def test_not_initial_compromise_entrypoints(
entry_point = get_node(attack_graph, 'OS App:fullAccess')
sim = MalSimulator(
attack_graph,
sim_settings=MalSimulatorSettings(
compromise_entrypoints_at_start=False
)
sim_settings=MalSimulatorSettings(compromise_entrypoints_at_start=False),
)
attacker_name = 'Test Attacker'
sim.register_attacker(attacker_name, {entry_point})
attacker_state = sim.reset()[attacker_name]

# No performed nodes, action surface is only the entrypoint
assert attacker_state.performed_nodes == set()
assert attacker_state.performed_nodes == tuple()
assert attacker_state.action_surface == {entry_point}

# Step through entrypoint adds it to performed nodes and extends the action surface
attacker_state = sim.step({attacker_name: [entry_point]})[attacker_name]
assert attacker_state.performed_nodes == {entry_point}
assert attacker_state.performed_nodes == (entry_point,)
assert attacker_state.action_surface == {n for n in entry_point.children}


Expand All @@ -369,9 +373,7 @@ def test_not_initial_compromise_entrypoints_unviable_step(
attack_graph = AttackGraph(corelang_lang_graph, model)
sim = MalSimulator(
attack_graph,
sim_settings=MalSimulatorSettings(
compromise_entrypoints_at_start=False
)
sim_settings=MalSimulatorSettings(compromise_entrypoints_at_start=False),
)
attacker_name = 'Test Attacker'
defender_name = 'Test Defender'
Expand All @@ -380,12 +382,11 @@ def test_not_initial_compromise_entrypoints_unviable_step(
attacker_state = sim.reset()[attacker_name]

# Step should not succeed if defender defended the entrypoint
attacker_state = sim.step({
attacker_name: ['OS App:fullAccess'],
defender_name: ['OS App:notPresent']
})[attacker_name]
assert attacker_state.performed_nodes == set()
assert attacker_state.action_surface == set()
attacker_state = sim.step(
{attacker_name: ['OS App:fullAccess'], defender_name: ['OS App:notPresent']}
)[attacker_name]
assert attacker_state.performed_nodes == tuple()
assert attacker_state.action_surface == frozenset()


def test_is_compromised(corelang_lang_graph: LanguageGraph, model: Model) -> None:
Expand Down Expand Up @@ -663,8 +664,8 @@ def test_agent_state_views_simple(
assert asv.step_performed_nodes == {entry_point}
assert dsv.step_performed_nodes == pre_enabled_defenses

assert asv.performed_nodes == {entry_point}
assert dsv.performed_nodes == pre_enabled_defenses
assert asv.performed_nodes == (entry_point,)
assert dsv.performed_nodes == tuple(pre_enabled_defenses)

assert len(asv.action_surface) == 6
assert set(n.full_name for n in dsv.action_surface) == {
Expand Down Expand Up @@ -719,8 +720,8 @@ def test_agent_state_views_simple(
assert asv.step_performed_nodes == {os_app_attempt_deny}
assert dsv.step_performed_nodes == {program2_not_present}

assert asv.performed_nodes == {os_app_attempt_deny, entry_point}
assert dsv.performed_nodes == pre_enabled_defenses | {program2_not_present}
assert asv.performed_nodes == (entry_point, os_app_attempt_deny)
assert dsv.performed_nodes == tuple(pre_enabled_defenses) + (program2_not_present,)

assert asv.step_action_surface_additions == {os_app_success_deny}
assert dsv.step_action_surface_additions == set()
Expand Down Expand Up @@ -948,7 +949,7 @@ def test_simulator_no_fpr_fnr() -> None:
assert isinstance(attacker_state, MalSimAttackerState)

# No false positives or negatives
assert defender_state.compromised_nodes == attacker_state.performed_nodes
assert defender_state.compromised_nodes == frozenset(attacker_state.performed_nodes)


def test_simulator_ttcs() -> None:
Expand Down