Skip to content

openg2g.controller

openg2g.controller.base

Abstract base class for controllers.

Controller

Bases: Generic[DCBackendT, GridBackendT], ABC

Interface for a control component in the G2G framework.

Controllers receive datacenter and grid state and produce control actions. Multiple controllers compose in order within the coordinator.

Source code in openg2g/controller/base.py
class Controller(Generic[DCBackendT, GridBackendT], ABC):
    """Interface for a control component in the G2G framework.

    Controllers receive datacenter and grid state and produce control actions.
    Multiple controllers compose in order within the coordinator.
    """

    _dc_types: tuple[type[DatacenterBackend], ...] = (DatacenterBackend,)
    _grid_types: tuple[type[GridBackend], ...] = (GridBackend,)

    def __init_subclass__(cls, **kwargs: object) -> None:
        super().__init_subclass__(**kwargs)
        dc_types: tuple[type[DatacenterBackend], ...] | None = None
        grid_types: tuple[type[GridBackend], ...] | None = None
        for base in getattr(cls, "__orig_bases__", ()):
            if get_origin(base) is Controller:
                args = get_args(base)
                if len(args) != 2:
                    raise TypeError(
                        f"{cls.__name__} must specialize Controller with two generic args: "
                        "Controller[DatacenterType, GridType]."
                    )
                dc_raw, grid_raw = args
                dc_norm = _normalize_backend_type_arg(dc_raw, required_base=DatacenterBackend)
                grid_norm = _normalize_backend_type_arg(grid_raw, required_base=GridBackend)
                dc_types = tuple(t for t in dc_norm if issubclass(t, DatacenterBackend))
                grid_types = tuple(t for t in grid_norm if issubclass(t, GridBackend))
                break

        if dc_types is None or grid_types is None:
            inherited = [b for b in cls.__bases__ if issubclass(b, Controller)]
            inherited = [b for b in inherited if b is not Controller]
            if inherited:
                parent = inherited[0]
                cls._dc_types = parent.compatible_datacenter_types()
                cls._grid_types = parent.compatible_grid_types()
                return
            raise TypeError(
                f"{cls.__name__} must explicitly specialize Controller generics as "
                "Controller[DatacenterType, GridType]."
            )

        cls._dc_types = dc_types
        cls._grid_types = grid_types

    @final
    @classmethod
    def compatible_datacenter_types(cls) -> tuple[type[DatacenterBackend], ...]:
        return cls._dc_types

    @final
    @classmethod
    def compatible_grid_types(cls) -> tuple[type[GridBackend], ...]:
        return cls._grid_types

    @final
    @classmethod
    def compatibility_signature(cls) -> str:
        dc = " | ".join(t.__name__ for t in cls.compatible_datacenter_types())
        grid = " | ".join(t.__name__ for t in cls.compatible_grid_types())
        return f"Controller[{dc}, {grid}]"

    @property
    @abstractmethod
    def dt_s(self) -> Fraction:
        """Control interval as a Fraction (seconds)."""

    @abstractmethod
    def reset(self) -> None:
        """Reset simulation state to initial conditions.

        Called by the coordinator before each [`start`][..start]. Must
        clear all simulation state: dual variables, counters, cached
        matrices. Configuration (dt_s, fits, step sizes) is not
        affected.

        Abstract so every implementation explicitly enumerates its state.
        A forgotten field is a bug -- not clearing it silently corrupts
        the second run.
        """

    def start(self) -> None:
        """Acquire per-run resources.

        Called after [`reset`][..reset], before the simulation loop.
        No-op by default because most controllers have no resources to
        acquire.
        """

    def stop(self) -> None:
        """Release per-run resources. Simulation state is preserved.

        Called after the simulation loop in LIFO order. No-op by default.
        """

    @abstractmethod
    def step(
        self,
        clock: SimulationClock,
        datacenter: DCBackendT,
        grid: GridBackendT,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        """Compute control commands for this step. Return an empty list for no-op."""

dt_s abstractmethod property

Control interval as a Fraction (seconds).

reset() abstractmethod

Reset simulation state to initial conditions.

Called by the coordinator before each start. Must clear all simulation state: dual variables, counters, cached matrices. Configuration (dt_s, fits, step sizes) is not affected.

Abstract so every implementation explicitly enumerates its state. A forgotten field is a bug -- not clearing it silently corrupts the second run.

Source code in openg2g/controller/base.py
@abstractmethod
def reset(self) -> None:
    """Reset simulation state to initial conditions.

    Called by the coordinator before each [`start`][..start]. Must
    clear all simulation state: dual variables, counters, cached
    matrices. Configuration (dt_s, fits, step sizes) is not
    affected.

    Abstract so every implementation explicitly enumerates its state.
    A forgotten field is a bug -- not clearing it silently corrupts
    the second run.
    """

start()

Acquire per-run resources.

Called after reset, before the simulation loop. No-op by default because most controllers have no resources to acquire.

Source code in openg2g/controller/base.py
def start(self) -> None:
    """Acquire per-run resources.

    Called after [`reset`][..reset], before the simulation loop.
    No-op by default because most controllers have no resources to
    acquire.
    """

stop()

Release per-run resources. Simulation state is preserved.

Called after the simulation loop in LIFO order. No-op by default.

Source code in openg2g/controller/base.py
def stop(self) -> None:
    """Release per-run resources. Simulation state is preserved.

    Called after the simulation loop in LIFO order. No-op by default.
    """

step(clock, datacenter, grid, events) abstractmethod

Compute control commands for this step. Return an empty list for no-op.

Source code in openg2g/controller/base.py
@abstractmethod
def step(
    self,
    clock: SimulationClock,
    datacenter: DCBackendT,
    grid: GridBackendT,
    events: EventEmitter,
) -> list[DatacenterCommand | GridCommand]:
    """Compute control commands for this step. Return an empty list for no-op."""

openg2g.controller.batch_size_schedule

Batch size schedule controller: applies pre-defined batch size changes at specified times.

BatchSizeChange dataclass

A batch size change event, optionally with gradual ramp-up.

Attributes:

Name Type Description
batch_size int

Target batch size (max_num_seqs).

ramp_up_rate float

Requests/second ramp-up rate. 0 means immediate.

Source code in openg2g/controller/batch_size_schedule.py
@dataclass(frozen=True)
class BatchSizeChange:
    """A batch size change event, optionally with gradual ramp-up.

    Attributes:
        batch_size: Target batch size (max_num_seqs).
        ramp_up_rate: Requests/second ramp-up rate. 0 means immediate.
    """

    batch_size: int
    ramp_up_rate: float = 0.0

    def __post_init__(self) -> None:
        if self.batch_size <= 0:
            raise ValueError(f"batch_size must be positive, got {self.batch_size}.")
        if self.ramp_up_rate < 0:
            raise ValueError(f"ramp_up_rate must be >= 0, got {self.ramp_up_rate}.")

    def at(self, t: float) -> BatchSizeSchedule:
        """Schedule this change at time *t* seconds.

        Returns:
            A single-entry [`BatchSizeSchedule`][...BatchSizeSchedule].
        """
        return BatchSizeSchedule(((t, self),))

at(t)

Schedule this change at time t seconds.

Returns:

Type Description
BatchSizeSchedule

A single-entry BatchSizeSchedule.

Source code in openg2g/controller/batch_size_schedule.py
def at(self, t: float) -> BatchSizeSchedule:
    """Schedule this change at time *t* seconds.

    Returns:
        A single-entry [`BatchSizeSchedule`][...BatchSizeSchedule].
    """
    return BatchSizeSchedule(((t, self),))

BatchSizeSchedule

Ordered sequence of batch size changes, built with | operator.

Example:

schedule = (
    BatchSizeChange(48).at(40)
    | BatchSizeChange(32).at(60)
    | BatchSizeChange(48, ramp_up_rate=4).at(280)
)

Raises:

Type Description
ValueError

If two entries share the same timestamp.

Source code in openg2g/controller/batch_size_schedule.py
class BatchSizeSchedule:
    """Ordered sequence of batch size changes, built with `|` operator.

    Example:

    ```python
    schedule = (
        BatchSizeChange(48).at(40)
        | BatchSizeChange(32).at(60)
        | BatchSizeChange(48, ramp_up_rate=4).at(280)
    )
    ```

    Raises:
        ValueError: If two entries share the same timestamp.
    """

    __slots__ = ("_entries",)

    def __init__(self, entries: tuple[tuple[float, BatchSizeChange], ...]) -> None:
        self._entries = tuple(sorted(entries, key=lambda e: e[0]))
        times = [t for t, _ in self._entries]
        if len(times) != len(set(times)):
            seen: set[float] = set()
            dupes = sorted({t for t in times if t in seen or seen.add(t)})
            raise ValueError(f"BatchSizeSchedule has duplicate timestamps: {dupes}")

    def __or__(self, other: BatchSizeSchedule) -> BatchSizeSchedule:
        return BatchSizeSchedule(self._entries + other._entries)

    def __iter__(self) -> Iterator[tuple[float, BatchSizeChange]]:
        return iter(self._entries)

    def __len__(self) -> int:
        return len(self._entries)

    def __bool__(self) -> bool:
        return bool(self._entries)

    def __repr__(self) -> str:
        parts: list[str] = []
        for t, c in self._entries:
            ramp = f", ramp_up_rate={c.ramp_up_rate}" if c.ramp_up_rate > 0 else ""
            parts.append(f"BatchSizeChange({c.batch_size}{ramp}).at(t={t})")
        return " | ".join(parts)

BatchSizeScheduleController

Bases: Controller[DatacenterBackend, GridBackend]

Applies pre-defined batch size changes at scheduled times.

Walks each model's schedule and emits SetBatchSize commands when the simulation clock reaches the scheduled time.

Parameters:

Name Type Description Default
schedules dict[str, BatchSizeSchedule]

Per-model batch size schedules, keyed by model label.

required
dt_s Fraction

How often the controller checks the schedule (seconds).

Fraction(1)
Source code in openg2g/controller/batch_size_schedule.py
class BatchSizeScheduleController(Controller[DatacenterBackend, GridBackend]):
    """Applies pre-defined batch size changes at scheduled times.

    Walks each model's schedule and emits
    [`SetBatchSize`][openg2g.datacenter.command.SetBatchSize] commands when the
    simulation clock reaches the scheduled time.

    Args:
        schedules: Per-model batch size schedules, keyed by model label.
        dt_s: How often the controller checks the schedule (seconds).
    """

    def __init__(
        self,
        *,
        schedules: dict[str, BatchSizeSchedule],
        dt_s: Fraction = Fraction(1),
    ) -> None:
        self._dt_s = dt_s
        self._schedules = dict(schedules)
        self._indices: dict[str, int] = {label: 0 for label in schedules}

    def reset(self) -> None:
        self._indices = {label: 0 for label in self._schedules}

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        datacenter: DatacenterBackend,
        grid: GridBackend,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        t_now = clock.time_s
        batch_changes: dict[str, int] = {}
        ramp_rates: dict[str, float] = {}

        for label, schedule in self._schedules.items():
            entries = list(schedule)
            idx = self._indices[label]

            while idx < len(entries):
                t_ev, change = entries[idx]
                if float(t_ev) <= t_now + 1e-12:
                    batch_changes[label] = change.batch_size
                    if change.ramp_up_rate > 0:
                        ramp_rates[label] = change.ramp_up_rate
                    idx += 1
                else:
                    break

            self._indices[label] = idx

        if batch_changes:
            events.emit(
                "controller.batch_schedule.fired",
                {"batch_size_by_model": batch_changes},
            )
            return [
                SetBatchSize(
                    batch_size_by_model=batch_changes,
                    ramp_up_rate_by_model=ramp_rates,
                )
            ]
        return []

openg2g.controller.load_shift

Cross-site LLM load shifting controller.

Shifts replicas between datacenters when batch-size control (OFO) is exhausted and voltage violations persist. Runs after all per-site OFO controllers in the coordinator loop.

LoadShiftConfig

Bases: BaseModel

Configuration for cross-site load shifting.

Source code in openg2g/controller/load_shift.py
class LoadShiftConfig(BaseModel):
    """Configuration for cross-site load shifting."""

    enabled: bool = False
    gpus_per_shift: int = 8
    headroom: float = 0.3
    """Fraction of extra server capacity to pre-allocate at each DC
    so incoming replicas have room (e.g. 0.3 = 30% headroom)."""

headroom = 0.3 class-attribute instance-attribute

Fraction of extra server capacity to pre-allocate at each DC so incoming replicas have room (e.g. 0.3 = 30% headroom).

LoadShiftController

Bases: Controller[LLMBatchSizeControlledDatacenter, OpenDSSGrid]

Shift LLM replicas between datacenters to resolve voltage violations.

Rules: 1. Only shift models already running at both source and destination. 2. Only act when batch sizes are saturated AND violation persists. 3. For undervoltage: shift load OUT of violated site → highest-voltage site. For overvoltage: shift load INTO violated site ← lowest-voltage site. 4. Shift gpus_per_shift GPUs worth of replicas per time step. 5. Repeat until violation resolves or no candidates remain.

Source code in openg2g/controller/load_shift.py
class LoadShiftController(Controller[LLMBatchSizeControlledDatacenter, OpenDSSGrid]):
    """Shift LLM replicas between datacenters to resolve voltage violations.

    Rules:
    1. Only shift models already running at both source and destination.
    2. Only act when batch sizes are saturated AND violation persists.
    3. For undervoltage: shift load OUT of violated site → highest-voltage site.
       For overvoltage: shift load INTO violated site ← lowest-voltage site.
    4. Shift ``gpus_per_shift`` GPUs worth of replicas per time step.
    5. Repeat until violation resolves or no candidates remain.
    """

    def __init__(
        self,
        *,
        config: LoadShiftConfig,
        dt_s: Fraction,
        datacenters: dict[str, LLMBatchSizeControlledDatacenter],
        site_bus_map: dict[str, str],
        models_by_site: dict[str, list[str]],
        gpus_per_replica_by_model: dict[str, int],
        feasible_batch_sizes_by_model: dict[str, list[int]],
        v_min: float = 0.95,
        v_max: float = 1.05,
    ) -> None:
        self._config = config
        self._dt_s = dt_s
        self._datacenters = datacenters  # site_id -> datacenter
        self._site_bus_map = site_bus_map  # site_id -> bus name
        self._models_by_site = models_by_site  # site_id -> [model_labels]
        self._gpus_per_replica = gpus_per_replica_by_model
        self._feasible_bs = feasible_batch_sizes_by_model
        self._v_min = v_min
        self._v_max = v_max
        self._step_count = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    @property
    def site_id(self) -> str | None:
        return None  # cross-site controller

    def step(
        self,
        clock: SimulationClock,
        datacenter: LLMBatchSizeControlledDatacenter,
        grid: OpenDSSGrid,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        self._step_count += 1
        if not self._config.enabled:
            return []

        # Build bus -> min voltage mapping from grid
        voltages = grid.voltages_vector()
        v_index = grid.v_index  # list of (bus, phase)
        bus_voltages: dict[str, list[float]] = {}
        for (bus, _phase), v in zip(v_index, voltages, strict=False):
            bus_voltages.setdefault(bus.lower(), []).append(float(v))

        # Per-site min/max voltage
        site_vmin: dict[str, float] = {}
        site_vmax: dict[str, float] = {}
        for site_id, bus in self._site_bus_map.items():
            vs = bus_voltages.get(bus.lower(), [])
            if vs:
                site_vmin[site_id] = min(vs)
                site_vmax[site_id] = max(vs)

        commands: list[DatacenterCommand | GridCommand] = []

        # Check each site for violations
        for site_id in list(self._site_bus_map.keys()):
            vmin = site_vmin.get(site_id, 1.0)
            vmax = site_vmax.get(site_id, 1.0)

            is_undervoltage = vmin < self._v_min
            is_overvoltage = vmax > self._v_max

            if not is_undervoltage and not is_overvoltage:
                continue

            # Check if batch sizes are saturated at this site
            if not self._is_batch_saturated(site_id, datacenter, grid, is_undervoltage):
                continue

            # Find best destination and model to shift
            site_models = set(self._models_by_site.get(site_id, []))

            if is_undervoltage:
                # Shift load OUT: pick destination with highest voltage AND available capacity
                best_dest = None
                best_v = -1.0
                for other_id in self._site_bus_map:
                    if other_id == site_id:
                        continue
                    other_models = set(self._models_by_site.get(other_id, []))
                    shared = site_models & other_models
                    if not shared:
                        continue
                    dest_dc = self._datacenters.get(other_id)
                    if dest_dc is not None and dest_dc.available_gpu_capacity() < self._config.gpus_per_shift:
                        continue  # no room at this destination
                    ov = site_vmin.get(other_id, 0.0)
                    if ov > best_v:
                        best_v = ov
                        best_dest = other_id

                if best_dest is None:
                    continue

                shared_models = site_models & set(self._models_by_site.get(best_dest, []))
                model = self._pick_model(shared_models)
                if model is None:
                    continue

                replicas = max(1, self._config.gpus_per_shift // self._gpus_per_replica[model])
                commands.append(
                    ShiftReplicas(
                        model_label=model,
                        replica_delta=-replicas,
                        target_site_id=site_id,
                    )
                )
                commands.append(
                    ShiftReplicas(
                        model_label=model,
                        replica_delta=+replicas,
                        target_site_id=best_dest,
                    )
                )
                logger.info(
                    "LoadShift: undervoltage at %s (Vmin=%.4f), shift %s ×%d replicas -> %s (Vmin=%.4f, free=%d GPUs)",
                    site_id,
                    vmin,
                    model,
                    replicas,
                    best_dest,
                    best_v,
                    self._datacenters[best_dest].available_gpu_capacity() if best_dest in self._datacenters else -1,
                )

            elif is_overvoltage:
                # Shift load IN: check this site has capacity, pick source with lowest voltage
                site_dc = self._datacenters.get(site_id)
                if site_dc is not None and site_dc.available_gpu_capacity() < self._config.gpus_per_shift:
                    continue  # violated site is full, can't accept more load

                best_src = None
                best_v = 2.0
                for other_id in self._site_bus_map:
                    if other_id == site_id:
                        continue
                    other_models = set(self._models_by_site.get(other_id, []))
                    shared = site_models & other_models
                    if not shared:
                        continue
                    ov = site_vmax.get(other_id, 2.0)
                    if ov < best_v:
                        best_v = ov
                        best_src = other_id

                if best_src is None:
                    continue

                shared_models = site_models & set(self._models_by_site.get(best_src, []))
                model = self._pick_model(shared_models)
                if model is None:
                    continue

                replicas = max(1, self._config.gpus_per_shift // self._gpus_per_replica[model])
                commands.append(
                    ShiftReplicas(
                        model_label=model,
                        replica_delta=-replicas,
                        target_site_id=best_src,
                    )
                )
                commands.append(
                    ShiftReplicas(
                        model_label=model,
                        replica_delta=+replicas,
                        target_site_id=site_id,
                    )
                )
                logger.info(
                    "LoadShift: overvoltage at %s (Vmax=%.4f), shift %s ×%d replicas <- %s (Vmax=%.4f)",
                    site_id,
                    vmax,
                    model,
                    replicas,
                    best_src,
                    best_v,
                )

        return commands

    def _is_batch_saturated(
        self,
        site_id: str,
        datacenter: LLMBatchSizeControlledDatacenter,
        grid: OpenDSSGrid,
        is_undervoltage: bool,
    ) -> bool:
        """Check if all models at site have batch sizes at their limit.

        For undervoltage: saturated = all at minimum batch (can't reduce power further
        via OFO — batch is already at min so load can't be lowered).
        For overvoltage: saturated = all at maximum batch (can't increase power further).
        """
        dc = self._datacenters.get(site_id)
        if dc is None:
            return False
        state = dc.state
        if state is None:
            return False

        site_models = self._models_by_site.get(site_id, [])
        for model_label in site_models:
            current_bs = state.batch_size_by_model.get(model_label)
            feasible = self._feasible_bs.get(model_label, [])
            if not feasible or current_bs is None:
                continue
            if is_undervoltage:
                # Saturated = at minimum batch (OFO can't reduce power further)
                if current_bs > min(feasible):
                    return False  # Still room to reduce
            else:
                # Saturated = at maximum batch (OFO can't increase power further)
                if current_bs < max(feasible):
                    return False  # Still room to increase
        return True

    def _pick_model(self, shared_models: set[str]) -> str | None:
        """Pick the model with the most GPUs per replica (largest power impact)."""
        if not shared_models:
            return None
        return max(shared_models, key=lambda m: self._gpus_per_replica.get(m, 1))

    def reset(self) -> None:
        self._step_count = 0

    def start(self) -> None:
        pass

    def stop(self) -> None:
        pass

openg2g.controller.noop

No-op controller that does nothing.

NoopController

Bases: Controller[DatacenterBackend, GridBackend]

Controller that always returns an empty action.

Source code in openg2g/controller/noop.py
class NoopController(Controller[DatacenterBackend, GridBackend]):
    """Controller that always returns an empty action."""

    def __init__(self, dt_s: Fraction = Fraction(1)) -> None:
        self._dt_s = dt_s

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def reset(self) -> None:
        pass

    def step(
        self,
        clock: SimulationClock,
        datacenter: DatacenterBackend,
        grid: GridBackend,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        return []

openg2g.controller.ofo

Online Feedback Optimization (OFO) batch-size controller.

Implements the primal-dual algorithm for joint voltage regulation and latency management via GPU batch size control.

OFOConfig

Bases: BaseModel

Online Feedback Optimization tuning parameters.

Attributes:

Name Type Description
primal_step_size float

Primal descent step size ρ_x (Eq. 8).

w_throughput float

Throughput weight in primal gradient.

w_switch float

Switching cost regularizer weight γ (Eq. 4a).

voltage_gradient_scale float

Scaling factor k_v for voltage dual term in the primal gradient.

v_min float

Lower voltage bound (pu).

v_max float

Upper voltage bound (pu).

voltage_dual_step_size float

Voltage dual ascent step size ρ_v (Eqs. 5-6).

latency_dual_step_size float

Latency dual ascent step size ρ_l (Eq. 7).

sensitivity_update_interval int

Steps between H-matrix re-estimation (0 = only once at init).

sensitivity_perturbation_kw float

Perturbation magnitude (kW) for finite-difference sensitivity estimation.

Source code in openg2g/controller/ofo.py
class OFOConfig(BaseModel):
    """Online Feedback Optimization tuning parameters.

    Attributes:
        primal_step_size: Primal descent step size ρ_x (Eq. 8).
        w_throughput: Throughput weight in primal gradient.
        w_switch: Switching cost regularizer weight γ (Eq. 4a).
        voltage_gradient_scale: Scaling factor k_v for voltage dual term
            in the primal gradient.
        v_min: Lower voltage bound (pu).
        v_max: Upper voltage bound (pu).
        voltage_dual_step_size: Voltage dual ascent step size ρ_v (Eqs. 5-6).
        latency_dual_step_size: Latency dual ascent step size ρ_l (Eq. 7).
        sensitivity_update_interval: Steps between H-matrix re-estimation
            (0 = only once at init).
        sensitivity_perturbation_kw: Perturbation magnitude (kW) for
            finite-difference sensitivity estimation.
    """

    model_config = ConfigDict(frozen=True)

    # Primal
    primal_step_size: float = 0.05
    w_throughput: float = 0.1
    w_switch: float = 0.0
    voltage_gradient_scale: float = 1e6

    # Dual
    v_min: float = 0.95
    v_max: float = 1.05
    voltage_dual_step_size: float = 0.5
    latency_dual_step_size: float = 1.0

    # Sensitivity
    sensitivity_update_interval: int = 0
    sensitivity_perturbation_kw: float = 100.0

LogisticModelStore

Per-model logistic models for power, latency, and throughput.

Used by OFOBatchSizeController to compute gradients of the Lagrangian with respect to batch size.

Attributes:

Name Type Description
COL_MODEL_LABEL

Column name for model label in the CSV.

COL_METRIC

Column name for metric type in the CSV.

Source code in openg2g/controller/ofo.py
class LogisticModelStore:
    """Per-model logistic models for power, latency, and throughput.

    Used by
    [`OFOBatchSizeController`][openg2g.controller.ofo.OFOBatchSizeController]
    to compute gradients of the Lagrangian with respect to batch size.

    Attributes:
        COL_MODEL_LABEL: Column name for model label in the CSV.
        COL_METRIC: Column name for metric type in the CSV.
    """

    COL_MODEL_LABEL = "model_label"
    COL_METRIC = "metric"

    def __init__(
        self,
        power: dict[str, LogisticModel],
        latency: dict[str, LogisticModel],
        throughput: dict[str, LogisticModel],
    ) -> None:
        self._power = dict(power)
        self._latency = dict(latency)
        self._throughput = dict(throughput)
        self._by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] | None = None

    def power(self, model: str) -> LogisticModel:
        """Return the power logistic model for a model label."""
        return self._power[model]

    def latency(self, model: str) -> LogisticModel:
        """Return the latency logistic model for a model label."""
        return self._latency[model]

    def throughput(self, model: str) -> LogisticModel:
        """Return the throughput logistic model for a model label."""
        return self._throughput[model]

    @property
    def power_fits(self) -> dict[str, LogisticModel]:
        return dict(self._power)

    @property
    def latency_fits(self) -> dict[str, LogisticModel]:
        return dict(self._latency)

    @property
    def throughput_fits(self) -> dict[str, LogisticModel]:
        return dict(self._throughput)

    @classmethod
    def generate(
        cls,
        models: tuple[InferenceModelSpec, ...],
        data_sources: dict[str, Any],
        *,
        runs: Any = None,
        mlenergy_data_dir: Path | None = None,
    ) -> LogisticModelStore:
        """Generate logistic fits from ML.ENERGY benchmark data.

        Args:
            models: Model specifications.
            data_sources: Per-model `MLEnergySource` instances, keyed by
                `model_label`.
            runs: Pre-loaded `LLMRuns` object. If `None`, loads from
                `mlenergy_data_dir` or the HuggingFace Hub.
            mlenergy_data_dir: Path to compiled mlenergy-data directory.
                Ignored if `runs` is provided.

        Returns:
            A new `LogisticModelStore` with fitted logistic models.
        """
        if runs is None:
            unique_tasks = {src.task for src in data_sources.values()}
            if mlenergy_data_dir:
                runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
            else:
                runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
        if not runs:
            raise ValueError("No runs found for the specified tasks")

        subsets_by_label: dict[str, Any] = {}
        for ms in models:
            src = data_sources.get(ms.model_label)
            if src is None:
                raise ValueError(f"No data source for model {ms.model_label!r}")
            model_id = ms.model_id
            if not model_id:
                raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")

            subset = (
                runs.model_id(model_id).gpu_model(src.gpu).num_gpus(ms.gpus_per_replica).max_num_seqs(*src.batch_sizes)
            )
            if not subset:
                raise ValueError(
                    f"Config matched zero runs for logistic fits: model_id={model_id!r}, "
                    f"gpu={src.gpu!r}, num_gpus={ms.gpus_per_replica}, "
                    f"batch_sizes={src.batch_sizes}"
                )
            subsets_by_label[ms.model_label] = subset

        all_by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] = {}
        power: dict[str, LogisticModel] = {}
        latency: dict[str, LogisticModel] = {}
        throughput: dict[str, LogisticModel] = {}
        for model_label, group in subsets_by_label.items():
            exclude = set(data_sources[model_label].fit_exclude_batch_sizes)
            by_batch: dict[int, list[tuple[float, float, float]]] = {}
            for r in group:
                if r.max_num_seqs in exclude:
                    continue
                by_batch.setdefault(r.max_num_seqs, []).append(
                    (r.avg_power_watts, r.mean_itl_ms / 1000.0, r.output_throughput_tokens_per_sec)
                )
            all_by_batch[model_label] = by_batch

            batches = sorted(by_batch.keys())
            if not batches:
                continue

            x = np.log2(np.array(batches, dtype=float).clip(min=1))
            for _metric_name, idx, target in [
                ("power", 0, power),
                ("latency", 1, latency),
                ("throughput", 2, throughput),
            ]:
                y = np.array([float(np.median([t[idx] for t in by_batch[b]])) for b in batches])
                fit = LogisticModel.fit(x, y)
                target[model_label] = fit

        if not power and not latency and not throughput:
            raise ValueError("No logistic fit rows produced")
        store = cls(power=power, latency=latency, throughput=throughput)
        store._by_batch = all_by_batch
        return store

    def save(self, csv_path: Path, *, plot: bool = False) -> None:
        """Save logistic fits to a CSV.

        Args:
            csv_path: Output CSV path.
            plot: If `True`, also write a logistic fits plot to the
                same directory.
        """
        csv_path = Path(csv_path)
        csv_path.parent.mkdir(parents=True, exist_ok=True)
        rows: list[dict[str, Any]] = []
        for metric_name, fits in [("power", self._power), ("latency", self._latency), ("throughput", self._throughput)]:
            for label in sorted(fits):
                model = fits[label]
                rows.append(
                    {
                        self.COL_MODEL_LABEL: label,
                        self.COL_METRIC: metric_name,
                        "L": model.L,
                        "x0": model.x0,
                        "k": model.k,
                        "b0": model.b0,
                    }
                )
        pd.DataFrame(rows).to_csv(csv_path, index=False)

        by_batch = getattr(self, "_by_batch", None)
        if plot and by_batch is not None:
            model_labels = sorted(self._power.keys())
            _plot_logistic_fits(
                by_batch,
                self._power,
                self._latency,
                self._throughput,
                model_labels,
                csv_path.parent,
            )

    @classmethod
    def load(cls, csv_path: Path | str) -> LogisticModelStore:
        """Load power, latency, and throughput fits from a merged CSV.

        Expected columns: `model_label`, `metric`, plus the logistic
        model parameter columns (`L`, `x0`, `k`, `b0`).

        The `metric` column must contain `power`, `latency`, or
        `throughput` (case-insensitive).

        Args:
            csv_path: Path to the logistic fits CSV.
        """
        csv_path = Path(csv_path)
        df = pd.read_csv(csv_path)

        required_cols = [cls.COL_MODEL_LABEL, cls.COL_METRIC]
        missing = [c for c in required_cols if c not in df.columns]
        if missing:
            raise ValueError(f"{csv_path} missing columns: {missing}. Got: {list(df.columns)}")

        power: dict[str, LogisticModel] = {}
        latency: dict[str, LogisticModel] = {}
        throughput: dict[str, LogisticModel] = {}
        targets = {"power": power, "latency": latency, "throughput": throughput}
        for row in df.to_dict(orient="records"):
            metric = str(row[cls.COL_METRIC]).strip().lower()
            if metric in targets:
                targets[metric][str(row[cls.COL_MODEL_LABEL])] = LogisticModel.from_dict(row)

        if not power and not latency and not throughput:
            raise ValueError(f"No logistic model rows loaded from {csv_path}")
        return cls(power=power, latency=latency, throughput=throughput)

    @classmethod
    def ensure(
        cls,
        csv_path: Path,
        models: tuple[InferenceModelSpec, ...] | None = None,
        data_sources: dict[str, Any] | None = None,
        *,
        mlenergy_data_dir: Path | None = None,
        plot: bool = False,
    ) -> LogisticModelStore:
        """Load from `csv_path`, generating first if needed.

        Args:
            csv_path: Path to the logistic fits CSV.
            models: Model specifications. Required when no cached file exists.
            data_sources: Per-model `MLEnergySource` instances, keyed by
                `model_label`. Required when no cached file exists.
            mlenergy_data_dir: Path to compiled mlenergy-data directory.
            plot: If `True`, generate a logistic fits plot on generation.
        """
        csv_path = Path(csv_path)
        if not csv_path.exists():
            if models is None or data_sources is None:
                raise ValueError("models and data_sources required for LogisticModelStore generation (no cached data)")
            logger.info("Generating logistic fits to %s ...", csv_path)
            cls.generate(models, data_sources, mlenergy_data_dir=mlenergy_data_dir).save(csv_path, plot=plot)
        return cls.load(csv_path)

power(model)

Return the power logistic model for a model label.

Source code in openg2g/controller/ofo.py
def power(self, model: str) -> LogisticModel:
    """Return the power logistic model for a model label."""
    return self._power[model]

latency(model)

Return the latency logistic model for a model label.

Source code in openg2g/controller/ofo.py
def latency(self, model: str) -> LogisticModel:
    """Return the latency logistic model for a model label."""
    return self._latency[model]

throughput(model)

Return the throughput logistic model for a model label.

Source code in openg2g/controller/ofo.py
def throughput(self, model: str) -> LogisticModel:
    """Return the throughput logistic model for a model label."""
    return self._throughput[model]

generate(models, data_sources, *, runs=None, mlenergy_data_dir=None) classmethod

Generate logistic fits from ML.ENERGY benchmark data.

Parameters:

Name Type Description Default
models tuple[InferenceModelSpec, ...]

Model specifications.

required
data_sources dict[str, Any]

Per-model MLEnergySource instances, keyed by model_label.

required
runs Any

Pre-loaded LLMRuns object. If None, loads from mlenergy_data_dir or the HuggingFace Hub.

None
mlenergy_data_dir Path | None

Path to compiled mlenergy-data directory. Ignored if runs is provided.

None

Returns:

Type Description
LogisticModelStore

A new LogisticModelStore with fitted logistic models.

Source code in openg2g/controller/ofo.py
@classmethod
def generate(
    cls,
    models: tuple[InferenceModelSpec, ...],
    data_sources: dict[str, Any],
    *,
    runs: Any = None,
    mlenergy_data_dir: Path | None = None,
) -> LogisticModelStore:
    """Generate logistic fits from ML.ENERGY benchmark data.

    Args:
        models: Model specifications.
        data_sources: Per-model `MLEnergySource` instances, keyed by
            `model_label`.
        runs: Pre-loaded `LLMRuns` object. If `None`, loads from
            `mlenergy_data_dir` or the HuggingFace Hub.
        mlenergy_data_dir: Path to compiled mlenergy-data directory.
            Ignored if `runs` is provided.

    Returns:
        A new `LogisticModelStore` with fitted logistic models.
    """
    if runs is None:
        unique_tasks = {src.task for src in data_sources.values()}
        if mlenergy_data_dir:
            runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
        else:
            runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
    if not runs:
        raise ValueError("No runs found for the specified tasks")

    subsets_by_label: dict[str, Any] = {}
    for ms in models:
        src = data_sources.get(ms.model_label)
        if src is None:
            raise ValueError(f"No data source for model {ms.model_label!r}")
        model_id = ms.model_id
        if not model_id:
            raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")

        subset = (
            runs.model_id(model_id).gpu_model(src.gpu).num_gpus(ms.gpus_per_replica).max_num_seqs(*src.batch_sizes)
        )
        if not subset:
            raise ValueError(
                f"Config matched zero runs for logistic fits: model_id={model_id!r}, "
                f"gpu={src.gpu!r}, num_gpus={ms.gpus_per_replica}, "
                f"batch_sizes={src.batch_sizes}"
            )
        subsets_by_label[ms.model_label] = subset

    all_by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] = {}
    power: dict[str, LogisticModel] = {}
    latency: dict[str, LogisticModel] = {}
    throughput: dict[str, LogisticModel] = {}
    for model_label, group in subsets_by_label.items():
        exclude = set(data_sources[model_label].fit_exclude_batch_sizes)
        by_batch: dict[int, list[tuple[float, float, float]]] = {}
        for r in group:
            if r.max_num_seqs in exclude:
                continue
            by_batch.setdefault(r.max_num_seqs, []).append(
                (r.avg_power_watts, r.mean_itl_ms / 1000.0, r.output_throughput_tokens_per_sec)
            )
        all_by_batch[model_label] = by_batch

        batches = sorted(by_batch.keys())
        if not batches:
            continue

        x = np.log2(np.array(batches, dtype=float).clip(min=1))
        for _metric_name, idx, target in [
            ("power", 0, power),
            ("latency", 1, latency),
            ("throughput", 2, throughput),
        ]:
            y = np.array([float(np.median([t[idx] for t in by_batch[b]])) for b in batches])
            fit = LogisticModel.fit(x, y)
            target[model_label] = fit

    if not power and not latency and not throughput:
        raise ValueError("No logistic fit rows produced")
    store = cls(power=power, latency=latency, throughput=throughput)
    store._by_batch = all_by_batch
    return store

save(csv_path, *, plot=False)

Save logistic fits to a CSV.

Parameters:

Name Type Description Default
csv_path Path

Output CSV path.

required
plot bool

If True, also write a logistic fits plot to the same directory.

False
Source code in openg2g/controller/ofo.py
def save(self, csv_path: Path, *, plot: bool = False) -> None:
    """Save logistic fits to a CSV.

    Args:
        csv_path: Output CSV path.
        plot: If `True`, also write a logistic fits plot to the
            same directory.
    """
    csv_path = Path(csv_path)
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    rows: list[dict[str, Any]] = []
    for metric_name, fits in [("power", self._power), ("latency", self._latency), ("throughput", self._throughput)]:
        for label in sorted(fits):
            model = fits[label]
            rows.append(
                {
                    self.COL_MODEL_LABEL: label,
                    self.COL_METRIC: metric_name,
                    "L": model.L,
                    "x0": model.x0,
                    "k": model.k,
                    "b0": model.b0,
                }
            )
    pd.DataFrame(rows).to_csv(csv_path, index=False)

    by_batch = getattr(self, "_by_batch", None)
    if plot and by_batch is not None:
        model_labels = sorted(self._power.keys())
        _plot_logistic_fits(
            by_batch,
            self._power,
            self._latency,
            self._throughput,
            model_labels,
            csv_path.parent,
        )

load(csv_path) classmethod

Load power, latency, and throughput fits from a merged CSV.

Expected columns: model_label, metric, plus the logistic model parameter columns (L, x0, k, b0).

The metric column must contain power, latency, or throughput (case-insensitive).

Parameters:

Name Type Description Default
csv_path Path | str

Path to the logistic fits CSV.

required
Source code in openg2g/controller/ofo.py
@classmethod
def load(cls, csv_path: Path | str) -> LogisticModelStore:
    """Load power, latency, and throughput fits from a merged CSV.

    Expected columns: `model_label`, `metric`, plus the logistic
    model parameter columns (`L`, `x0`, `k`, `b0`).

    The `metric` column must contain `power`, `latency`, or
    `throughput` (case-insensitive).

    Args:
        csv_path: Path to the logistic fits CSV.
    """
    csv_path = Path(csv_path)
    df = pd.read_csv(csv_path)

    required_cols = [cls.COL_MODEL_LABEL, cls.COL_METRIC]
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise ValueError(f"{csv_path} missing columns: {missing}. Got: {list(df.columns)}")

    power: dict[str, LogisticModel] = {}
    latency: dict[str, LogisticModel] = {}
    throughput: dict[str, LogisticModel] = {}
    targets = {"power": power, "latency": latency, "throughput": throughput}
    for row in df.to_dict(orient="records"):
        metric = str(row[cls.COL_METRIC]).strip().lower()
        if metric in targets:
            targets[metric][str(row[cls.COL_MODEL_LABEL])] = LogisticModel.from_dict(row)

    if not power and not latency and not throughput:
        raise ValueError(f"No logistic model rows loaded from {csv_path}")
    return cls(power=power, latency=latency, throughput=throughput)

ensure(csv_path, models=None, data_sources=None, *, mlenergy_data_dir=None, plot=False) classmethod

Load from csv_path, generating first if needed.

Parameters:

Name Type Description Default
csv_path Path

Path to the logistic fits CSV.

required
models tuple[InferenceModelSpec, ...] | None

Model specifications. Required when no cached file exists.

None
data_sources dict[str, Any] | None

Per-model MLEnergySource instances, keyed by model_label. Required when no cached file exists.

None
mlenergy_data_dir Path | None

Path to compiled mlenergy-data directory.

None
plot bool

If True, generate a logistic fits plot on generation.

False
Source code in openg2g/controller/ofo.py
@classmethod
def ensure(
    cls,
    csv_path: Path,
    models: tuple[InferenceModelSpec, ...] | None = None,
    data_sources: dict[str, Any] | None = None,
    *,
    mlenergy_data_dir: Path | None = None,
    plot: bool = False,
) -> LogisticModelStore:
    """Load from `csv_path`, generating first if needed.

    Args:
        csv_path: Path to the logistic fits CSV.
        models: Model specifications. Required when no cached file exists.
        data_sources: Per-model `MLEnergySource` instances, keyed by
            `model_label`. Required when no cached file exists.
        mlenergy_data_dir: Path to compiled mlenergy-data directory.
        plot: If `True`, generate a logistic fits plot on generation.
    """
    csv_path = Path(csv_path)
    if not csv_path.exists():
        if models is None or data_sources is None:
            raise ValueError("models and data_sources required for LogisticModelStore generation (no cached data)")
        logger.info("Generating logistic fits to %s ...", csv_path)
        cls.generate(models, data_sources, mlenergy_data_dir=mlenergy_data_dir).save(csv_path, plot=plot)
    return cls.load(csv_path)

VoltageDualVariables

Full-network duals for voltage box constraints.

Maintains per-bus dual variables for under- and overvoltage and updates them via projected gradient ascent:

dual_undervoltage  <- [dual_undervoltage  + ρ_v * (v_min - v̂)]+
dual_overvoltage   <- [dual_overvoltage   + ρ_v * (v̂ - v_max)]+

Parameters:

Name Type Description Default
n_bus_phases int

Number of bus-phase pairs in the voltage vector (3M).

required
config OFOConfig

OFO configuration (voltage bounds and dual step size).

required
Source code in openg2g/controller/ofo.py
class VoltageDualVariables:
    """Full-network duals for voltage box constraints.

    Maintains per-bus dual variables for under- and overvoltage and updates
    them via projected gradient ascent:

        dual_undervoltage  <- [dual_undervoltage  + ρ_v * (v_min - v̂)]+
        dual_overvoltage   <- [dual_overvoltage   + ρ_v * (v̂ - v_max)]+

    Args:
        n_bus_phases: Number of bus-phase pairs in the voltage vector (3M).
        config: OFO configuration (voltage bounds and dual step size).
    """

    def __init__(self, n_bus_phases: int, config: OFOConfig) -> None:
        self.config = config
        self.dual_undervoltage = np.zeros(int(n_bus_phases), dtype=float)  # λ in G2G paper Eq. 5
        self.dual_overvoltage = np.zeros(int(n_bus_phases), dtype=float)  # λ̄ in G2G paper Eq. 6

    def update(self, observed_voltages: np.ndarray) -> None:
        """Update duals given observed voltage vector.

        Args:
            observed_voltages: Observed voltage magnitudes (pu), shape
                `(n_bus_phases,)`.

        Raises:
            ValueError: If `observed_voltages` length does not match the dual
                dimension.
        """
        observed_voltages = np.asarray(observed_voltages, float).reshape(-1)
        if observed_voltages.shape[0] != self.dual_undervoltage.shape[0]:
            raise ValueError(
                f"observed_voltages has len {observed_voltages.shape[0]} "
                f"but duals have len {self.dual_undervoltage.shape[0]}"
            )
        vmin = float(self.config.v_min)
        vmax = float(self.config.v_max)
        rho = float(self.config.voltage_dual_step_size)
        self.dual_undervoltage = np.maximum(self.dual_undervoltage + rho * (vmin - observed_voltages), 0.0)
        self.dual_overvoltage = np.maximum(self.dual_overvoltage + rho * (observed_voltages - vmax), 0.0)

    def dual_difference(self) -> np.ndarray:
        """Return the voltage dual difference (η = λ̄ − λ, Appendix B)."""
        return self.dual_overvoltage - self.dual_undervoltage

update(observed_voltages)

Update duals given observed voltage vector.

Parameters:

Name Type Description Default
observed_voltages ndarray

Observed voltage magnitudes (pu), shape (n_bus_phases,).

required

Raises:

Type Description
ValueError

If observed_voltages length does not match the dual dimension.

Source code in openg2g/controller/ofo.py
def update(self, observed_voltages: np.ndarray) -> None:
    """Update duals given observed voltage vector.

    Args:
        observed_voltages: Observed voltage magnitudes (pu), shape
            `(n_bus_phases,)`.

    Raises:
        ValueError: If `observed_voltages` length does not match the dual
            dimension.
    """
    observed_voltages = np.asarray(observed_voltages, float).reshape(-1)
    if observed_voltages.shape[0] != self.dual_undervoltage.shape[0]:
        raise ValueError(
            f"observed_voltages has len {observed_voltages.shape[0]} "
            f"but duals have len {self.dual_undervoltage.shape[0]}"
        )
    vmin = float(self.config.v_min)
    vmax = float(self.config.v_max)
    rho = float(self.config.voltage_dual_step_size)
    self.dual_undervoltage = np.maximum(self.dual_undervoltage + rho * (vmin - observed_voltages), 0.0)
    self.dual_overvoltage = np.maximum(self.dual_overvoltage + rho * (observed_voltages - vmax), 0.0)

dual_difference()

Return the voltage dual difference (η = λ̄ − λ, Appendix B).

Source code in openg2g/controller/ofo.py
def dual_difference(self) -> np.ndarray:
    """Return the voltage dual difference (η = λ̄ − λ, Appendix B)."""
    return self.dual_overvoltage - self.dual_undervoltage

PrimalBatchOptimizer

Primal batch-size optimizer operating in log2 space.

Maintains continuous state x_i = log2(batch_i) per model and applies a gradient descent step using voltage duals, latency duals, and fitted power/latency/throughput curves.

Parameters:

Name Type Description Default
models list[InferenceModelSpec]

Model specifications for each served model.

required
feasible_batch_sizes list[int]

Allowed batch sizes (union across all models).

required
power_fits dict[str, LogisticModel]

Per-model logistic fit for power vs log2(batch_size).

required
latency_fits dict[str, LogisticModel]

Per-model logistic fit for latency vs log2(batch_size).

required
throughput_fits dict[str, LogisticModel]

Per-model logistic fit for throughput vs log2(batch_size).

required
config OFOConfig

OFO configuration (step size, throughput/switch weights, voltage gradient scale).

required
Source code in openg2g/controller/ofo.py
class PrimalBatchOptimizer:
    """Primal batch-size optimizer operating in log2 space.

    Maintains continuous state `x_i = log2(batch_i)` per model and applies
    a gradient descent step using voltage duals, latency duals, and fitted
    power/latency/throughput curves.

    Args:
        models: Model specifications for each served model.
        feasible_batch_sizes: Allowed batch sizes (union across all models).
        power_fits: Per-model logistic fit for power vs log2(batch_size).
        latency_fits: Per-model logistic fit for latency vs log2(batch_size).
        throughput_fits: Per-model logistic fit for throughput vs
            log2(batch_size).
        config: OFO configuration (step size, throughput/switch weights,
            voltage gradient scale).
    """

    def __init__(
        self,
        *,
        models: list[InferenceModelSpec],
        feasible_batch_sizes: list[int],
        power_fits: dict[str, LogisticModel],
        latency_fits: dict[str, LogisticModel],
        throughput_fits: dict[str, LogisticModel],
        config: OFOConfig,
    ) -> None:
        self.models = list(models)
        self.feasible_batch_sizes = sorted({int(b) for b in feasible_batch_sizes})
        if not self.feasible_batch_sizes:
            raise ValueError("feasible_batch_sizes cannot be empty.")

        self.power_fits = power_fits
        self.latency_fits = latency_fits
        self.throughput_fits = throughput_fits
        self.config = config

        self.log_batch_size_min = math.log2(min(self.feasible_batch_sizes))
        self.log_batch_size_max = math.log2(max(self.feasible_batch_sizes))

        self.log_batch_size_by_model: dict[str, float] = {
            ms.model_label: float(self.log_batch_size_max) for ms in self.models
        }
        self.prev_log_batch_size_by_model: dict[str, float] = dict(self.log_batch_size_by_model)

        # Per-model throughput normalization: r_i(x_max) for a single replica
        self.throughput_max_by_model: dict[str, float] = {}
        b_max = int(max(self.feasible_batch_sizes))
        for ms in self.models:
            label = ms.model_label
            try:
                th_max = float(self.throughput_fits[label].eval(b_max))
            except Exception:
                th_max = float("nan")
            if (not np.isfinite(th_max)) or (th_max <= 0.0):
                th_max = 1.0
            self.throughput_max_by_model[label] = th_max

    def _clamp_log_batch_size(self, log_batch_size: float) -> float:
        return float(min(max(float(log_batch_size), self.log_batch_size_min), self.log_batch_size_max))

    def _discretize_batch(self, log_batch_size: float) -> int:
        b_cont = 2.0 ** float(log_batch_size)
        idx = bisect.bisect_left(self.feasible_batch_sizes, b_cont)
        candidates = []
        if idx > 0:
            candidates.append(self.feasible_batch_sizes[idx - 1])
        if idx < len(self.feasible_batch_sizes):
            candidates.append(self.feasible_batch_sizes[idx])
        return int(min(candidates, key=lambda bb: abs(bb - b_cont)))

    def init_from_batches(self, batch_init: dict[str, int]) -> None:
        """Initialize log-batch-size state from discrete batch sizes."""
        for ms in self.models:
            label = ms.model_label
            b = int(batch_init.get(label, max(self.feasible_batch_sizes)))
            log_batch_size = math.log2(max(b, 1))
            log_batch_size = self._clamp_log_batch_size(log_batch_size)
            self.log_batch_size_by_model[label] = float(log_batch_size)
            self.prev_log_batch_size_by_model[label] = float(log_batch_size)

    def step(
        self,
        *,
        voltage_dual_diff: np.ndarray,
        sensitivity_matrix: np.ndarray,
        phase_share_by_model: dict[str, np.ndarray],
        latency_dual_by_model: dict[str, float] | None = None,
        replica_count_by_model: dict[str, float] | None = None,
    ) -> dict[str, int]:
        """Primal gradient descent step.

        Args:
            voltage_dual_diff: Voltage dual difference vector
                (η = λ̄ − λ), shape `(n_bus_phases,)`.
            sensitivity_matrix: Voltage sensitivity matrix (H = dv/dp),
                shape `(n_bus_phases, 3)`.
            phase_share_by_model: Per-model normalized phase share vectors,
                shape `(3,)` each.
            latency_dual_by_model: Per-model latency dual variables (μ_i).
            replica_count_by_model: Per-model active replica counts (w_i).

        Returns:
            Next batch sizes per model.
        """
        voltage_dual_diff = np.asarray(voltage_dual_diff, float).reshape(-1)
        sensitivity_matrix = np.asarray(sensitivity_matrix, float)
        latency_dual_by_model = {} if latency_dual_by_model is None else dict(latency_dual_by_model)
        replica_count_by_model = {} if replica_count_by_model is None else dict(replica_count_by_model)

        step_size = float(self.config.primal_step_size)  # ρ_x
        w_throughput = float(self.config.w_throughput)
        w_switch = float(self.config.w_switch)
        voltage_gradient_scale = float(self.config.voltage_gradient_scale)

        batch_next: dict[str, int] = {}

        for ms in self.models:
            label = ms.model_label
            log_batch_size = float(self.log_batch_size_by_model[label])
            prev_log_batch_size = float(self.prev_log_batch_size_by_model.get(label, log_batch_size))

            replica_count = float(replica_count_by_model.get(label, 0.0))  # w_i
            if (not np.isfinite(replica_count)) or (replica_count < 0.0):
                replica_count = 0.0

            phase_share = np.asarray(  # e_i (phase-allocation weight, p.7)
                phase_share_by_model.get(label, np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)),
                float,
            ).reshape(3)
            s = float(np.sum(phase_share))
            if (not np.isfinite(s)) or s <= 0.0:
                phase_share = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
            else:
                phase_share = phase_share / s

            weighted_sensitivity = sensitivity_matrix @ phase_share  # H @ e_i
            voltage_gradient = float(voltage_dual_diff @ weighted_sensitivity)

            dPdx_1 = float(self.power_fits[label].deriv_wrt_x(log_batch_size))
            dLdx_1 = float(self.latency_fits[label].deriv_wrt_x(log_batch_size))
            dThdx_1 = float(self.throughput_fits[label].deriv_wrt_x(log_batch_size))

            dPdx_1_kw = dPdx_1 / 1000.0

            th_max = float(self.throughput_max_by_model.get(label, 1.0))
            if (not np.isfinite(th_max)) or (th_max <= 0.0):
                th_max = 1.0
            dThdx_norm_1 = dThdx_1 / th_max

            dPdx = replica_count * dPdx_1_kw
            dThdx = replica_count * dThdx_norm_1
            dLdx = dLdx_1

            latency_dual = float(latency_dual_by_model.get(label, 0.0))  # μ_i
            if (not np.isfinite(latency_dual)) or (latency_dual < 0.0):
                latency_dual = 0.0

            # Gradient of the Lagrangian w.r.t. x_i = log2(batch_i).
            # G2G paper Eq. 18: nabla_x L = -dR/dx (throughput)
            #                              + 2*gamma*(x - x_prev) (switching)
            #                              + eta^T H e_i dP/dx (voltage dual)
            #                              + mu_i * dL/dx (latency dual)
            # Implementation extensions: wT scaling on throughput,
            #                            k_v scaling on voltage term
            grad = 0.0
            grad -= w_throughput * dThdx
            grad += voltage_gradient_scale * voltage_gradient * dPdx
            grad += latency_dual * dLdx
            grad += w_switch * (log_batch_size - prev_log_batch_size)

            new_log_batch_size = self._clamp_log_batch_size(log_batch_size - step_size * grad)
            self.prev_log_batch_size_by_model[label] = log_batch_size
            self.log_batch_size_by_model[label] = new_log_batch_size
            batch_next[label] = self._discretize_batch(new_log_batch_size)

        return batch_next

init_from_batches(batch_init)

Initialize log-batch-size state from discrete batch sizes.

Source code in openg2g/controller/ofo.py
def init_from_batches(self, batch_init: dict[str, int]) -> None:
    """Initialize log-batch-size state from discrete batch sizes."""
    for ms in self.models:
        label = ms.model_label
        b = int(batch_init.get(label, max(self.feasible_batch_sizes)))
        log_batch_size = math.log2(max(b, 1))
        log_batch_size = self._clamp_log_batch_size(log_batch_size)
        self.log_batch_size_by_model[label] = float(log_batch_size)
        self.prev_log_batch_size_by_model[label] = float(log_batch_size)

step(*, voltage_dual_diff, sensitivity_matrix, phase_share_by_model, latency_dual_by_model=None, replica_count_by_model=None)

Primal gradient descent step.

Parameters:

Name Type Description Default
voltage_dual_diff ndarray

Voltage dual difference vector (η = λ̄ − λ), shape (n_bus_phases,).

required
sensitivity_matrix ndarray

Voltage sensitivity matrix (H = dv/dp), shape (n_bus_phases, 3).

required
phase_share_by_model dict[str, ndarray]

Per-model normalized phase share vectors, shape (3,) each.

required
latency_dual_by_model dict[str, float] | None

Per-model latency dual variables (μ_i).

None
replica_count_by_model dict[str, float] | None

Per-model active replica counts (w_i).

None

Returns:

Type Description
dict[str, int]

Next batch sizes per model.

Source code in openg2g/controller/ofo.py
def step(
    self,
    *,
    voltage_dual_diff: np.ndarray,
    sensitivity_matrix: np.ndarray,
    phase_share_by_model: dict[str, np.ndarray],
    latency_dual_by_model: dict[str, float] | None = None,
    replica_count_by_model: dict[str, float] | None = None,
) -> dict[str, int]:
    """Primal gradient descent step.

    Args:
        voltage_dual_diff: Voltage dual difference vector
            (η = λ̄ − λ), shape `(n_bus_phases,)`.
        sensitivity_matrix: Voltage sensitivity matrix (H = dv/dp),
            shape `(n_bus_phases, 3)`.
        phase_share_by_model: Per-model normalized phase share vectors,
            shape `(3,)` each.
        latency_dual_by_model: Per-model latency dual variables (μ_i).
        replica_count_by_model: Per-model active replica counts (w_i).

    Returns:
        Next batch sizes per model.
    """
    voltage_dual_diff = np.asarray(voltage_dual_diff, float).reshape(-1)
    sensitivity_matrix = np.asarray(sensitivity_matrix, float)
    latency_dual_by_model = {} if latency_dual_by_model is None else dict(latency_dual_by_model)
    replica_count_by_model = {} if replica_count_by_model is None else dict(replica_count_by_model)

    step_size = float(self.config.primal_step_size)  # ρ_x
    w_throughput = float(self.config.w_throughput)
    w_switch = float(self.config.w_switch)
    voltage_gradient_scale = float(self.config.voltage_gradient_scale)

    batch_next: dict[str, int] = {}

    for ms in self.models:
        label = ms.model_label
        log_batch_size = float(self.log_batch_size_by_model[label])
        prev_log_batch_size = float(self.prev_log_batch_size_by_model.get(label, log_batch_size))

        replica_count = float(replica_count_by_model.get(label, 0.0))  # w_i
        if (not np.isfinite(replica_count)) or (replica_count < 0.0):
            replica_count = 0.0

        phase_share = np.asarray(  # e_i (phase-allocation weight, p.7)
            phase_share_by_model.get(label, np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)),
            float,
        ).reshape(3)
        s = float(np.sum(phase_share))
        if (not np.isfinite(s)) or s <= 0.0:
            phase_share = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
        else:
            phase_share = phase_share / s

        weighted_sensitivity = sensitivity_matrix @ phase_share  # H @ e_i
        voltage_gradient = float(voltage_dual_diff @ weighted_sensitivity)

        dPdx_1 = float(self.power_fits[label].deriv_wrt_x(log_batch_size))
        dLdx_1 = float(self.latency_fits[label].deriv_wrt_x(log_batch_size))
        dThdx_1 = float(self.throughput_fits[label].deriv_wrt_x(log_batch_size))

        dPdx_1_kw = dPdx_1 / 1000.0

        th_max = float(self.throughput_max_by_model.get(label, 1.0))
        if (not np.isfinite(th_max)) or (th_max <= 0.0):
            th_max = 1.0
        dThdx_norm_1 = dThdx_1 / th_max

        dPdx = replica_count * dPdx_1_kw
        dThdx = replica_count * dThdx_norm_1
        dLdx = dLdx_1

        latency_dual = float(latency_dual_by_model.get(label, 0.0))  # μ_i
        if (not np.isfinite(latency_dual)) or (latency_dual < 0.0):
            latency_dual = 0.0

        # Gradient of the Lagrangian w.r.t. x_i = log2(batch_i).
        # G2G paper Eq. 18: nabla_x L = -dR/dx (throughput)
        #                              + 2*gamma*(x - x_prev) (switching)
        #                              + eta^T H e_i dP/dx (voltage dual)
        #                              + mu_i * dL/dx (latency dual)
        # Implementation extensions: wT scaling on throughput,
        #                            k_v scaling on voltage term
        grad = 0.0
        grad -= w_throughput * dThdx
        grad += voltage_gradient_scale * voltage_gradient * dPdx
        grad += latency_dual * dLdx
        grad += w_switch * (log_batch_size - prev_log_batch_size)

        new_log_batch_size = self._clamp_log_batch_size(log_batch_size - step_size * grad)
        self.prev_log_batch_size_by_model[label] = log_batch_size
        self.log_batch_size_by_model[label] = new_log_batch_size
        batch_next[label] = self._discretize_batch(new_log_batch_size)

    return batch_next

