"""Trace recording, filtering, and formatting for comprehensible race condition errors.
When frontrun finds a race condition, the raw counterexample is a list of thread
indices — one per bytecode instruction. This module transforms that into a
human-readable "story" of which source lines executed in which order.
The pipeline:
1. **Record** a TraceEvent at each opcode during the failing run.
2. **Filter** to events that touch shared state (LOAD_ATTR/STORE_ATTR, etc.)
3. **Deduplicate** consecutive events from the same thread on the same source line.
4. **Classify** the conflict pattern (lost update, order violation, etc.)
5. **Format** as an interleaved source-line trace.
"""
from __future__ import annotations
import dis
import linecache
import sys
from dataclasses import dataclass
from typing import Any
from frontrun._cooperative import real_lock
_PY_VERSION = sys.version_info[:2]
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
[docs]
@dataclass(slots=True)
class TraceEvent:
"""A single recorded event from the trace."""
step_index: int
thread_id: int
filename: str
lineno: int
function_name: str
opcode: str
access_type: str | None = None # "read", "write", or None
attr_name: str | None = None # e.g. "value", "balance"
obj_type_name: str | None = None # e.g. "Counter", "BankAccount"
call_chain: list[str] | None = None # e.g. ["DB.dict", "do_incrs"]
[docs]
@dataclass(slots=True)
class SourceLineEvent:
"""A deduplicated, source-level event for display."""
thread_id: int
filename: str
lineno: int
function_name: str
source_line: str
access_type: str | None = None # "read", "write", or None (if mixed, "read+write")
attr_name: str | None = None
obj_type_name: str | None = None
call_chain: list[str] | None = None
# ---------------------------------------------------------------------------
# Frame introspection helpers
# ---------------------------------------------------------------------------
[docs]
def qualified_name(frame: Any) -> str:
"""Get a qualified function name from a frame (e.g. ``DB.dict``)."""
code = frame.f_code
qualname = getattr(code, "co_qualname", None) # Python 3.11+
if qualname is not None:
return qualname
# Fallback for 3.10: try to get class from 'self'
name = code.co_name
try:
self_obj = frame.f_locals.get("self")
if self_obj is not None:
return f"{type(self_obj).__name__}.{name}"
except Exception:
pass
return name
[docs]
def build_call_chain(frame: Any, *, filter_fn: Any, max_depth: int = 3) -> list[str] | None:
"""Walk user-code frames from *frame* upward, returning qualified names.
``filter_fn(filename) -> bool`` selects which frames are user code
(typically :func:`frontrun._tracing.should_trace_file`).
Returns ``None`` when the chain would be empty.
"""
chain: list[str] = []
f: Any = frame
while f is not None and len(chain) < max_depth:
if filter_fn(f.f_code.co_filename):
chain.append(qualified_name(f))
f = f.f_back
return chain or None
# ---------------------------------------------------------------------------
# Trace recorder
# ---------------------------------------------------------------------------
[docs]
class TraceRecorder:
"""Accumulates TraceEvent objects during a single execution.
Thread-safe: multiple threads call ``record()`` concurrently, each
holding the scheduler lock (so ordering is deterministic).
"""
__slots__ = ("events", "_step", "_lock", "enabled")
def __init__(self, *, enabled: bool = True) -> None:
self.events: list[TraceEvent] = []
self._step = 0
self._lock = real_lock()
self.enabled = enabled
[docs]
def record(
self,
thread_id: int,
frame: Any,
opcode: str | None = None,
access_type: str | None = None,
attr_name: str | None = None,
obj: Any = None,
obj_type_name: str | None = None,
call_chain: list[str] | None = None,
) -> None:
"""Record one trace event from a frame object."""
if not self.enabled:
return
code = frame.f_code
if obj_type_name is None and obj is not None:
obj_type_name = type(obj).__name__
with self._lock:
step = self._step
self._step += 1
ev = TraceEvent(
step_index=step,
thread_id=thread_id,
filename=code.co_filename,
lineno=frame.f_lineno,
function_name=code.co_name,
opcode=opcode or "",
access_type=access_type,
attr_name=attr_name,
obj_type_name=obj_type_name,
call_chain=call_chain,
)
# Append under the recorder lock for ordering consistency
with self._lock:
self.events.append(ev)
[docs]
def record_io(
self,
thread_id: int,
resource_id: str,
kind: str,
) -> None:
"""Record an I/O event that has no Python frame (e.g. C-level socket I/O)."""
if not self.enabled:
return
with self._lock:
step = self._step
self._step += 1
ev = TraceEvent(
step_index=step,
thread_id=thread_id,
filename="<C extension>",
lineno=0,
function_name="",
opcode="IO",
access_type=kind,
attr_name=resource_id,
obj_type_name="IO",
)
with self._lock:
self.events.append(ev)
[docs]
def record_from_opcode(
self,
thread_id: int,
frame: Any,
) -> None:
"""Record an event using the frame's current instruction.
Used by the bytecode explorer, which doesn't do shadow-stack
analysis. We inspect the instruction to extract access info.
"""
if not self.enabled:
return
code = frame.f_code
offset = frame.f_lasti
instr = _get_instruction(code, offset)
if instr is None:
return
op = instr.opname
access_type: str | None = None
attr_name: str | None = None
obj: Any = None
if op == "LOAD_ATTR":
access_type = "read"
attr_name = instr.argval
elif op == "STORE_ATTR":
access_type = "write"
attr_name = instr.argval
elif op == "DELETE_ATTR":
access_type = "write"
attr_name = instr.argval
elif op in ("BINARY_SUBSCR", "STORE_SUBSCR", "DELETE_SUBSCR"):
access_type = "write" if op.startswith(("STORE", "DELETE")) else "read"
elif op == "BINARY_OP":
argrepr = instr.argrepr
if argrepr and ("[" in argrepr or "NB_SUBSCR" in argrepr.upper()):
access_type = "read"
else:
# Not an interesting opcode
return
self.record(
thread_id=thread_id,
frame=frame,
opcode=op,
access_type=access_type,
attr_name=attr_name,
obj=obj,
)
# Lightweight instruction cache (separate from DPOR's to avoid cross-module coupling)
_instr_cache: dict[int, dict[int, dis.Instruction]] = {}
def _get_instruction(code: Any, offset: int) -> dis.Instruction | None:
code_id = id(code)
mapping = _instr_cache.get(code_id)
if mapping is None:
mapping = {}
if _PY_VERSION >= (3, 11):
instructions = dis.get_instructions(code, show_caches=False)
else:
instructions = dis.get_instructions(code)
for inst in instructions:
mapping[inst.offset] = inst
_instr_cache[code_id] = mapping
return mapping.get(offset)
# ---------------------------------------------------------------------------
# Filtering and deduplication
# ---------------------------------------------------------------------------
def _is_shared_access(ev: TraceEvent) -> bool:
"""Return True if this event represents an access to shared state."""
return ev.access_type is not None
[docs]
def filter_to_shared_accesses(events: list[TraceEvent]) -> list[TraceEvent]:
"""Keep only events that access shared mutable state."""
return [ev for ev in events if _is_shared_access(ev)]
[docs]
def deduplicate_to_source_lines(events: list[TraceEvent]) -> list[SourceLineEvent]:
"""Collapse consecutive events from the same thread+line into one SourceLineEvent.
When multiple opcodes on the same source line produce events (e.g.,
LOAD_ATTR then STORE_ATTR for ``self.value += 1``), merge them into
a single entry with a combined access_type — but only when they
access the same (obj_type, attr_name) key. Events with different
keys on the same line get separate entries so that filtering can
distinguish them later (e.g. an attribute read vs an I/O event).
"""
if not events:
return []
result: list[SourceLineEvent] = []
prev_tid = -1
prev_lineno = -1
prev_filename = ""
prev_key: tuple[str | None, str | None] = (None, None)
for ev in events:
same_line = ev.thread_id == prev_tid and ev.lineno == prev_lineno and ev.filename == prev_filename
ev_key = (ev.obj_type_name, ev.attr_name)
same_key = ev_key == prev_key
if same_line and same_key and result:
last = result[-1]
# Merge access types
if last.access_type != ev.access_type and ev.access_type is not None:
if last.access_type is None:
last.access_type = ev.access_type
elif last.access_type == "read" and ev.access_type == "write":
last.access_type = "read+write"
elif last.access_type == "write" and ev.access_type == "read":
last.access_type = "read+write"
# Prefer more specific attr info
if ev.attr_name is not None and last.attr_name is None:
last.attr_name = ev.attr_name
if ev.obj_type_name is not None and last.obj_type_name is None:
last.obj_type_name = ev.obj_type_name
else:
source_line = linecache.getline(ev.filename, ev.lineno).strip()
result.append(
SourceLineEvent(
thread_id=ev.thread_id,
filename=ev.filename,
lineno=ev.lineno,
function_name=ev.function_name,
source_line=source_line,
access_type=ev.access_type,
attr_name=ev.attr_name,
obj_type_name=ev.obj_type_name,
call_chain=ev.call_chain,
)
)
prev_tid = ev.thread_id
prev_lineno = ev.lineno
prev_filename = ev.filename
prev_key = ev_key
return result
# ---------------------------------------------------------------------------
# Conflict pattern classification
# ---------------------------------------------------------------------------
[docs]
@dataclass
class ConflictInfo:
"""Description of the conflict pattern found in the trace."""
pattern: str # "lost_update", "stale_read", "write_write", "order_violation", "unknown"
summary: str # One-line human-readable explanation
attr_name: str | None = None # attribute involved, if identifiable
[docs]
def classify_conflict(events: list[SourceLineEvent]) -> ConflictInfo:
"""Examine a filtered, deduplicated trace and classify the conflict type.
Looks for classic patterns:
- Lost update: R_a R_b W_a W_b (or R_a R_b W_b W_a)
- Write-write: W_a W_b on same attribute without intervening sync
"""
if not events:
return ConflictInfo(pattern="unknown", summary="No shared-state accesses recorded.")
# Track per-attribute access sequences across threads
# Group by (obj_type, attr_name), look for cross-thread read-before-write patterns
attr_accesses: dict[str, list[tuple[int, str]]] = {} # attr -> [(thread_id, access_type), ...]
io_attrs: set[str] = set() # attributes that come from I/O events
for ev in events:
key = ev.attr_name or "(unknown)"
attr_accesses.setdefault(key, []).append((ev.thread_id, ev.access_type or "unknown"))
if ev.obj_type_name == "IO":
io_attrs.add(key)
# Process non-I/O attributes first so Python-level conflicts take
# priority over raw socket-level ones in the summary line.
io_fallback: ConflictInfo | None = None
for attr, accesses in attr_accesses.items():
threads_involved = sorted({tid for tid, _ in accesses})
if len(threads_involved) < 2:
continue
is_io = attr in io_attrs
# Check for lost-update pattern: two threads both read before either writes
# Pattern: R_a ... R_b ... W_a ... W_b (or W_b before W_a)
first_read: dict[int, int] = {} # thread -> index of first read
first_write: dict[int, int] = {} # thread -> index of first write
for i, (tid, atype) in enumerate(accesses):
if atype in ("read", "read+write") and tid not in first_read:
first_read[tid] = i
if atype in ("write", "read+write") and tid not in first_write:
first_write[tid] = i
# Look for pairs where both threads read before either writes
for t_a in threads_involved:
for t_b in threads_involved:
if t_a >= t_b:
continue
r_a = first_read.get(t_a)
r_b = first_read.get(t_b)
w_a = first_write.get(t_a)
w_b = first_write.get(t_b)
if r_a is not None and r_b is not None and w_a is not None and w_b is not None:
# Both read before both write?
writes_start = min(w_a, w_b)
if r_a < writes_start and r_b < writes_start:
obj_desc = attr
if is_io:
io_fallback = io_fallback or ConflictInfo(
pattern="lost_update",
summary=(
f"Lost update via database I/O: threads {t_a} and {t_b} "
f"both queried {obj_desc} before either committed."
),
attr_name=attr,
)
else:
return ConflictInfo(
pattern="lost_update",
summary=(
f"Lost update: threads {t_a} and {t_b} both read "
f"{obj_desc} before either wrote it back."
),
attr_name=attr,
)
# Check for write-write without reads (simple overwrite)
for t_a in threads_involved:
for t_b in threads_involved:
if t_a >= t_b:
continue
w_a = first_write.get(t_a)
w_b = first_write.get(t_b)
if w_a is not None and w_b is not None:
if is_io:
io_fallback = io_fallback or ConflictInfo(
pattern="write_write",
summary=(f"Concurrent database I/O: threads {t_a} and {t_b} both sent queries to {attr}."),
attr_name=attr,
)
else:
return ConflictInfo(
pattern="write_write",
summary=f"Write-write conflict: threads {t_a} and {t_b} both wrote to {attr}.",
attr_name=attr,
)
if io_fallback is not None:
return io_fallback
# Fallback: we recorded shared accesses but couldn't classify the pattern
all_threads = sorted({ev.thread_id for ev in events})
return ConflictInfo(
pattern="unknown",
summary=f"Race condition involving threads {', '.join(map(str, all_threads))}.",
)
# ---------------------------------------------------------------------------
# Trace condensation
# ---------------------------------------------------------------------------
[docs]
@dataclass(slots=True)
class CollapsedRun:
"""Placeholder for a collapsed sequence of events from one thread."""
count: int
thread_id: int
def _find_conflicting_keys(events: list[SourceLineEvent]) -> set[tuple[str | None, str | None]]:
"""Find (obj_type, attr_name) pairs accessed by multiple threads with at least one write."""
key_threads: dict[tuple[str | None, str | None], set[int]] = {}
key_has_write: set[tuple[str | None, str | None]] = set()
for ev in events:
key = (ev.obj_type_name, ev.attr_name)
key_threads.setdefault(key, set()).add(ev.thread_id)
if ev.access_type in ("write", "read+write"):
key_has_write.add(key)
return {key for key, tids in key_threads.items() if len(tids) > 1 and key in key_has_write}
def _collapse_runs(lines: list[SourceLineEvent], *, max_lines: int) -> list[SourceLineEvent | CollapsedRun]:
"""Collapse consecutive same-thread events, keeping first and last of each run."""
if not lines:
return []
# Group into runs of consecutive events from the same thread
runs: list[tuple[int, list[SourceLineEvent]]] = []
current_tid = -1
current_run: list[SourceLineEvent] = []
for ev in lines:
if ev.thread_id != current_tid:
if current_run:
runs.append((current_tid, current_run))
current_tid = ev.thread_id
current_run = [ev]
else:
current_run.append(ev)
if current_run:
runs.append((current_tid, current_run))
result: list[SourceLineEvent | CollapsedRun] = []
for tid, run in runs:
if len(run) <= 3:
result.extend(run)
else:
result.append(run[0])
result.append(CollapsedRun(count=len(run) - 2, thread_id=tid))
result.append(run[-1])
# Final cap: if still too long, take first half + last half
if len(result) > max_lines:
half = max_lines // 2
omitted = len(result) - max_lines
result = result[:half] + [CollapsedRun(count=omitted, thread_id=-1)] + result[-half:]
return result
def _merge_consecutive(events: list[SourceLineEvent]) -> list[SourceLineEvent]:
"""Merge consecutive same-thread same-line events after filtering.
After conflict-key filtering removes irrelevant events, previously
non-adjacent entries with the same thread+line may become neighbours.
This pass collapses them just like :func:`deduplicate_to_source_lines`.
"""
if not events:
return []
result: list[SourceLineEvent] = [events[0]]
for ev in events[1:]:
prev = result[-1]
if ev.thread_id == prev.thread_id and ev.lineno == prev.lineno and ev.filename == prev.filename:
if ev.access_type in ("write", "read+write") and prev.access_type == "read":
prev.access_type = "read+write"
elif ev.access_type in ("read", "read+write") and prev.access_type == "write":
prev.access_type = "read+write"
else:
result.append(ev)
return result
[docs]
def condense_trace(lines: list[SourceLineEvent], *, max_lines: int = 30) -> list[SourceLineEvent | CollapsedRun]:
"""Condense a trace to show only the essential interleaving.
Strategy:
1. Always filter to events involved in cross-thread data conflicts
(same attribute accessed by 2+ threads with at least one write).
After filtering, re-merge consecutive same-line events that were
previously separated by now-removed entries.
2. If still too long, collapse single-thread runs (keep first/last).
3. Cap at ``max_lines``.
Returns a mixed list of :class:`SourceLineEvent` and :class:`CollapsedRun`
placeholders for the formatter to render.
"""
# Always try to filter to cross-thread conflicting attributes —
# this removes method lookups, lock accesses, and other noise
# regardless of overall trace length.
conflicting_keys = _find_conflicting_keys(lines)
if conflicting_keys:
filtered = [ev for ev in lines if (ev.obj_type_name, ev.attr_name) in conflicting_keys]
if filtered:
merged = _merge_consecutive(filtered)
if len(merged) <= max_lines:
return list(merged)
lines = merged
if len(lines) <= max_lines:
return list(lines)
# Strategy 2: collapse single-thread runs
return _collapse_runs(lines, max_lines=max_lines)
# ---------------------------------------------------------------------------
# Formatting
# ---------------------------------------------------------------------------
def _format_no_shared_accesses(events: list[TraceEvent], *, num_explored: int = 0) -> str:
"""Fallback when no shared-state accesses were detected."""
if num_explored > 0:
return f"Race condition found after {num_explored} interleavings (no shared-state accesses recorded).\n"
return "Race condition found (no shared-state accesses recorded).\n"
def _short_filename(path: str) -> str:
"""Convert an absolute path to a short display name."""
import os
basename = os.path.basename(path)
return basename