"""Classes for day20."""
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Flag
from graphviz import Digraph
[docs]
class Pulse(Flag):
"""Simple True/False enum."""
LOW = False
HIGH = True
def __str__(self) -> str:
"""Custom str to show ``low`` and ``high``."""
if self.name is not None:
return self.name.lower()
raise AssertionError("no valid value for this Pulse")
[docs]
@dataclass
class MappingModule:
"""map to a list of outputs."""
name: str
outputs: list[str]
[docs]
@dataclass
class PulseTarget:
"""A pulse(low/high) from src to dest."""
pulse: Pulse
src: str
target: str
def __str__(self) -> str:
"""Custom arrow string to match problem description."""
return f"{self.src} -{self.pulse}-> {self.target}"
[docs]
@dataclass
class BaseModule(ABC):
"""Abstract base module."""
name: str
outputs: list[str] = field(repr=False)
num_low: int = field(init=False, default=0)
num_high: int = field(init=False, default=0)
def __post_init__(self) -> None:
"""Stop users from constructing BaseModule."""
if self.__class__ == BaseModule:
raise AssertionError("Cannot instantiate abstract class")
[docs]
def arrow_color(self) -> str:
"""Return arrow color for graphviz."""
return "#000000"
[docs]
def handle_pulse(self, input: str, pulse: Pulse) -> list[PulseTarget]:
"""Keep track of lows/highs through all modules."""
if pulse == Pulse.LOW:
self.num_low += 1
else:
self.num_high += 1
return []
[docs]
def add_to_graph(self, dot: Digraph) -> None:
"""Adds edges only to the graph. inheritors need to handle their repr."""
attrs = {"color": self.arrow_color()}
for output in self.outputs:
dot.edge(self.name, output, **attrs)
[docs]
@abstractmethod
def is_initial_state(self) -> bool:
"""Returns if the module is in the initial state."""
raise AssertionError("Implement me")
[docs]
@dataclass
class FlipFlopModule(BaseModule):
"""If we receive HIGH, we are a sink (do nothing).
If we receive LOW, flip our current value and send it to everyone
"""
state: Pulse = Pulse.LOW
[docs]
def handle_pulse(self, input: str, pulse: Pulse) -> list[PulseTarget]:
"""Handle pulse by forwarding if we receive low."""
super().handle_pulse(input, pulse)
if pulse:
return []
# if we receive low, become the opposite and send it.
# if we receive high, do ont send, do not modify state.
self.state = Pulse(not self.state)
return [PulseTarget(self.state, self.name, target) for target in self.outputs]
[docs]
def add_to_graph(self, dot: Digraph) -> None:
"""Adds ourselves to a graphviz digraph."""
attrs = {"shape": "box", "style": "filled"}
if self.state == Pulse.LOW:
attrs["fillcolor"] = "#FF0000"
else:
attrs["fillcolor"] = "#0000FF"
dot.node(self.name, **attrs)
super().add_to_graph(dot)
[docs]
def is_initial_state(self) -> bool:
"""Returns true if we are in our initial state."""
return self.state == Pulse.LOW
[docs]
@dataclass
class ConjunctionModule(BaseModule):
"""Keeps track of all inputs.
Changes internal state, then sends high/low based on internal state
"""
inputs: dict[str, Pulse] = field(init=False, repr=False)
[docs]
def arrow_color(self) -> str:
"""Returns red."""
return "#FF0000"
[docs]
def handle_pulse(self, input: str, pulse: Pulse) -> list[PulseTarget]:
"""Store pulse, then send based on current state.
If all our values are high, we send low to all our outputs.
Otherwise, we send ``LOW`` to all our outputs.
"""
super().handle_pulse(input, pulse)
self.inputs[input] = pulse
if all(self.inputs.values()):
return [
PulseTarget(Pulse.LOW, self.name, target) for target in self.outputs
]
return [PulseTarget(Pulse.HIGH, self.name, target) for target in self.outputs]
[docs]
def add_to_graph(self, dot: Digraph) -> None:
"""Add this module to a GraphViz Digraph."""
count = self.current_count()
length = len(list(self.inputs.values()))
label = f"{self.name} {count}/{length}"
dot.node(self.name, label=label)
super().add_to_graph(dot)
[docs]
def current_count(self) -> int:
"""Returns current count of inputs that sent ``high``."""
return list(self.inputs.values()).count(Pulse.HIGH)
[docs]
def is_initial_state(self) -> bool:
"""Returns True if all our inputs are LOW."""
return self.current_count() == 0
[docs]
@dataclass
class BroadcastModule(BaseModule):
"""Broadcasts to all outputs."""
[docs]
def handle_pulse(self, input: str, pulse: Pulse) -> list[PulseTarget]:
"""Broadcasts to all outputs immediately."""
super().handle_pulse(input, pulse)
return [PulseTarget(pulse, self.name, target) for target in self.outputs]
[docs]
def add_to_graph(self, dot: Digraph) -> None:
"""Add node to graphviz digraph."""
dot.node(self.name)
super().add_to_graph(dot)
[docs]
def is_initial_state(self) -> bool:
"""Always true."""
return True
[docs]
@dataclass
class SinkModule(BaseModule):
"""Sink module, gets something but sends it no where else."""
[docs]
def handle_pulse(self, input: str, pulse: Pulse) -> list[PulseTarget]:
"""Always eats inputs and never sends onwards."""
super().handle_pulse(input, pulse)
return []
[docs]
def add_to_graph(self, dot: Digraph) -> None:
"""Adds this node to the graph."""
dot.node(self.name)
super().add_to_graph(dot)
[docs]
def is_initial_state(self) -> bool:
"""Always true."""
return True
[docs]
@dataclass
class ModuleGroups:
"""A group of modules for part2."""
head: BroadcastModule
loops: list[list[BaseModule]]
loop_tails: list[ConjunctionModule]
penultimate: ConjunctionModule
sink: SinkModule
all_nodes: list[BaseModule] = field(init=False)
def __post_init__(self) -> None:
"""Sets up our ``all_nodes`` property."""
all_nodes: list[BaseModule] = []
all_nodes.append(self.head)
all_nodes.extend(list(itertools.chain(*self.loops)))
all_nodes.extend(self.loop_tails)
all_nodes.append(self.penultimate)
all_nodes.append(self.sink)
self.all_nodes = all_nodes
[docs]
@dataclass
class LoopCounter:
"""Keeps track of loop lengths."""
target_loop_count: int = field(repr=False)
loop_lengths: dict[str, int] = field(default_factory=dict, init=False)
@property
def num_results(self) -> int:
"""Returns number of loop_lenghts submitted."""
return len(self.loop_lengths)
@property
def finished(self) -> bool:
"""Returns True if our loop_lengths are equal to our target."""
return self.num_results == self.target_loop_count
[docs]
def add_result(self, loop_name: str, value: int) -> None:
"""Adds a result to our loop_count.
If we already had a loop with that name, we ignore it.
"""
if loop_name not in self.loop_lengths:
self.loop_lengths[loop_name] = value