OFOBatchSizeController

Bases: Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]

Online Feedback Optimization controller for batch-size regulation.

Reads grid voltage and datacenter state, updates voltage and latency duals, runs the primal batch-size optimizer, and returns new batch sizes. Latency dual updates use dc_state.observed_itl_s_by_model .

Parameters:

Name Type Description Default
inference_models tuple[InferenceModelSpec, ...]

Model specifications served in the datacenter.

required
models LogisticModelStore

Per-model logistic models for power, latency, and throughput used in gradient computation.

required
config OFOConfig | None

Unified OFO tuning parameters.

None
dt_s Fraction

Control interval (seconds).

Fraction(1)
Source code in openg2g/controller/ofo.py
class OFOBatchSizeController(Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]):
    """Online Feedback Optimization controller for batch-size regulation.

    Reads grid voltage and datacenter state, updates voltage and latency
    duals, runs the primal batch-size optimizer, and returns new batch
    sizes. Latency dual updates use [`dc_state.observed_itl_s_by_model`
    ][openg2g.datacenter.base.LLMDatacenterState.observed_itl_s_by_model].

    Args:
        inference_models: Model specifications served in the datacenter.
        models: Per-model logistic models for power, latency, and
            throughput used in gradient computation.
        config: Unified OFO tuning parameters.
        dt_s: Control interval (seconds).
    """

    def __init__(
        self,
        inference_models: tuple[InferenceModelSpec, ...],
        *,
        models: LogisticModelStore,
        config: OFOConfig | None = None,
        dt_s: Fraction = Fraction(1),
        site_id: str | None = None,
        initial_batch_sizes: dict[str, int] | None = None,
    ) -> None:
        if config is None:
            config = OFOConfig()

        if not inference_models:
            raise ValueError("inference_models must not be empty.")
        labels = [ms.model_label for ms in inference_models]
        if len(labels) != len(set(labels)):
            raise ValueError(f"Duplicate model labels: {labels}")

        model_specs = list(inference_models)
        self._initial_batch_sizes = initial_batch_sizes or {}

        for ms in model_specs:
            label = ms.model_label
            for metric_name, accessor in [
                ("power", models.power),
                ("latency", models.latency),
                ("throughput", models.throughput),
            ]:
                try:
                    accessor(label)
                except KeyError:
                    raise ValueError(f"LogisticModelStore missing {metric_name} model for {label!r}.") from None

        self._dt_s = dt_s
        self._models = model_specs
        self._config = config
        self._site_id = site_id
        self._itl_deadline_by_model = {ms.model_label: ms.itl_deadline_s for ms in model_specs}

        self._voltage_dual: VoltageDualVariables | None = None
        self._latency_dual_by_model: dict[str, float] = {ms.model_label: 0.0 for ms in model_specs}

        all_bs: set[int] = set()
        for ms in model_specs:
            all_bs.update(ms.feasible_batch_sizes)
        feasible_batch_sizes = sorted(all_bs)

        self._optimizer = PrimalBatchOptimizer(
            models=model_specs,
            feasible_batch_sizes=feasible_batch_sizes,
            power_fits=models.power_fits,
            latency_fits=models.latency_fits,
            throughput_fits=models.throughput_fits,
            config=config,
        )
        self._optimizer.init_from_batches(
            {
                ms.model_label: self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0])
                for ms in model_specs
            }
        )

        self._sensitivity_matrix: np.ndarray | None = None
        self._control_step_count: int = 0

        logger.info(
            "OFOBatchSizeController: %d models, dt=%s s, feasible_batches=%s",
            len(model_specs),
            dt_s,
            feasible_batch_sizes,
        )

    def reset(self) -> None:
        self._voltage_dual = None
        self._latency_dual_by_model = {ms.model_label: 0.0 for ms in self._models}
        self._optimizer.init_from_batches(
            {
                ms.model_label: self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0])
                for ms in self._models
            }
        )
        self._sensitivity_matrix = None
        self._control_step_count = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        datacenter: LLMBatchSizeControlledDatacenter[LLMDatacenterState],
        grid: OpenDSSGrid,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:

        if self._voltage_dual is None:
            self._voltage_dual = VoltageDualVariables(len(grid.v_index), self._config)

        # 1. Re-estimate sensitivity if needed
        if self._sensitivity_matrix is None or (
            self._config.sensitivity_update_interval > 0
            and self._control_step_count % self._config.sensitivity_update_interval == 0
        ):
            sens_kwargs = {"perturbation_kw": self._config.sensitivity_perturbation_kw}
            if self._site_id is not None:
                sens_kwargs["site_id"] = self._site_id
            self._sensitivity_matrix, _ = grid.estimate_sensitivity(**sens_kwargs)

        # 2. Update voltage duals from grid state
        observed_voltages = grid.voltages_vector()
        self._voltage_dual.update(observed_voltages)

        voltage_dual_diff = self._voltage_dual.dual_difference()  # η = λ̄ − λ

        # 3. Read observed latency from datacenter and update latency duals
        dc_state = datacenter.state
        missing_replicas = [
            ms.model_label for ms in self._models if ms.model_label not in dc_state.active_replicas_by_model
        ]
        if missing_replicas:
            miss = ", ".join(sorted(missing_replicas))
            raise RuntimeError(
                f"OFOBatchSizeController requires active_replicas_by_model for all models. Missing: {miss}."
            )
        missing_itl = [ms.model_label for ms in self._models if ms.model_label not in dc_state.observed_itl_s_by_model]
        if missing_itl:
            miss = ", ".join(sorted(missing_itl))
            raise RuntimeError(
                f"OFOBatchSizeController requires observed_itl_s_by_model for all models. Missing: {miss}."
            )
        for ms in self._models:
            label = ms.model_label
            num_replicas = max(int(dc_state.active_replicas_by_model[label]), 0)
            observed_itl = float(dc_state.observed_itl_s_by_model[label])
            if num_replicas <= 0:
                logger.debug("Model %s has 0 replicas, skipping latency dual update", label)
                observed_itl = float("nan")

            deadline = float(self._itl_deadline_by_model[label])
            if np.isfinite(observed_itl):
                self._latency_dual_by_model[label] = max(
                    self._latency_dual_by_model[label]
                    + self._config.latency_dual_step_size * (observed_itl - deadline),
                    0.0,
                )
            else:
                self._latency_dual_by_model[label] = max(self._latency_dual_by_model[label], 0.0)

        # 4. Compute replica counts
        replica_count_by_model: dict[str, float] = {}
        for ms in self._models:
            label = ms.model_label
            replica_count_by_model[label] = float(dc_state.active_replicas_by_model[label])

        # 5. Primal update -> next batch sizes
        batch_next = self._optimizer.step(
            voltage_dual_diff=voltage_dual_diff,
            sensitivity_matrix=self._sensitivity_matrix,
            phase_share_by_model=datacenter.phase_share_by_model,
            latency_dual_by_model=self._latency_dual_by_model,
            replica_count_by_model=replica_count_by_model,
        )

        self._control_step_count += 1
        logger.info(
            "OFO step %d (t=%.1f s): batch=%s",
            self._control_step_count,
            clock.time_s,
            batch_next,
        )
        events.emit(
            "controller.ofo.step",
            {
                "batch_size_by_model": batch_next,
                "latency_dual_by_model": dict(self._latency_dual_by_model),
            },
        )
        return [SetBatchSize(batch_size_by_model=batch_next, target_site_id=self._site_id)]

openg2g.controller.rule_based

Rule-based batch-size controller for voltage regulation.

A proportional controller that adjusts LLM batch sizes based on observed voltage violations. Unlike the OFO controller, it requires no sensitivity matrix, no logistic curve fits, and no dual variables — making it a natural "simple baseline" for comparison.

Algorithm (each control step): 1. Read all bus-phase voltages from the grid. 2. Find worst voltage violation magnitude. 3. Compute a signed "pressure" signal: - positive (undervoltage) → reduce batch (less power draw, less voltage drop) - negative (overvoltage) → increase batch (more power draw, more voltage drop) - zero → no action (all voltages within bounds) 4. Adjust each model's batch size proportionally in log2-space. 5. Snap to the nearest feasible batch size.

RuleBasedConfig

Bases: BaseModel

Configuration for the rule-based batch-size controller.

Source code in openg2g/controller/rule_based.py
class RuleBasedConfig(BaseModel):
    """Configuration for the rule-based batch-size controller."""

    step_size: float = 10.0
    """Proportional gain: log2(batch) change per pu of voltage violation.
    With feasible batches spaced ~1 log2 unit apart, a violation of 0.01 pu
    needs step_size ~10 to produce a 0.1 log2 shift, enough to eventually
    change the discrete batch level."""

    v_min: float = 0.95
    """Lower voltage limit (pu)."""

    v_max: float = 1.05
    """Upper voltage limit (pu)."""

    deadband: float = 0.001
    """Ignore violations smaller than this (pu).  Prevents chattering."""

    latency_guard: bool = True
    """If True, prevent batch size increase when ITL exceeds deadline."""

step_size = 10.0 class-attribute instance-attribute

Proportional gain: log2(batch) change per pu of voltage violation. With feasible batches spaced ~1 log2 unit apart, a violation of 0.01 pu needs step_size ~10 to produce a 0.1 log2 shift, enough to eventually change the discrete batch level.

v_min = 0.95 class-attribute instance-attribute

Lower voltage limit (pu).

v_max = 1.05 class-attribute instance-attribute

Upper voltage limit (pu).

deadband = 0.001 class-attribute instance-attribute

Ignore violations smaller than this (pu). Prevents chattering.

latency_guard = True class-attribute instance-attribute

If True, prevent batch size increase when ITL exceeds deadline.

RuleBasedBatchSizeController

Bases: Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]

Proportional rule-based controller for LLM batch-size regulation.

Reads grid voltages, computes a signed pressure signal from the worst violation, and adjusts batch sizes proportionally. No model fits or sensitivity matrices required.

Source code in openg2g/controller/rule_based.py
class RuleBasedBatchSizeController(
    Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid],
):
    """Proportional rule-based controller for LLM batch-size regulation.

    Reads grid voltages, computes a signed pressure signal from the worst
    violation, and adjusts batch sizes proportionally.  No model fits or
    sensitivity matrices required.
    """

    def __init__(
        self,
        inference_models: tuple[InferenceModelSpec, ...] | list[InferenceModelSpec],
        *,
        config: RuleBasedConfig,
        dt_s: Fraction = Fraction(1),
        site_id: str | None = None,
        exclude_buses: tuple[str, ...] = (),
        initial_batch_sizes: dict[str, int] | None = None,
    ) -> None:
        model_specs = list(inference_models)
        self._dt_s = dt_s
        self._site_id = site_id
        self._config = config
        self._models = model_specs
        self._exclude_lower = {b.lower() for b in exclude_buses}
        self._initial_batch_sizes = initial_batch_sizes or {}

        # Build per-model feasible batch list (sorted ascending)
        self._feasible: dict[str, list[int]] = {}
        self._itl_deadline: dict[str, float] = {}
        for ms in model_specs:
            self._feasible[ms.model_label] = sorted(ms.feasible_batch_sizes)
            self._itl_deadline[ms.model_label] = ms.itl_deadline_s

        # Continuous state: log2(batch) per model (for smooth proportional control)
        self._log2_batch: dict[str, float] = {
            ms.model_label: math.log2(self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0]))
            for ms in model_specs
        }

        logger.info(
            "RuleBasedBatchSizeController: %d models, dt=%s s, step_size=%.2f, deadband=%.4f, v=[%.2f, %.2f]",
            len(model_specs),
            dt_s,
            config.step_size,
            config.deadband,
            config.v_min,
            config.v_max,
        )

    @property
    def site_id(self) -> str | None:
        return self._site_id

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def reset(self) -> None:
        self._log2_batch = {
            ms.model_label: math.log2(self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0]))
            for ms in self._models
        }

    def step(
        self,
        clock: SimulationClock,
        datacenter: LLMBatchSizeControlledDatacenter[LLMDatacenterState],
        grid: OpenDSSGrid,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        if grid.state is None:
            return []

        cfg = self._config
        voltages = grid.state.voltages

        # ── 1. Find worst voltage violation ──
        worst_under = 0.0  # magnitude of worst undervoltage (positive)
        worst_over = 0.0  # magnitude of worst overvoltage (positive)

        for bus in voltages.buses():
            if bus.lower() in self._exclude_lower:
                continue
            pv = voltages[bus]
            for v in (pv.a, pv.b, pv.c):
                if math.isnan(v):
                    continue
                if v < cfg.v_min:
                    worst_under = max(worst_under, cfg.v_min - v)
                elif v > cfg.v_max:
                    worst_over = max(worst_over, v - cfg.v_max)

        # ── 2. Compute pressure signal ──
        # Positive pressure → reduce batch (undervoltage: DC draws too much power)
        # Negative pressure → increase batch (overvoltage: DC draws too little)
        if worst_under > cfg.deadband:
            pressure = worst_under
        elif worst_over > cfg.deadband:
            pressure = -worst_over
        else:
            pressure = 0.0

        if pressure == 0.0:
            return []

        # ── 3. Read latency state for guard ──
        dc_state = datacenter.state
        itl_by_model = dc_state.observed_itl_s_by_model if dc_state else {}

        # ── 4. Adjust batch sizes ──
        new_batches: dict[str, int] = {}
        changed = False

        for ms in self._models:
            label = ms.model_label
            log2_b = self._log2_batch[label]
            feasible = self._feasible[label]

            # Proportional adjustment in log2-space
            # pressure > 0 (undervoltage) → reduce batch → log2_b decreases
            delta = -cfg.step_size * pressure
            new_log2 = log2_b + delta

            # Clamp to feasible range
            new_log2 = max(math.log2(feasible[0]), min(math.log2(feasible[-1]), new_log2))

            # Latency guard: don't increase batch if ITL already exceeds deadline
            if cfg.latency_guard and delta > 0:
                itl = itl_by_model.get(label, 0.0)
                if not math.isnan(itl) and itl > self._itl_deadline[label]:
                    new_log2 = log2_b  # revert

            # Snap to nearest feasible batch size
            target = 2.0**new_log2
            best = min(feasible, key=lambda b: abs(b - target))

            # Keep continuous state for accumulation; only snap for the command
            self._log2_batch[label] = new_log2
            new_batches[label] = best

            if best != dc_state.batch_size_by_model.get(
                label, self._initial_batch_sizes.get(label, ms.feasible_batch_sizes[0])
            ):
                changed = True

        if not changed:
            return []

        events.emit(
            "rule_based.step",
            {"time_s": clock.time_s, "pressure": pressure, "batch": dict(new_batches)},
        )

        return [SetBatchSize(batch_size_by_model=new_batches)]

openg2g.controller.tap_schedule

Tap schedule controller: applies pre-defined regulator tap changes at specified times.

TapScheduleController

Bases: Controller[DatacenterBackend, GridBackend]

Applies pre-defined tap changes at scheduled times.

When multiple schedule entries fire in the same step, their tap values are merged (later entries win).

Parameters:

Name Type Description Default
schedule TapSchedule

Tap schedule built via TapPosition(...).at(t=...) | ....

required
dt_s Fraction

How often the controller checks the schedule (seconds).

Fraction(1)
Source code in openg2g/controller/tap_schedule.py
class TapScheduleController(Controller[DatacenterBackend, GridBackend]):
    """Applies pre-defined tap changes at scheduled times.

    When multiple schedule entries fire in the same step, their tap
    values are merged (later entries win).

    Args:
        schedule: Tap schedule built via
            [`TapPosition(...).at(t=...) | ...`][openg2g.grid.config.TapSchedule].
        dt_s: How often the controller checks the schedule (seconds).
    """

    def __init__(self, *, schedule: TapSchedule, dt_s: Fraction = Fraction(1)) -> None:
        self._dt_s = dt_s
        self._entries = list(schedule)
        self._idx = 0

    def reset(self) -> None:
        self._idx = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        datacenter: DatacenterBackend,
        grid: GridBackend,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:

        t_now = clock.time_s
        merged: dict[str, float] = {}
        any_fired = False

        while self._idx < len(self._entries):
            t_ev, pos = self._entries[self._idx]
            if float(t_ev) <= t_now + 1e-12:
                merged.update(pos.regulators)
                any_fired = True
                self._idx += 1
            else:
                break

        if not any_fired or not merged:
            return []

        tap = TapPosition(regulators=merged)
        events.emit("controller.tap_schedule.fired", {"tap_position": tap})
        return [SetTaps(tap_position=tap)]