Skip to content

openg2g.controller

openg2g.controller.base

Abstract base class for controllers.

Controller

Bases: Generic[DCBackendT, GridBackendT], ABC

Interface for a control component in the G2G framework.

Controllers receive datacenter and grid state and produce control actions. Multiple controllers compose in order within the coordinator.

Source code in openg2g/controller/base.py
class Controller(Generic[DCBackendT, GridBackendT], ABC):
    """Interface for a control component in the G2G framework.

    Controllers receive datacenter and grid state and produce control actions.
    Multiple controllers compose in order within the coordinator.
    """

    _dc_types: tuple[type[DatacenterBackend], ...] = (DatacenterBackend,)
    _grid_types: tuple[type[GridBackend], ...] = (GridBackend,)

    def __init_subclass__(cls, **kwargs: object) -> None:
        super().__init_subclass__(**kwargs)
        dc_types: tuple[type[DatacenterBackend], ...] | None = None
        grid_types: tuple[type[GridBackend], ...] | None = None
        for base in getattr(cls, "__orig_bases__", ()):
            if get_origin(base) is Controller:
                args = get_args(base)
                if len(args) != 2:
                    raise TypeError(
                        f"{cls.__name__} must specialize Controller with two generic args: "
                        "Controller[DatacenterType, GridType]."
                    )
                dc_raw, grid_raw = args
                dc_norm = _normalize_backend_type_arg(dc_raw, required_base=DatacenterBackend)
                grid_norm = _normalize_backend_type_arg(grid_raw, required_base=GridBackend)
                dc_types = tuple(t for t in dc_norm if issubclass(t, DatacenterBackend))
                grid_types = tuple(t for t in grid_norm if issubclass(t, GridBackend))
                break

        if dc_types is None or grid_types is None:
            inherited = [b for b in cls.__bases__ if issubclass(b, Controller)]
            inherited = [b for b in inherited if b is not Controller]
            if inherited:
                parent = inherited[0]
                cls._dc_types = parent.compatible_datacenter_types()
                cls._grid_types = parent.compatible_grid_types()
                return
            raise TypeError(
                f"{cls.__name__} must explicitly specialize Controller generics as "
                "Controller[DatacenterType, GridType]."
            )

        cls._dc_types = dc_types
        cls._grid_types = grid_types

    @final
    @classmethod
    def compatible_datacenter_types(cls) -> tuple[type[DatacenterBackend], ...]:
        return cls._dc_types

    @final
    @classmethod
    def compatible_grid_types(cls) -> tuple[type[GridBackend], ...]:
        return cls._grid_types

    @final
    @classmethod
    def compatibility_signature(cls) -> str:
        dc = " | ".join(t.__name__ for t in cls.compatible_datacenter_types())
        grid = " | ".join(t.__name__ for t in cls.compatible_grid_types())
        return f"Controller[{dc}, {grid}]"

    @property
    @abstractmethod
    def dt_s(self) -> Fraction:
        """Control interval as a Fraction (seconds)."""

    @abstractmethod
    def reset(self) -> None:
        """Reset simulation state to initial conditions.

        Called by the coordinator before each [`start`][..start]. Must
        clear all simulation state: dual variables, counters, cached
        matrices. Configuration (dt_s, fits, step sizes) is not
        affected.

        Abstract so every implementation explicitly enumerates its state.
        A forgotten field is a bug -- not clearing it silently corrupts
        the second run.
        """

    def start(self) -> None:
        """Acquire per-run resources.

        Called after [`reset`][..reset], before the simulation loop.
        No-op by default because most controllers have no resources to
        acquire.
        """

    def stop(self) -> None:
        """Release per-run resources. Simulation state is preserved.

        Called after the simulation loop in LIFO order. No-op by default.
        """

    @abstractmethod
    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        """Compute control commands for this step. Return an empty list for no-op."""

dt_s abstractmethod property

Control interval as a Fraction (seconds).

reset() abstractmethod

Reset simulation state to initial conditions.

Called by the coordinator before each start. Must clear all simulation state: dual variables, counters, cached matrices. Configuration (dt_s, fits, step sizes) is not affected.

Abstract so every implementation explicitly enumerates its state. A forgotten field is a bug -- not clearing it silently corrupts the second run.

Source code in openg2g/controller/base.py
@abstractmethod
def reset(self) -> None:
    """Reset simulation state to initial conditions.

    Called by the coordinator before each [`start`][..start]. Must
    clear all simulation state: dual variables, counters, cached
    matrices. Configuration (dt_s, fits, step sizes) is not
    affected.

    Abstract so every implementation explicitly enumerates its state.
    A forgotten field is a bug -- not clearing it silently corrupts
    the second run.
    """

start()

Acquire per-run resources.

Called after reset, before the simulation loop. No-op by default because most controllers have no resources to acquire.

Source code in openg2g/controller/base.py
def start(self) -> None:
    """Acquire per-run resources.

    Called after [`reset`][..reset], before the simulation loop.
    No-op by default because most controllers have no resources to
    acquire.
    """

stop()

Release per-run resources. Simulation state is preserved.

Called after the simulation loop in LIFO order. No-op by default.

Source code in openg2g/controller/base.py
def stop(self) -> None:
    """Release per-run resources. Simulation state is preserved.

    Called after the simulation loop in LIFO order. No-op by default.
    """

step(clock, events) abstractmethod

Compute control commands for this step. Return an empty list for no-op.

Source code in openg2g/controller/base.py
@abstractmethod
def step(
    self,
    clock: SimulationClock,
    events: EventEmitter,
) -> list[DatacenterCommand | GridCommand]:
    """Compute control commands for this step. Return an empty list for no-op."""

openg2g.controller.batch_size_schedule

Batch size schedule controller: applies pre-defined batch size changes at specified times.

BatchSizeChange dataclass

A batch size change event, optionally with gradual ramp-up.

Attributes:

Name Type Description
batch_size int

Target batch size (max_num_seqs).

ramp_up_rate float

Requests/second ramp-up rate. 0 means immediate.

Source code in openg2g/controller/batch_size_schedule.py
@dataclass(frozen=True)
class BatchSizeChange:
    """A batch size change event, optionally with gradual ramp-up.

    Attributes:
        batch_size: Target batch size (max_num_seqs).
        ramp_up_rate: Requests/second ramp-up rate. 0 means immediate.
    """

    batch_size: int
    ramp_up_rate: float = 0.0

    def __post_init__(self) -> None:
        if self.batch_size <= 0:
            raise ValueError(f"batch_size must be positive, got {self.batch_size}.")
        if self.ramp_up_rate < 0:
            raise ValueError(f"ramp_up_rate must be >= 0, got {self.ramp_up_rate}.")

    def at(self, t: float) -> BatchSizeSchedule:
        """Schedule this change at time *t* seconds.

        Returns:
            A single-entry [`BatchSizeSchedule`][...BatchSizeSchedule].
        """
        return BatchSizeSchedule(((t, self),))

at(t)

Schedule this change at time t seconds.

Returns:

Type Description
BatchSizeSchedule

A single-entry BatchSizeSchedule.

Source code in openg2g/controller/batch_size_schedule.py
def at(self, t: float) -> BatchSizeSchedule:
    """Schedule this change at time *t* seconds.

    Returns:
        A single-entry [`BatchSizeSchedule`][...BatchSizeSchedule].
    """
    return BatchSizeSchedule(((t, self),))

BatchSizeSchedule

Ordered sequence of batch size changes, built with | operator.

Example:

schedule = (
    BatchSizeChange(48).at(40)
    | BatchSizeChange(32).at(60)
    | BatchSizeChange(48, ramp_up_rate=4).at(280)
)

Raises:

Type Description
ValueError

If two entries share the same timestamp.

Source code in openg2g/controller/batch_size_schedule.py
class BatchSizeSchedule:
    """Ordered sequence of batch size changes, built with `|` operator.

    Example:

    ```python
    schedule = (
        BatchSizeChange(48).at(40)
        | BatchSizeChange(32).at(60)
        | BatchSizeChange(48, ramp_up_rate=4).at(280)
    )
    ```

    Raises:
        ValueError: If two entries share the same timestamp.
    """

    __slots__ = ("_entries",)

    def __init__(self, entries: tuple[tuple[float, BatchSizeChange], ...]) -> None:
        self._entries = tuple(sorted(entries, key=lambda e: e[0]))
        times = [t for t, _ in self._entries]
        if len(times) != len(set(times)):
            seen: set[float] = set()
            dupes = sorted({t for t in times if t in seen or seen.add(t)})
            raise ValueError(f"BatchSizeSchedule has duplicate timestamps: {dupes}")

    def __or__(self, other: BatchSizeSchedule) -> BatchSizeSchedule:
        return BatchSizeSchedule(self._entries + other._entries)

    def __iter__(self) -> Iterator[tuple[float, BatchSizeChange]]:
        return iter(self._entries)

    def __len__(self) -> int:
        return len(self._entries)

    def __bool__(self) -> bool:
        return bool(self._entries)

    def __repr__(self) -> str:
        parts: list[str] = []
        for t, c in self._entries:
            ramp = f", ramp_up_rate={c.ramp_up_rate}" if c.ramp_up_rate > 0 else ""
            parts.append(f"BatchSizeChange({c.batch_size}{ramp}).at(t={t})")
        return " | ".join(parts)

BatchSizeScheduleController

Bases: Controller[DatacenterBackend, GridBackend]

Applies pre-defined batch size changes at scheduled times.

Walks each model's schedule and emits SetBatchSize commands when the simulation clock reaches the scheduled time.

Parameters:

Name Type Description Default
schedules dict[str, BatchSizeSchedule]

Per-model batch size schedules, keyed by model label.

required
dt_s Fraction

How often the controller checks the schedule (seconds).

Fraction(1)
Source code in openg2g/controller/batch_size_schedule.py
class BatchSizeScheduleController(Controller[DatacenterBackend, GridBackend]):
    """Applies pre-defined batch size changes at scheduled times.

    Walks each model's schedule and emits
    [`SetBatchSize`][openg2g.datacenter.command.SetBatchSize] commands when the
    simulation clock reaches the scheduled time.

    Args:
        schedules: Per-model batch size schedules, keyed by model label.
        dt_s: How often the controller checks the schedule (seconds).
    """

    def __init__(
        self,
        *,
        datacenter: DatacenterBackend,
        schedules: dict[str, BatchSizeSchedule],
        dt_s: Fraction = Fraction(1),
    ) -> None:
        self._datacenter = datacenter
        self._dt_s = dt_s
        self._schedules = dict(schedules)
        self._indices: dict[str, int] = {label: 0 for label in schedules}

    def reset(self) -> None:
        self._indices = {label: 0 for label in self._schedules}

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        t_now = clock.time_s
        batch_changes: dict[str, int] = {}
        ramp_rates: dict[str, float] = {}

        for label, schedule in self._schedules.items():
            entries = list(schedule)
            idx = self._indices[label]

            while idx < len(entries):
                t_ev, change = entries[idx]
                if float(t_ev) <= t_now + 1e-12:
                    batch_changes[label] = change.batch_size
                    if change.ramp_up_rate > 0:
                        ramp_rates[label] = change.ramp_up_rate
                    idx += 1
                else:
                    break

            self._indices[label] = idx

        if batch_changes:
            events.emit(
                "controller.batch_schedule.fired",
                {"batch_size_by_model": batch_changes},
            )
            return [
                SetBatchSize(
                    batch_size_by_model=batch_changes,
                    ramp_up_rate_by_model=ramp_rates,
                    target=self._datacenter,
                )
            ]
        return []

openg2g.controller.load_shift

Cross-site LLM load shifting controller.

Shifts replicas between datacenters when batch-size control (OFO) is exhausted and voltage violations persist. Runs after all per-site OFO controllers in the coordinator loop.

LoadShiftConfig

Bases: BaseModel

Configuration for cross-site load shifting.

Attributes:

Name Type Description
enabled bool

If False, the controller is a no-op. Useful for ablations.

gpus_per_shift int

Number of GPUs worth of replicas to shift in a single controller tick.

Source code in openg2g/controller/load_shift.py
class LoadShiftConfig(BaseModel):
    """Configuration for cross-site load shifting.

    Attributes:
        enabled: If False, the controller is a no-op. Useful for ablations.
        gpus_per_shift: Number of GPUs worth of replicas to shift in a single
            controller tick.
    """

    enabled: bool = False
    gpus_per_shift: int = 8

LoadShiftController

Bases: Controller[OfflineDatacenter, OpenDSSGrid]

Shift LLM replicas between datacenters to resolve voltage violations.

Rules: 1. Only shift models already running at both source and destination. 2. Only act when batch sizes are saturated AND violation persists. 3. For undervoltage: shift load OUT of violated site -> highest-voltage site. For overvoltage: shift load INTO violated site <- lowest-voltage site. 4. Shift gpus_per_shift GPUs worth of replicas per time step. 5. Repeat until violation resolves or no candidates remain.

Source code in openg2g/controller/load_shift.py
class LoadShiftController(Controller[OfflineDatacenter, OpenDSSGrid]):
    """Shift LLM replicas between datacenters to resolve voltage violations.

    Rules:
    1. Only shift models already running at both source and destination.
    2. Only act when batch sizes are saturated AND violation persists.
    3. For undervoltage: shift load OUT of violated site -> highest-voltage site.
       For overvoltage: shift load INTO violated site <- lowest-voltage site.
    4. Shift `gpus_per_shift` GPUs worth of replicas per time step.
    5. Repeat until violation resolves or no candidates remain.
    """

    def __init__(
        self,
        *,
        config: LoadShiftConfig,
        dt_s: Fraction,
        datacenters: list[OfflineDatacenter],
        grid: OpenDSSGrid,
        models_by_dc: dict[OfflineDatacenter, list[str]],
        gpus_per_replica_by_model: dict[str, int],
        feasible_batch_sizes_by_model: dict[str, list[int]],
        v_min: float = 0.95,
        v_max: float = 1.05,
    ) -> None:
        self._config = config
        self._dt_s = dt_s
        self._datacenters = list(datacenters)
        self._grid = grid
        self._dc_bus_map = {dc: grid.dc_bus(dc) for dc in datacenters}
        self._models_by_dc = dict(models_by_dc)
        self._gpus_per_replica = gpus_per_replica_by_model
        self._feasible_bs = feasible_batch_sizes_by_model
        self._v_min = v_min
        self._v_max = v_max
        self._step_count = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        grid = self._grid
        self._step_count += 1
        if not self._config.enabled:
            return []

        # Build bus -> min voltage mapping from grid
        voltages = grid.voltages_vector()
        v_index = grid.v_index
        bus_voltages: dict[str, list[float]] = {}
        for (bus, _phase), v in zip(v_index, voltages, strict=True):
            bus_voltages.setdefault(bus.lower(), []).append(float(v))

        # Per-DC min/max voltage
        dc_vmin: dict[OfflineDatacenter, float] = {}
        dc_vmax: dict[OfflineDatacenter, float] = {}
        for dc, bus in self._dc_bus_map.items():
            vs = bus_voltages.get(bus.lower(), [])
            if vs:
                dc_vmin[dc] = min(vs)
                dc_vmax[dc] = max(vs)

        commands: list[DatacenterCommand | GridCommand] = []

        for dc in self._datacenters:
            vmin = dc_vmin.get(dc, 1.0)
            vmax = dc_vmax.get(dc, 1.0)

            is_undervoltage = vmin < self._v_min
            is_overvoltage = vmax > self._v_max

            if not is_undervoltage and not is_overvoltage:
                continue

            if not self._is_batch_saturated(dc, is_undervoltage):
                continue

            dc_models = set(self._models_by_dc.get(dc, []))

            if is_undervoltage:
                best_dest = None
                best_v = -1.0
                for other_dc in self._datacenters:
                    if other_dc is dc:
                        continue
                    other_models = set(self._models_by_dc.get(other_dc, []))
                    shared = dc_models & other_models
                    if not shared:
                        continue
                    if other_dc.available_gpu_capacity() < self._config.gpus_per_shift:
                        continue
                    ov = dc_vmin.get(other_dc, 0.0)
                    if ov > best_v:
                        best_v = ov
                        best_dest = other_dc

                if best_dest is None:
                    continue

                shared_models = dc_models & set(self._models_by_dc.get(best_dest, []))
                model = self._pick_model(shared_models)
                if model is None:
                    continue

                replicas = max(1, self._config.gpus_per_shift // self._gpus_per_replica[model])
                src_active = dc.state.active_replicas_by_model[model]
                if src_active < replicas:
                    continue
                commands.append(ShiftReplicas(model_label=model, replica_delta=-replicas, target=dc))
                commands.append(ShiftReplicas(model_label=model, replica_delta=+replicas, target=best_dest))
                logger.info(
                    "LoadShift: undervoltage at %s (Vmin=%.4f), shift %s x%d replicas -> %s (Vmin=%.4f, free=%d GPUs)",
                    dc.name,
                    vmin,
                    model,
                    replicas,
                    best_dest.name,
                    best_v,
                    best_dest.available_gpu_capacity(),
                )

            elif is_overvoltage:
                if dc.available_gpu_capacity() < self._config.gpus_per_shift:
                    continue

                best_src = None
                best_v = 2.0
                for other_dc in self._datacenters:
                    if other_dc is dc:
                        continue
                    other_models = set(self._models_by_dc.get(other_dc, []))
                    shared = dc_models & other_models
                    if not shared:
                        continue
                    ov = dc_vmax.get(other_dc, 2.0)
                    if ov < best_v:
                        best_v = ov
                        best_src = other_dc

                if best_src is None:
                    continue

                shared_models = dc_models & set(self._models_by_dc.get(best_src, []))
                model = self._pick_model(shared_models)
                if model is None:
                    continue

                replicas = max(1, self._config.gpus_per_shift // self._gpus_per_replica[model])
                src_active = best_src.state.active_replicas_by_model[model]
                if src_active < replicas:
                    continue
                commands.append(ShiftReplicas(model_label=model, replica_delta=-replicas, target=best_src))
                commands.append(ShiftReplicas(model_label=model, replica_delta=+replicas, target=dc))
                logger.info(
                    "LoadShift: overvoltage at %s (Vmax=%.4f), shift %s x%d replicas <- %s (Vmax=%.4f)",
                    dc.name,
                    vmax,
                    model,
                    replicas,
                    best_src.name,
                    best_v,
                )

        return commands

    def _is_batch_saturated(
        self,
        dc: OfflineDatacenter,
        is_undervoltage: bool,
    ) -> bool:
        """Check if all models at DC have batch sizes at their limit."""
        state = dc.state
        dc_models = self._models_by_dc.get(dc, [])
        for model_label in dc_models:
            current_bs = state.batch_size_by_model.get(model_label)
            feasible = self._feasible_bs.get(model_label, [])
            if not feasible or current_bs is None:
                continue
            if is_undervoltage:
                if current_bs > min(feasible):
                    return False
            else:
                if current_bs < max(feasible):
                    return False
        return True

    def _pick_model(self, shared_models: set[str]) -> str | None:
        """Pick the model with the most GPUs per replica (largest power impact)."""
        if not shared_models:
            return None
        return max(shared_models, key=lambda m: self._gpus_per_replica.get(m, 1))

    def reset(self) -> None:
        self._step_count = 0

    def start(self) -> None:
        pass

    def stop(self) -> None:
        pass

openg2g.controller.noop

No-op controller that does nothing.

NoopController

Bases: Controller[DatacenterBackend, GridBackend]

Controller that always returns an empty action.

Source code in openg2g/controller/noop.py
class NoopController(Controller[DatacenterBackend, GridBackend]):
    """Controller that always returns an empty action."""

    def __init__(self, dt_s: Fraction = Fraction(1)) -> None:
        self._dt_s = dt_s

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def reset(self) -> None:
        pass

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        return []

openg2g.controller.ofo

Online Feedback Optimization (OFO) batch-size controller.

Implements the primal-dual algorithm for joint voltage regulation and latency management via GPU batch size control.

OFOConfig

Bases: BaseModel

Online Feedback Optimization tuning parameters.

Attributes:

Name Type Description
primal_step_size float

Primal descent step size ρ_x (Eq. 8).

w_throughput float

Throughput weight in primal gradient.

w_switch float

Switching cost regularizer weight γ (Eq. 4a).

voltage_gradient_scale float

Scaling factor k_v for voltage dual term in the primal gradient.

v_min float

Lower voltage bound (pu).

v_max float

Upper voltage bound (pu).

voltage_dual_step_size float

Voltage dual ascent step size ρ_v (Eqs. 5-6).

latency_dual_step_size float

Latency dual ascent step size ρ_l (Eq. 7).

sensitivity_update_interval int

Steps between H-matrix re-estimation (0 = only once at init).

sensitivity_perturbation_kw float

Perturbation magnitude (kW) for finite-difference sensitivity estimation.

Source code in openg2g/controller/ofo.py
class OFOConfig(BaseModel):
    """Online Feedback Optimization tuning parameters.

    Attributes:
        primal_step_size: Primal descent step size ρ_x (Eq. 8).
        w_throughput: Throughput weight in primal gradient.
        w_switch: Switching cost regularizer weight γ (Eq. 4a).
        voltage_gradient_scale: Scaling factor k_v for voltage dual term
            in the primal gradient.
        v_min: Lower voltage bound (pu).
        v_max: Upper voltage bound (pu).
        voltage_dual_step_size: Voltage dual ascent step size ρ_v (Eqs. 5-6).
        latency_dual_step_size: Latency dual ascent step size ρ_l (Eq. 7).
        sensitivity_update_interval: Steps between H-matrix re-estimation
            (0 = only once at init).
        sensitivity_perturbation_kw: Perturbation magnitude (kW) for
            finite-difference sensitivity estimation.
    """

    model_config = ConfigDict(frozen=True)

    # Primal
    primal_step_size: float = 0.05
    w_throughput: float = 0.1
    w_switch: float = 0.0
    voltage_gradient_scale: float = 1e6

    # Dual
    v_min: float = 0.95
    v_max: float = 1.05
    voltage_dual_step_size: float = 0.5
    latency_dual_step_size: float = 1.0

    # Sensitivity
    sensitivity_update_interval: int = 0
    sensitivity_perturbation_kw: float = 100.0

LogisticModelStore

Per-model logistic models for power, latency, and throughput.

Used by OFOBatchSizeController to compute gradients of the Lagrangian with respect to batch size.

Attributes:

Name Type Description
COL_MODEL_LABEL

Column name for model label in the CSV.

COL_METRIC

Column name for metric type in the CSV.

Source code in openg2g/controller/ofo.py
class LogisticModelStore:
    """Per-model logistic models for power, latency, and throughput.

    Used by
    [`OFOBatchSizeController`][openg2g.controller.ofo.OFOBatchSizeController]
    to compute gradients of the Lagrangian with respect to batch size.

    Attributes:
        COL_MODEL_LABEL: Column name for model label in the CSV.
        COL_METRIC: Column name for metric type in the CSV.
    """

    COL_MODEL_LABEL = "model_label"
    COL_METRIC = "metric"

    def __init__(
        self,
        power: dict[str, LogisticModel],
        latency: dict[str, LogisticModel],
        throughput: dict[str, LogisticModel],
    ) -> None:
        self._power = dict(power)
        self._latency = dict(latency)
        self._throughput = dict(throughput)
        self._by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] | None = None

    def power(self, model: str) -> LogisticModel:
        """Return the power logistic model for a model label."""
        return self._power[model]

    def latency(self, model: str) -> LogisticModel:
        """Return the latency logistic model for a model label."""
        return self._latency[model]

    def throughput(self, model: str) -> LogisticModel:
        """Return the throughput logistic model for a model label."""
        return self._throughput[model]

    @property
    def power_fits(self) -> dict[str, LogisticModel]:
        return dict(self._power)

    @property
    def latency_fits(self) -> dict[str, LogisticModel]:
        return dict(self._latency)

    @property
    def throughput_fits(self) -> dict[str, LogisticModel]:
        return dict(self._throughput)

    @classmethod
    def generate(
        cls,
        models: tuple[InferenceModelSpec, ...],
        *,
        runs: Any = None,
        mlenergy_data_dir: Path | None = None,
    ) -> LogisticModelStore:
        """Generate logistic fits from ML.ENERGY benchmark data.

        Args:
            models: Model specifications; every field (`task`,
                `gpu_model`, `gpus_per_replica`, `batch_sizes`,
                `fit_exclude_batch_sizes`) is used for the benchmark query
                and fit selection.
            runs: Pre-loaded `LLMRuns` object. If `None`, loads from
                `mlenergy_data_dir` or the HuggingFace Hub.
            mlenergy_data_dir: Path to compiled mlenergy-data directory.
                Ignored if `runs` is provided.

        Returns:
            A new `LogisticModelStore` with fitted logistic models.
        """
        if runs is None:
            unique_tasks = {ms.task for ms in models}
            if mlenergy_data_dir:
                runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
            else:
                runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
        if not runs:
            raise ValueError("No runs found for the specified tasks")

        subsets_by_label: dict[str, Any] = {}
        exclude_by_label: dict[str, set[int]] = {}
        for ms in models:
            if not ms.model_id:
                raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")

            subset = (
                runs.model_id(ms.model_id)
                .gpu_model(ms.gpu_model)
                .num_gpus(ms.gpus_per_replica)
                .max_num_seqs(*ms.batch_sizes)
            )
            if not subset:
                raise ValueError(
                    f"Config matched zero runs for logistic fits: model_id={ms.model_id!r}, "
                    f"gpu_model={ms.gpu_model!r}, num_gpus={ms.gpus_per_replica}, "
                    f"batch_sizes={ms.batch_sizes}"
                )
            subsets_by_label[ms.model_label] = subset
            exclude_by_label[ms.model_label] = set(ms.fit_exclude_batch_sizes)

        all_by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] = {}
        power: dict[str, LogisticModel] = {}
        latency: dict[str, LogisticModel] = {}
        throughput: dict[str, LogisticModel] = {}
        for model_label, group in subsets_by_label.items():
            exclude = exclude_by_label[model_label]
            by_batch: dict[int, list[tuple[float, float, float]]] = {}
            for r in group:
                if r.max_num_seqs in exclude:
                    continue
                by_batch.setdefault(r.max_num_seqs, []).append(
                    (r.avg_power_watts, r.mean_itl_ms / 1000.0, r.output_throughput_tokens_per_sec)
                )
            all_by_batch[model_label] = by_batch

            batches = sorted(by_batch.keys())
            if not batches:
                continue

            x = np.log2(np.array(batches, dtype=float).clip(min=1))
            for _metric_name, idx, target in [
                ("power", 0, power),
                ("latency", 1, latency),
                ("throughput", 2, throughput),
            ]:
                y = np.array([float(np.median([t[idx] for t in by_batch[b]])) for b in batches])
                fit = LogisticModel.fit(x, y)
                target[model_label] = fit
                _warn_if_fit_suspicious(model_label, _metric_name, fit, batches, y)

        if not power and not latency and not throughput:
            raise ValueError("No logistic fit rows produced")
        store = cls(power=power, latency=latency, throughput=throughput)
        store._by_batch = all_by_batch
        return store

    def save(
        self,
        base_dir: Path,
        specs: tuple[InferenceModelSpec, ...],
        *,
        plot: bool = False,
    ) -> None:
        """Save per-spec logistic fits to `base_dir/<spec.cache_hash()>/logistic_fit.json`.

        Args:
            base_dir: Root of the per-spec cache (typically `data/specs/`).
            specs: Specs whose fits to save.
            plot: If `True`, write a diagnostic fits plot per spec's
                directory when `_by_batch` observations are available.
        """
        import json as _json

        base_dir = Path(base_dir)
        by_batch = getattr(self, "_by_batch", None)
        for spec in specs:
            label = spec.model_label
            if label not in self._power and label not in self._latency and label not in self._throughput:
                continue
            spec_dir = base_dir / spec.cache_hash()
            spec_dir.mkdir(parents=True, exist_ok=True)
            payload: dict[str, Any] = {
                "schema": "logistic_v1",
                "model_label": label,
            }
            for metric_name, store in (
                ("power", self._power),
                ("latency", self._latency),
                ("throughput", self._throughput),
            ):
                model = store.get(label)
                if model is not None:
                    payload[metric_name] = {"L": model.L, "x0": model.x0, "k": model.k, "b0": model.b0}
            (spec_dir / "logistic_fit.json").write_text(_json.dumps(payload, indent=2, sort_keys=True))

            if plot and by_batch is not None and label in by_batch:
                _plot_logistic_fits(
                    {label: by_batch[label]},
                    self._power,
                    self._latency,
                    self._throughput,
                    [label],
                    spec_dir,
                )

    @classmethod
    def load(
        cls,
        base_dir: Path | str,
        specs: tuple[InferenceModelSpec, ...],
    ) -> LogisticModelStore:
        """Load per-spec logistic fits from `base_dir/<spec.cache_hash()>/logistic_fit.json`.

        Args:
            base_dir: Root of the per-spec cache.
            specs: Specs whose fits to load.
        """
        import json as _json

        base_dir = Path(base_dir)
        power: dict[str, LogisticModel] = {}
        latency: dict[str, LogisticModel] = {}
        throughput: dict[str, LogisticModel] = {}
        for spec in specs:
            spec_dir = base_dir / spec.cache_hash()
            path = spec_dir / "logistic_fit.json"
            if not path.exists():
                raise FileNotFoundError(f"Logistic fit not found at {path} (spec={spec.model_label!r})")
            payload = _json.loads(path.read_text())
            label = spec.model_label
            for metric_name, target in (
                ("power", power),
                ("latency", latency),
                ("throughput", throughput),
            ):
                params = payload.get(metric_name)
                if params is not None:
                    target[label] = LogisticModel.from_dict(params)
        if not power and not latency and not throughput:
            raise ValueError(f"No logistic model entries loaded from {base_dir}")
        return cls(power=power, latency=latency, throughput=throughput)

    @classmethod
    def ensure(
        cls,
        base_dir: Path,
        models: tuple[InferenceModelSpec, ...],
        *,
        mlenergy_data_dir: Path | None = None,
        plot: bool = False,
    ) -> LogisticModelStore:
        """Load per-spec logistic fits under `base_dir`, generating missing ones.

        Any spec whose `base_dir/<hash>/logistic_fit.json` is absent
        triggers a targeted regeneration for just that spec; the cached
        ones are read as-is.

        Args:
            base_dir: Root of the per-spec cache.
            models: Model specifications.
            mlenergy_data_dir: Path to compiled mlenergy-data directory.
            plot: If `True`, generate a logistic fits plot per newly
                generated spec directory.
        """
        base_dir = Path(base_dir)
        base_dir.mkdir(parents=True, exist_ok=True)

        missing = tuple(ms for ms in models if not (base_dir / ms.cache_hash() / "logistic_fit.json").exists())
        if missing:
            missing_labels = [ms.model_label for ms in missing]
            logger.info(
                "Generating logistic fits for %d/%d specs (missing=%s) under %s ...",
                len(missing),
                len(models),
                missing_labels,
                base_dir,
            )
            cls.generate(missing, mlenergy_data_dir=mlenergy_data_dir).save(base_dir, missing, plot=plot)
        return cls.load(base_dir, models)

power(model)

Return the power logistic model for a model label.

Source code in openg2g/controller/ofo.py
def power(self, model: str) -> LogisticModel:
    """Return the power logistic model for a model label."""
    return self._power[model]

latency(model)

Return the latency logistic model for a model label.

Source code in openg2g/controller/ofo.py
def latency(self, model: str) -> LogisticModel:
    """Return the latency logistic model for a model label."""
    return self._latency[model]

throughput(model)

Return the throughput logistic model for a model label.

Source code in openg2g/controller/ofo.py
def throughput(self, model: str) -> LogisticModel:
    """Return the throughput logistic model for a model label."""
    return self._throughput[model]

generate(models, *, runs=None, mlenergy_data_dir=None) classmethod

Generate logistic fits from ML.ENERGY benchmark data.

Parameters:

Name Type Description Default
models tuple[InferenceModelSpec, ...]

Model specifications; every field (task, gpu_model, gpus_per_replica, batch_sizes, fit_exclude_batch_sizes) is used for the benchmark query and fit selection.

required
runs Any

Pre-loaded LLMRuns object. If None, loads from mlenergy_data_dir or the HuggingFace Hub.

None
mlenergy_data_dir Path | None

Path to compiled mlenergy-data directory. Ignored if runs is provided.

None

Returns:

Type Description
LogisticModelStore

A new LogisticModelStore with fitted logistic models.

Source code in openg2g/controller/ofo.py
@classmethod
def generate(
    cls,
    models: tuple[InferenceModelSpec, ...],
    *,
    runs: Any = None,
    mlenergy_data_dir: Path | None = None,
) -> LogisticModelStore:
    """Generate logistic fits from ML.ENERGY benchmark data.

    Args:
        models: Model specifications; every field (`task`,
            `gpu_model`, `gpus_per_replica`, `batch_sizes`,
            `fit_exclude_batch_sizes`) is used for the benchmark query
            and fit selection.
        runs: Pre-loaded `LLMRuns` object. If `None`, loads from
            `mlenergy_data_dir` or the HuggingFace Hub.
        mlenergy_data_dir: Path to compiled mlenergy-data directory.
            Ignored if `runs` is provided.

    Returns:
        A new `LogisticModelStore` with fitted logistic models.
    """
    if runs is None:
        unique_tasks = {ms.task for ms in models}
        if mlenergy_data_dir:
            runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
        else:
            runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
    if not runs:
        raise ValueError("No runs found for the specified tasks")

    subsets_by_label: dict[str, Any] = {}
    exclude_by_label: dict[str, set[int]] = {}
    for ms in models:
        if not ms.model_id:
            raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")

        subset = (
            runs.model_id(ms.model_id)
            .gpu_model(ms.gpu_model)
            .num_gpus(ms.gpus_per_replica)
            .max_num_seqs(*ms.batch_sizes)
        )
        if not subset:
            raise ValueError(
                f"Config matched zero runs for logistic fits: model_id={ms.model_id!r}, "
                f"gpu_model={ms.gpu_model!r}, num_gpus={ms.gpus_per_replica}, "
                f"batch_sizes={ms.batch_sizes}"
            )
        subsets_by_label[ms.model_label] = subset
        exclude_by_label[ms.model_label] = set(ms.fit_exclude_batch_sizes)

    all_by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] = {}
    power: dict[str, LogisticModel] = {}
    latency: dict[str, LogisticModel] = {}
    throughput: dict[str, LogisticModel] = {}
    for model_label, group in subsets_by_label.items():
        exclude = exclude_by_label[model_label]
        by_batch: dict[int, list[tuple[float, float, float]]] = {}
        for r in group:
            if r.max_num_seqs in exclude:
                continue
            by_batch.setdefault(r.max_num_seqs, []).append(
                (r.avg_power_watts, r.mean_itl_ms / 1000.0, r.output_throughput_tokens_per_sec)
            )
        all_by_batch[model_label] = by_batch

        batches = sorted(by_batch.keys())
        if not batches:
            continue

        x = np.log2(np.array(batches, dtype=float).clip(min=1))
        for _metric_name, idx, target in [
            ("power", 0, power),
            ("latency", 1, latency),
            ("throughput", 2, throughput),
        ]:
            y = np.array([float(np.median([t[idx] for t in by_batch[b]])) for b in batches])
            fit = LogisticModel.fit(x, y)
            target[model_label] = fit
            _warn_if_fit_suspicious(model_label, _metric_name, fit, batches, y)

    if not power and not latency and not throughput:
        raise ValueError("No logistic fit rows produced")
    store = cls(power=power, latency=latency, throughput=throughput)
    store._by_batch = all_by_batch
    return store

save(base_dir, specs, *, plot=False)

Save per-spec logistic fits to base_dir/<spec.cache_hash()>/logistic_fit.json.

Parameters:

Name Type Description Default
base_dir Path

Root of the per-spec cache (typically data/specs/).

required
specs tuple[InferenceModelSpec, ...]

Specs whose fits to save.

required
plot bool

If True, write a diagnostic fits plot per spec's directory when _by_batch observations are available.

False
Source code in openg2g/controller/ofo.py
def save(
    self,
    base_dir: Path,
    specs: tuple[InferenceModelSpec, ...],
    *,
    plot: bool = False,
) -> None:
    """Save per-spec logistic fits to `base_dir/<spec.cache_hash()>/logistic_fit.json`.

    Args:
        base_dir: Root of the per-spec cache (typically `data/specs/`).
        specs: Specs whose fits to save.
        plot: If `True`, write a diagnostic fits plot per spec's
            directory when `_by_batch` observations are available.
    """
    import json as _json

    base_dir = Path(base_dir)
    by_batch = getattr(self, "_by_batch", None)
    for spec in specs:
        label = spec.model_label
        if label not in self._power and label not in self._latency and label not in self._throughput:
            continue
        spec_dir = base_dir / spec.cache_hash()
        spec_dir.mkdir(parents=True, exist_ok=True)
        payload: dict[str, Any] = {
            "schema": "logistic_v1",
            "model_label": label,
        }
        for metric_name, store in (
            ("power", self._power),
            ("latency", self._latency),
            ("throughput", self._throughput),
        ):
            model = store.get(label)
            if model is not None:
                payload[metric_name] = {"L": model.L, "x0": model.x0, "k": model.k, "b0": model.b0}
        (spec_dir / "logistic_fit.json").write_text(_json.dumps(payload, indent=2, sort_keys=True))

        if plot and by_batch is not None and label in by_batch:
            _plot_logistic_fits(
                {label: by_batch[label]},
                self._power,
                self._latency,
                self._throughput,
                [label],
                spec_dir,
            )

load(base_dir, specs) classmethod

Load per-spec logistic fits from base_dir/<spec.cache_hash()>/logistic_fit.json.

Parameters:

Name Type Description Default
base_dir Path | str

Root of the per-spec cache.

required
specs tuple[InferenceModelSpec, ...]

Specs whose fits to load.

required
Source code in openg2g/controller/ofo.py
@classmethod
def load(
    cls,
    base_dir: Path | str,
    specs: tuple[InferenceModelSpec, ...],
) -> LogisticModelStore:
    """Load per-spec logistic fits from `base_dir/<spec.cache_hash()>/logistic_fit.json`.

    Args:
        base_dir: Root of the per-spec cache.
        specs: Specs whose fits to load.
    """
    import json as _json

    base_dir = Path(base_dir)
    power: dict[str, LogisticModel] = {}
    latency: dict[str, LogisticModel] = {}
    throughput: dict[str, LogisticModel] = {}
    for spec in specs:
        spec_dir = base_dir / spec.cache_hash()
        path = spec_dir / "logistic_fit.json"
        if not path.exists():
            raise FileNotFoundError(f"Logistic fit not found at {path} (spec={spec.model_label!r})")
        payload = _json.loads(path.read_text())
        label = spec.model_label
        for metric_name, target in (
            ("power", power),
            ("latency", latency),
            ("throughput", throughput),
        ):
            params = payload.get(metric_name)
            if params is not None:
                target[label] = LogisticModel.from_dict(params)
    if not power and not latency and not throughput:
        raise ValueError(f"No logistic model entries loaded from {base_dir}")
    return cls(power=power, latency=latency, throughput=throughput)

ensure(base_dir, models, *, mlenergy_data_dir=None, plot=False) classmethod

Load per-spec logistic fits under base_dir, generating missing ones.

Any spec whose base_dir/<hash>/logistic_fit.json is absent triggers a targeted regeneration for just that spec; the cached ones are read as-is.

Parameters:

Name Type Description Default
base_dir Path

Root of the per-spec cache.

required
models tuple[InferenceModelSpec, ...]

Model specifications.

required
mlenergy_data_dir Path | None

Path to compiled mlenergy-data directory.

None
plot bool

If True, generate a logistic fits plot per newly generated spec directory.

False
Source code in openg2g/controller/ofo.py
@classmethod
def ensure(
    cls,
    base_dir: Path,
    models: tuple[InferenceModelSpec, ...],
    *,
    mlenergy_data_dir: Path | None = None,
    plot: bool = False,
) -> LogisticModelStore:
    """Load per-spec logistic fits under `base_dir`, generating missing ones.

    Any spec whose `base_dir/<hash>/logistic_fit.json` is absent
    triggers a targeted regeneration for just that spec; the cached
    ones are read as-is.

    Args:
        base_dir: Root of the per-spec cache.
        models: Model specifications.
        mlenergy_data_dir: Path to compiled mlenergy-data directory.
        plot: If `True`, generate a logistic fits plot per newly
            generated spec directory.
    """
    base_dir = Path(base_dir)
    base_dir.mkdir(parents=True, exist_ok=True)

    missing = tuple(ms for ms in models if not (base_dir / ms.cache_hash() / "logistic_fit.json").exists())
    if missing:
        missing_labels = [ms.model_label for ms in missing]
        logger.info(
            "Generating logistic fits for %d/%d specs (missing=%s) under %s ...",
            len(missing),
            len(models),
            missing_labels,
            base_dir,
        )
        cls.generate(missing, mlenergy_data_dir=mlenergy_data_dir).save(base_dir, missing, plot=plot)
    return cls.load(base_dir, models)

VoltageDualVariables

Full-network duals for voltage box constraints.

Maintains per-bus dual variables for under- and overvoltage and updates them via projected gradient ascent:

dual_undervoltage  <- [dual_undervoltage  + ρ_v * (v_min - v̂)]+
dual_overvoltage   <- [dual_overvoltage   + ρ_v * (v̂ - v_max)]+

Parameters:

Name Type Description Default
n_bus_phases int

Number of bus-phase pairs in the voltage vector (3M).

required
config OFOConfig

OFO configuration (voltage bounds and dual step size).

required
Source code in openg2g/controller/ofo.py
class VoltageDualVariables:
    """Full-network duals for voltage box constraints.

    Maintains per-bus dual variables for under- and overvoltage and updates
    them via projected gradient ascent:

        dual_undervoltage  <- [dual_undervoltage  + ρ_v * (v_min - v̂)]+
        dual_overvoltage   <- [dual_overvoltage   + ρ_v * (v̂ - v_max)]+

    Args:
        n_bus_phases: Number of bus-phase pairs in the voltage vector (3M).
        config: OFO configuration (voltage bounds and dual step size).
    """

    def __init__(self, n_bus_phases: int, config: OFOConfig) -> None:
        self.config = config
        self.dual_undervoltage = np.zeros(int(n_bus_phases), dtype=float)  # λ in G2G paper Eq. 5
        self.dual_overvoltage = np.zeros(int(n_bus_phases), dtype=float)  # λ̄ in G2G paper Eq. 6

    def update(self, observed_voltages: np.ndarray) -> None:
        """Update duals given observed voltage vector.

        Args:
            observed_voltages: Observed voltage magnitudes (pu), shape
                `(n_bus_phases,)`.

        Raises:
            ValueError: If `observed_voltages` length does not match the dual
                dimension.
        """
        observed_voltages = np.asarray(observed_voltages, float).reshape(-1)
        if observed_voltages.shape[0] != self.dual_undervoltage.shape[0]:
            raise ValueError(
                f"observed_voltages has len {observed_voltages.shape[0]} "
                f"but duals have len {self.dual_undervoltage.shape[0]}"
            )
        vmin = float(self.config.v_min)
        vmax = float(self.config.v_max)
        rho = float(self.config.voltage_dual_step_size)
        self.dual_undervoltage = np.maximum(self.dual_undervoltage + rho * (vmin - observed_voltages), 0.0)
        self.dual_overvoltage = np.maximum(self.dual_overvoltage + rho * (observed_voltages - vmax), 0.0)

    def dual_difference(self) -> np.ndarray:
        """Return the voltage dual difference (η = λ̄ − λ, Appendix B)."""
        return self.dual_overvoltage - self.dual_undervoltage

update(observed_voltages)

Update duals given observed voltage vector.

Parameters:

Name Type Description Default
observed_voltages ndarray

Observed voltage magnitudes (pu), shape (n_bus_phases,).

required

Raises:

Type Description
ValueError

If observed_voltages length does not match the dual dimension.

Source code in openg2g/controller/ofo.py
def update(self, observed_voltages: np.ndarray) -> None:
    """Update duals given observed voltage vector.

    Args:
        observed_voltages: Observed voltage magnitudes (pu), shape
            `(n_bus_phases,)`.

    Raises:
        ValueError: If `observed_voltages` length does not match the dual
            dimension.
    """
    observed_voltages = np.asarray(observed_voltages, float).reshape(-1)
    if observed_voltages.shape[0] != self.dual_undervoltage.shape[0]:
        raise ValueError(
            f"observed_voltages has len {observed_voltages.shape[0]} "
            f"but duals have len {self.dual_undervoltage.shape[0]}"
        )
    vmin = float(self.config.v_min)
    vmax = float(self.config.v_max)
    rho = float(self.config.voltage_dual_step_size)
    self.dual_undervoltage = np.maximum(self.dual_undervoltage + rho * (vmin - observed_voltages), 0.0)
    self.dual_overvoltage = np.maximum(self.dual_overvoltage + rho * (observed_voltages - vmax), 0.0)

dual_difference()

Return the voltage dual difference (η = λ̄ − λ, Appendix B).

Source code in openg2g/controller/ofo.py
def dual_difference(self) -> np.ndarray:
    """Return the voltage dual difference (η = λ̄ − λ, Appendix B)."""
    return self.dual_overvoltage - self.dual_undervoltage

PrimalBatchOptimizer

Primal batch-size optimizer operating in log2 space.

Maintains continuous state x_i = log2(batch_i) per model and applies a gradient descent step using voltage duals, latency duals, and fitted power/latency/throughput curves.

Parameters:

Name Type Description Default
models list[InferenceModelSpec]

Model specifications for each served model.

required
feasible_batch_sizes list[int]

Allowed batch sizes (union across all models).

required
power_fits dict[str, LogisticModel]

Per-model logistic fit for power vs log2(batch_size).

required
latency_fits dict[str, LogisticModel]

Per-model logistic fit for latency vs log2(batch_size).

required
throughput_fits dict[str, LogisticModel]

Per-model logistic fit for throughput vs log2(batch_size).

required
config OFOConfig

OFO configuration (step size, throughput/switch weights, voltage gradient scale).

required
Source code in openg2g/controller/ofo.py
class PrimalBatchOptimizer:
    """Primal batch-size optimizer operating in log2 space.

    Maintains continuous state `x_i = log2(batch_i)` per model and applies
    a gradient descent step using voltage duals, latency duals, and fitted
    power/latency/throughput curves.

    Args:
        models: Model specifications for each served model.
        feasible_batch_sizes: Allowed batch sizes (union across all models).
        power_fits: Per-model logistic fit for power vs log2(batch_size).
        latency_fits: Per-model logistic fit for latency vs log2(batch_size).
        throughput_fits: Per-model logistic fit for throughput vs
            log2(batch_size).
        config: OFO configuration (step size, throughput/switch weights,
            voltage gradient scale).
    """

    def __init__(
        self,
        *,
        models: list[InferenceModelSpec],
        feasible_batch_sizes: list[int],
        power_fits: dict[str, LogisticModel],
        latency_fits: dict[str, LogisticModel],
        throughput_fits: dict[str, LogisticModel],
        config: OFOConfig,
    ) -> None:
        self.models = list(models)
        self.feasible_batch_sizes = sorted({int(b) for b in feasible_batch_sizes})
        if not self.feasible_batch_sizes:
            raise ValueError("feasible_batch_sizes cannot be empty.")

        self.power_fits = power_fits
        self.latency_fits = latency_fits
        self.throughput_fits = throughput_fits
        self.config = config

        self.log_batch_size_min = math.log2(min(self.feasible_batch_sizes))
        self.log_batch_size_max = math.log2(max(self.feasible_batch_sizes))

        self.log_batch_size_by_model: dict[str, float] = {
            ms.model_label: float(self.log_batch_size_max) for ms in self.models
        }
        self.prev_log_batch_size_by_model: dict[str, float] = dict(self.log_batch_size_by_model)

        # Per-model throughput normalization: r_i(x_max) for a single replica
        self.throughput_max_by_model: dict[str, float] = {}
        b_max = int(max(self.feasible_batch_sizes))
        for ms in self.models:
            label = ms.model_label
            try:
                th_max = float(self.throughput_fits[label].eval(b_max))
            except Exception:
                th_max = float("nan")
            if (not np.isfinite(th_max)) or (th_max <= 0.0):
                th_max = 1.0
            self.throughput_max_by_model[label] = th_max

    def _clamp_log_batch_size(self, log_batch_size: float) -> float:
        return float(min(max(float(log_batch_size), self.log_batch_size_min), self.log_batch_size_max))

    def _discretize_batch(self, log_batch_size: float) -> int:
        b_cont = 2.0 ** float(log_batch_size)
        idx = bisect.bisect_left(self.feasible_batch_sizes, b_cont)
        candidates = []
        if idx > 0:
            candidates.append(self.feasible_batch_sizes[idx - 1])
        if idx < len(self.feasible_batch_sizes):
            candidates.append(self.feasible_batch_sizes[idx])
        return int(min(candidates, key=lambda bb: abs(bb - b_cont)))

    def init_from_batches(self, batch_init: dict[str, int]) -> None:
        """Initialize log-batch-size state from discrete batch sizes."""
        for ms in self.models:
            label = ms.model_label
            b = int(batch_init.get(label, max(self.feasible_batch_sizes)))
            log_batch_size = math.log2(max(b, 1))
            log_batch_size = self._clamp_log_batch_size(log_batch_size)
            self.log_batch_size_by_model[label] = float(log_batch_size)
            self.prev_log_batch_size_by_model[label] = float(log_batch_size)

    def step(
        self,
        *,
        voltage_dual_diff: np.ndarray,
        sensitivity_matrix: np.ndarray,
        phase_share_by_model: dict[str, np.ndarray],
        latency_dual_by_model: dict[str, float] | None = None,
        replica_count_by_model: dict[str, float] | None = None,
    ) -> dict[str, int]:
        """Primal gradient descent step.

        Args:
            voltage_dual_diff: Voltage dual difference vector
                (η = λ̄ − λ), shape `(n_bus_phases,)`.
            sensitivity_matrix: Voltage sensitivity matrix (H = dv/dp),
                shape `(n_bus_phases, 3)`.
            phase_share_by_model: Per-model normalized phase share vectors,
                shape `(3,)` each.
            latency_dual_by_model: Per-model latency dual variables (μ_i).
            replica_count_by_model: Per-model active replica counts (w_i).

        Returns:
            Next batch sizes per model.
        """
        voltage_dual_diff = np.asarray(voltage_dual_diff, float).reshape(-1)
        sensitivity_matrix = np.asarray(sensitivity_matrix, float)
        latency_dual_by_model = {} if latency_dual_by_model is None else dict(latency_dual_by_model)
        replica_count_by_model = {} if replica_count_by_model is None else dict(replica_count_by_model)

        step_size = float(self.config.primal_step_size)  # ρ_x
        w_throughput = float(self.config.w_throughput)
        w_switch = float(self.config.w_switch)
        voltage_gradient_scale = float(self.config.voltage_gradient_scale)

        batch_next: dict[str, int] = {}

        for ms in self.models:
            label = ms.model_label
            log_batch_size = float(self.log_batch_size_by_model[label])
            prev_log_batch_size = float(self.prev_log_batch_size_by_model.get(label, log_batch_size))

            replica_count = float(replica_count_by_model.get(label, 0.0))  # w_i
            if (not np.isfinite(replica_count)) or (replica_count < 0.0):
                replica_count = 0.0

            phase_share = np.asarray(  # e_i (phase-allocation weight, p.7)
                phase_share_by_model.get(label, np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)),
                float,
            ).reshape(3)
            s = float(np.sum(phase_share))
            if (not np.isfinite(s)) or s <= 0.0:
                phase_share = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
            else:
                phase_share = phase_share / s

            weighted_sensitivity = sensitivity_matrix @ phase_share  # H @ e_i
            voltage_gradient = float(voltage_dual_diff @ weighted_sensitivity)

            dPdx_1 = float(self.power_fits[label].deriv_wrt_x(log_batch_size))
            dLdx_1 = float(self.latency_fits[label].deriv_wrt_x(log_batch_size))
            dThdx_1 = float(self.throughput_fits[label].deriv_wrt_x(log_batch_size))

            dPdx_1_kw = dPdx_1 / 1000.0

            th_max = float(self.throughput_max_by_model.get(label, 1.0))
            if (not np.isfinite(th_max)) or (th_max <= 0.0):
                th_max = 1.0
            dThdx_norm_1 = dThdx_1 / th_max

            dPdx = replica_count * dPdx_1_kw
            dThdx = replica_count * dThdx_norm_1
            dLdx = dLdx_1

            latency_dual = float(latency_dual_by_model.get(label, 0.0))  # μ_i
            if (not np.isfinite(latency_dual)) or (latency_dual < 0.0):
                latency_dual = 0.0

            # Gradient of the Lagrangian w.r.t. x_i = log2(batch_i).
            # G2G paper Eq. 18: nabla_x L = -dR/dx (throughput)
            #                              + 2*gamma*(x - x_prev) (switching)
            #                              + eta^T H e_i dP/dx (voltage dual)
            #                              + mu_i * dL/dx (latency dual)
            # Implementation extensions: wT scaling on throughput,
            #                            k_v scaling on voltage term
            grad = 0.0
            grad -= w_throughput * dThdx
            grad += voltage_gradient_scale * voltage_gradient * dPdx
            grad += latency_dual * dLdx
            grad += w_switch * (log_batch_size - prev_log_batch_size)

            new_log_batch_size = self._clamp_log_batch_size(log_batch_size - step_size * grad)
            self.prev_log_batch_size_by_model[label] = log_batch_size
            self.log_batch_size_by_model[label] = new_log_batch_size
            batch_next[label] = self._discretize_batch(new_log_batch_size)

        return batch_next

init_from_batches(batch_init)

Initialize log-batch-size state from discrete batch sizes.

Source code in openg2g/controller/ofo.py
def init_from_batches(self, batch_init: dict[str, int]) -> None:
    """Initialize log-batch-size state from discrete batch sizes."""
    for ms in self.models:
        label = ms.model_label
        b = int(batch_init.get(label, max(self.feasible_batch_sizes)))
        log_batch_size = math.log2(max(b, 1))
        log_batch_size = self._clamp_log_batch_size(log_batch_size)
        self.log_batch_size_by_model[label] = float(log_batch_size)
        self.prev_log_batch_size_by_model[label] = float(log_batch_size)

step(*, voltage_dual_diff, sensitivity_matrix, phase_share_by_model, latency_dual_by_model=None, replica_count_by_model=None)

Primal gradient descent step.

Parameters:

Name Type Description Default
voltage_dual_diff ndarray

Voltage dual difference vector (η = λ̄ − λ), shape (n_bus_phases,).

required
sensitivity_matrix ndarray

Voltage sensitivity matrix (H = dv/dp), shape (n_bus_phases, 3).

required
phase_share_by_model dict[str, ndarray]

Per-model normalized phase share vectors, shape (3,) each.

required
latency_dual_by_model dict[str, float] | None

Per-model latency dual variables (μ_i).

None
replica_count_by_model dict[str, float] | None

Per-model active replica counts (w_i).

None

Returns:

Type Description
dict[str, int]

Next batch sizes per model.

Source code in openg2g/controller/ofo.py
def step(
    self,
    *,
    voltage_dual_diff: np.ndarray,
    sensitivity_matrix: np.ndarray,
    phase_share_by_model: dict[str, np.ndarray],
    latency_dual_by_model: dict[str, float] | None = None,
    replica_count_by_model: dict[str, float] | None = None,
) -> dict[str, int]:
    """Primal gradient descent step.

    Args:
        voltage_dual_diff: Voltage dual difference vector
            (η = λ̄ − λ), shape `(n_bus_phases,)`.
        sensitivity_matrix: Voltage sensitivity matrix (H = dv/dp),
            shape `(n_bus_phases, 3)`.
        phase_share_by_model: Per-model normalized phase share vectors,
            shape `(3,)` each.
        latency_dual_by_model: Per-model latency dual variables (μ_i).
        replica_count_by_model: Per-model active replica counts (w_i).

    Returns:
        Next batch sizes per model.
    """
    voltage_dual_diff = np.asarray(voltage_dual_diff, float).reshape(-1)
    sensitivity_matrix = np.asarray(sensitivity_matrix, float)
    latency_dual_by_model = {} if latency_dual_by_model is None else dict(latency_dual_by_model)
    replica_count_by_model = {} if replica_count_by_model is None else dict(replica_count_by_model)

    step_size = float(self.config.primal_step_size)  # ρ_x
    w_throughput = float(self.config.w_throughput)
    w_switch = float(self.config.w_switch)
    voltage_gradient_scale = float(self.config.voltage_gradient_scale)

    batch_next: dict[str, int] = {}

    for ms in self.models:
        label = ms.model_label
        log_batch_size = float(self.log_batch_size_by_model[label])
        prev_log_batch_size = float(self.prev_log_batch_size_by_model.get(label, log_batch_size))

        replica_count = float(replica_count_by_model.get(label, 0.0))  # w_i
        if (not np.isfinite(replica_count)) or (replica_count < 0.0):
            replica_count = 0.0

        phase_share = np.asarray(  # e_i (phase-allocation weight, p.7)
            phase_share_by_model.get(label, np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)),
            float,
        ).reshape(3)
        s = float(np.sum(phase_share))
        if (not np.isfinite(s)) or s <= 0.0:
            phase_share = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
        else:
            phase_share = phase_share / s

        weighted_sensitivity = sensitivity_matrix @ phase_share  # H @ e_i
        voltage_gradient = float(voltage_dual_diff @ weighted_sensitivity)

        dPdx_1 = float(self.power_fits[label].deriv_wrt_x(log_batch_size))
        dLdx_1 = float(self.latency_fits[label].deriv_wrt_x(log_batch_size))
        dThdx_1 = float(self.throughput_fits[label].deriv_wrt_x(log_batch_size))

        dPdx_1_kw = dPdx_1 / 1000.0

        th_max = float(self.throughput_max_by_model.get(label, 1.0))
        if (not np.isfinite(th_max)) or (th_max <= 0.0):
            th_max = 1.0
        dThdx_norm_1 = dThdx_1 / th_max

        dPdx = replica_count * dPdx_1_kw
        dThdx = replica_count * dThdx_norm_1
        dLdx = dLdx_1

        latency_dual = float(latency_dual_by_model.get(label, 0.0))  # μ_i
        if (not np.isfinite(latency_dual)) or (latency_dual < 0.0):
            latency_dual = 0.0

        # Gradient of the Lagrangian w.r.t. x_i = log2(batch_i).
        # G2G paper Eq. 18: nabla_x L = -dR/dx (throughput)
        #                              + 2*gamma*(x - x_prev) (switching)
        #                              + eta^T H e_i dP/dx (voltage dual)
        #                              + mu_i * dL/dx (latency dual)
        # Implementation extensions: wT scaling on throughput,
        #                            k_v scaling on voltage term
        grad = 0.0
        grad -= w_throughput * dThdx
        grad += voltage_gradient_scale * voltage_gradient * dPdx
        grad += latency_dual * dLdx
        grad += w_switch * (log_batch_size - prev_log_batch_size)

        new_log_batch_size = self._clamp_log_batch_size(log_batch_size - step_size * grad)
        self.prev_log_batch_size_by_model[label] = log_batch_size
        self.log_batch_size_by_model[label] = new_log_batch_size
        batch_next[label] = self._discretize_batch(new_log_batch_size)

    return batch_next

OFOBatchSizeController

Bases: Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]

Online Feedback Optimization controller for batch-size regulation.

Reads grid voltage and datacenter state, updates voltage and latency duals, runs the primal batch-size optimizer, and returns new batch sizes. Latency dual updates use dc_state.observed_itl_s_by_model .

Parameters:

Name Type Description Default
inference_models tuple[InferenceModelSpec, ...]

Model specifications served in the datacenter.

required
datacenter LLMBatchSizeControlledDatacenter[LLMDatacenterState]

The datacenter whose batch sizes this controller regulates. Used for sensitivity-matrix perturbations.

required
grid OpenDSSGrid

The grid attached to the datacenter. Used for sensitivity-matrix perturbations and voltage-dual updates.

required
models LogisticModelStore

Per-model logistic models for power, latency, and throughput used in gradient computation.

required
config OFOConfig | None

Unified OFO tuning parameters. Defaults to an OFOConfig() with default tunings.

None
dt_s Fraction

Control interval (seconds).

Fraction(1)
initial_batch_sizes dict[str, int] | None

Optional per-model initial batch size used to seed the primal optimizer. Any model omitted from the mapping starts at its feasible_batch_sizes[0].

None
Source code in openg2g/controller/ofo.py
class OFOBatchSizeController(Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]):
    """Online Feedback Optimization controller for batch-size regulation.

    Reads grid voltage and datacenter state, updates voltage and latency
    duals, runs the primal batch-size optimizer, and returns new batch
    sizes. Latency dual updates use [`dc_state.observed_itl_s_by_model`
    ][openg2g.datacenter.base.LLMDatacenterState.observed_itl_s_by_model].

    Args:
        inference_models: Model specifications served in the datacenter.
        datacenter: The datacenter whose batch sizes this controller
            regulates. Used for sensitivity-matrix perturbations.
        grid: The grid attached to the datacenter. Used for
            sensitivity-matrix perturbations and voltage-dual updates.
        models: Per-model logistic models for power, latency, and
            throughput used in gradient computation.
        config: Unified OFO tuning parameters. Defaults to an
            `OFOConfig()` with default tunings.
        dt_s: Control interval (seconds).
        initial_batch_sizes: Optional per-model initial batch size used
            to seed the primal optimizer. Any model omitted from the
            mapping starts at its `feasible_batch_sizes[0]`.
    """

    def __init__(
        self,
        inference_models: tuple[InferenceModelSpec, ...],
        datacenter: LLMBatchSizeControlledDatacenter[LLMDatacenterState],
        grid: OpenDSSGrid,
        models: LogisticModelStore,
        config: OFOConfig | None = None,
        dt_s: Fraction = Fraction(1),
        initial_batch_sizes: dict[str, int] | None = None,
    ) -> None:
        if config is None:
            config = OFOConfig()

        if not inference_models:
            raise ValueError("inference_models must not be empty.")
        labels = [ms.model_label for ms in inference_models]
        if len(labels) != len(set(labels)):
            raise ValueError(f"Duplicate model labels: {labels}")

        model_specs = list(inference_models)
        self._initial_batch_sizes = initial_batch_sizes or {}

        for ms in model_specs:
            label = ms.model_label
            for metric_name, accessor in [
                ("power", models.power),
                ("latency", models.latency),
                ("throughput", models.throughput),
            ]:
                try:
                    accessor(label)
                except KeyError:
                    raise ValueError(f"LogisticModelStore missing {metric_name} model for {label!r}.") from None

        self._dt_s = dt_s
        self._models = model_specs
        self._config = config
        self._datacenter = datacenter
        self._grid = grid
        self._itl_deadline_by_model = {ms.model_label: ms.itl_deadline_s for ms in model_specs}

        self._voltage_dual: VoltageDualVariables | None = None
        self._latency_dual_by_model: dict[str, float] = {ms.model_label: 0.0 for ms in model_specs}

        all_bs: set[int] = set()
        for ms in model_specs:
            all_bs.update(ms.feasible_batch_sizes)
        feasible_batch_sizes = sorted(all_bs)

        self._optimizer = PrimalBatchOptimizer(
            models=model_specs,
            feasible_batch_sizes=feasible_batch_sizes,
            power_fits=models.power_fits,
            latency_fits=models.latency_fits,
            throughput_fits=models.throughput_fits,
            config=config,
        )
        self._optimizer.init_from_batches(
            {
                ms.model_label: self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0])
                for ms in model_specs
            }
        )

        self._sensitivity_matrix: np.ndarray | None = None
        self._control_step_count: int = 0

        logger.info(
            "OFOBatchSizeController: %d models, dt=%s s, feasible_batches=%s",
            len(model_specs),
            dt_s,
            feasible_batch_sizes,
        )

    def reset(self) -> None:
        self._voltage_dual = None
        self._latency_dual_by_model = {ms.model_label: 0.0 for ms in self._models}
        self._optimizer.init_from_batches(
            {
                ms.model_label: self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0])
                for ms in self._models
            }
        )
        self._sensitivity_matrix = None
        self._control_step_count = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        datacenter = self._datacenter
        grid = self._grid

        if self._voltage_dual is None:
            self._voltage_dual = VoltageDualVariables(len(grid.v_index), self._config)

        # 1. Re-estimate sensitivity if needed
        if self._sensitivity_matrix is None or (
            self._config.sensitivity_update_interval > 0
            and self._control_step_count % self._config.sensitivity_update_interval == 0
        ):
            self._sensitivity_matrix, _ = grid.estimate_sensitivity(
                perturbation_kw=self._config.sensitivity_perturbation_kw,
                dc=self._datacenter,
            )

        # 2. Update voltage duals from grid state
        observed_voltages = grid.voltages_vector()
        self._voltage_dual.update(observed_voltages)

        voltage_dual_diff = self._voltage_dual.dual_difference()  # η = λ̄ − λ

        # 3. Read observed latency from datacenter and update latency duals
        dc_state = datacenter.state
        missing_replicas = [
            ms.model_label for ms in self._models if ms.model_label not in dc_state.active_replicas_by_model
        ]
        if missing_replicas:
            miss = ", ".join(sorted(missing_replicas))
            raise RuntimeError(
                f"OFOBatchSizeController requires active_replicas_by_model for all models. Missing: {miss}."
            )
        missing_itl = [ms.model_label for ms in self._models if ms.model_label not in dc_state.observed_itl_s_by_model]
        if missing_itl:
            miss = ", ".join(sorted(missing_itl))
            raise RuntimeError(
                f"OFOBatchSizeController requires observed_itl_s_by_model for all models. Missing: {miss}."
            )
        for ms in self._models:
            label = ms.model_label
            num_replicas = max(int(dc_state.active_replicas_by_model[label]), 0)
            observed_itl = float(dc_state.observed_itl_s_by_model[label])
            if num_replicas <= 0:
                logger.debug("Model %s has 0 replicas, skipping latency dual update", label)
                observed_itl = float("nan")

            deadline = float(self._itl_deadline_by_model[label])
            if np.isfinite(observed_itl):
                self._latency_dual_by_model[label] = max(
                    self._latency_dual_by_model[label]
                    + self._config.latency_dual_step_size * (observed_itl - deadline),
                    0.0,
                )
            else:
                self._latency_dual_by_model[label] = max(self._latency_dual_by_model[label], 0.0)

        # 4. Compute replica counts
        replica_count_by_model: dict[str, float] = {}
        for ms in self._models:
            label = ms.model_label
            replica_count_by_model[label] = float(dc_state.active_replicas_by_model[label])

        # 5. Primal update -> next batch sizes
        batch_next = self._optimizer.step(
            voltage_dual_diff=voltage_dual_diff,
            sensitivity_matrix=self._sensitivity_matrix,
            phase_share_by_model=datacenter.phase_share_by_model,
            latency_dual_by_model=self._latency_dual_by_model,
            replica_count_by_model=replica_count_by_model,
        )

        self._control_step_count += 1
        logger.debug(
            "OFO step %d (t=%.1f s): batch=%s",
            self._control_step_count,
            clock.time_s,
            batch_next,
        )
        events.emit(
            "controller.ofo.step",
            {
                "batch_size_by_model": batch_next,
                "latency_dual_by_model": dict(self._latency_dual_by_model),
            },
        )
        return [SetBatchSize(batch_size_by_model=batch_next, target=self._datacenter)]

openg2g.controller.rule_based

Rule-based batch-size controller for voltage regulation.

A proportional controller that adjusts LLM batch sizes based on observed voltage violations. Unlike the OFO controller, it requires no sensitivity matrix, no logistic curve fits, and no dual variables -- making it a natural "simple baseline" for comparison.

Algorithm (each control step): 1. Read all bus-phase voltages from the grid. 2. Find worst voltage violation magnitude. 3. Compute a signed "pressure" signal: - positive (undervoltage) → reduce batch (less power draw, less voltage drop) - negative (overvoltage) → increase batch (more power draw, more voltage drop) - zero → no action (all voltages within bounds) 4. Adjust each model's batch size proportionally in log2-space. 5. Snap to the nearest feasible batch size.

RuleBasedConfig

Bases: BaseModel

Configuration for the rule-based batch-size controller.

Attributes:

Name Type Description
step_size float

Proportional gain in log2(batch) change per pu of voltage violation. With feasible batches spaced ~1 log2 unit apart, a violation of 0.01 pu needs step_size ~10 to produce a 0.1 log2 shift, enough to eventually change the discrete batch level.

v_min float

Lower voltage limit (pu).

v_max float

Upper voltage limit (pu).

deadband float

Ignore violations smaller than this (pu). Prevents chattering.

latency_guard bool

If True, prevent batch-size increases when ITL exceeds the model's deadline.

Source code in openg2g/controller/rule_based.py
class RuleBasedConfig(BaseModel):
    """Configuration for the rule-based batch-size controller.

    Attributes:
        step_size: Proportional gain in log2(batch) change per pu of voltage
            violation. With feasible batches spaced ~1 log2 unit apart, a
            violation of 0.01 pu needs `step_size` ~10 to produce a 0.1 log2
            shift, enough to eventually change the discrete batch level.
        v_min: Lower voltage limit (pu).
        v_max: Upper voltage limit (pu).
        deadband: Ignore violations smaller than this (pu). Prevents chattering.
        latency_guard: If True, prevent batch-size increases when ITL exceeds
            the model's deadline.
    """

    step_size: float = 10.0
    v_min: float = 0.95
    v_max: float = 1.05
    deadband: float = 0.0001
    latency_guard: bool = True

RuleBasedBatchSizeController

Bases: Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]

Proportional rule-based controller for LLM batch-size regulation.

Reads grid voltages, computes a signed pressure signal from the worst violation, and adjusts batch sizes proportionally. No model fits or sensitivity matrices required.

Source code in openg2g/controller/rule_based.py
class RuleBasedBatchSizeController(
    Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid],
):
    """Proportional rule-based controller for LLM batch-size regulation.

    Reads grid voltages, computes a signed pressure signal from the worst
    violation, and adjusts batch sizes proportionally.  No model fits or
    sensitivity matrices required.
    """

    def __init__(
        self,
        inference_models: tuple[InferenceModelSpec, ...] | list[InferenceModelSpec],
        *,
        datacenter: LLMBatchSizeControlledDatacenter[LLMDatacenterState],
        grid: OpenDSSGrid,
        config: RuleBasedConfig,
        dt_s: Fraction = Fraction(1),
        exclude_buses: tuple[str, ...] = (),
        zone_buses: tuple[str, ...] | None = None,
        initial_batch_sizes: dict[str, int] | None = None,
    ) -> None:
        model_specs = list(inference_models)
        self._dt_s = dt_s
        self._datacenter = datacenter
        self._grid = grid
        self._config = config
        self._models = model_specs
        self._exclude_lower = {b.lower() for b in exclude_buses}
        # Zone-local observation: when set, only buses in this set contribute
        # to the worst-violation scan. Used in multi-DC topologies (ieee123)
        # to give each site credit only for the part of the network it can
        # actually move. None preserves the original global-scan behavior.
        self._zone_lower: set[str] | None = {b.lower() for b in zone_buses} if zone_buses is not None else None
        self._initial_batch_sizes = initial_batch_sizes or {}

        # Build per-model feasible batch list (sorted ascending)
        self._feasible: dict[str, list[int]] = {}
        self._itl_deadline: dict[str, float] = {}
        for ms in model_specs:
            self._feasible[ms.model_label] = sorted(ms.feasible_batch_sizes)
            self._itl_deadline[ms.model_label] = ms.itl_deadline_s

        # Continuous state: log2(batch) per model (for smooth proportional control)
        self._log2_batch: dict[str, float] = {
            ms.model_label: math.log2(self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0]))
            for ms in model_specs
        }

        logger.info(
            "RuleBasedBatchSizeController: %d models, dt=%s s, step_size=%.2f, deadband=%.4f, v=[%.2f, %.2f], zone_local=%s (n_zone_buses=%d)",  # noqa: E501
            len(model_specs),
            dt_s,
            config.step_size,
            config.deadband,
            config.v_min,
            config.v_max,
            self._zone_lower is not None,
            len(self._zone_lower) if self._zone_lower is not None else 0,
        )

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def reset(self) -> None:
        self._log2_batch = {
            ms.model_label: math.log2(self._initial_batch_sizes.get(ms.model_label, ms.feasible_batch_sizes[0]))
            for ms in self._models
        }

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        datacenter = self._datacenter
        grid = self._grid
        cfg = self._config
        voltages = grid.state.voltages

        # 1. Find worst voltage violation
        worst_under = 0.0  # magnitude of worst undervoltage (positive)
        worst_over = 0.0  # magnitude of worst overvoltage (positive)

        for bus in voltages.buses():
            blow = bus.lower()
            if blow in self._exclude_lower:
                continue
            if self._zone_lower is not None and blow not in self._zone_lower:
                continue
            pv = voltages[bus]
            for v in (pv.a, pv.b, pv.c):
                if math.isnan(v):
                    continue
                if v < cfg.v_min:
                    worst_under = max(worst_under, cfg.v_min - v)
                elif v > cfg.v_max:
                    worst_over = max(worst_over, v - cfg.v_max)

        # 2. Compute pressure signal
        # Positive pressure → reduce batch (undervoltage: DC draws too much power)
        # Negative pressure → increase batch (overvoltage: DC draws too little)
        if worst_under > cfg.deadband:
            pressure = worst_under
        elif worst_over > cfg.deadband:
            pressure = -worst_over
        else:
            pressure = 0.0

        if pressure == 0.0:
            return []

        # 3. Read latency state for guard
        dc_state = datacenter.state
        itl_by_model = dc_state.observed_itl_s_by_model

        # 4. Adjust batch sizes
        new_batches: dict[str, int] = {}
        changed = False

        for ms in self._models:
            label = ms.model_label
            log2_b = self._log2_batch[label]
            feasible = self._feasible[label]

            # Proportional adjustment in log2-space
            # pressure > 0 (undervoltage) → reduce batch → log2_b decreases
            delta = -cfg.step_size * pressure
            new_log2 = log2_b + delta

            # Clamp to feasible range
            new_log2 = max(math.log2(feasible[0]), min(math.log2(feasible[-1]), new_log2))

            # Latency guard: don't increase batch if ITL already exceeds deadline
            if cfg.latency_guard and delta > 0:
                itl = itl_by_model[label]
                if not math.isnan(itl) and itl > self._itl_deadline[label]:
                    new_log2 = log2_b  # revert

            # Snap to nearest feasible batch size
            target = 2.0**new_log2
            best = min(feasible, key=lambda b: abs(b - target))

            # Keep continuous state for accumulation; only snap for the command
            self._log2_batch[label] = new_log2
            new_batches[label] = best

            if best != dc_state.batch_size_by_model.get(
                label, self._initial_batch_sizes.get(label, ms.feasible_batch_sizes[0])
            ):
                changed = True

        if not changed:
            return []

        events.emit(
            "rule_based.step",
            {"time_s": clock.time_s, "pressure": pressure, "batch": dict(new_batches)},
        )

        return [SetBatchSize(batch_size_by_model=new_batches, target=self._datacenter)]

openg2g.controller.storage

Storage controllers.

StorageDroopConfig

Bases: BaseModel

Configuration for local-voltage storage droop control.

Attributes:

Name Type Description
mode StorageDroopMode

Droop output mode: qv controls kvar, pv controls kW.

v_ref float

Reference local voltage in pu.

deadband_pu float

Symmetric deadband around v_ref, in pu.

full_output_voltage_error_pu float

Absolute voltage error from v_ref where storage reaches its output limit.

droop_gain_per_pu float | None

Optional gain override in kW/pu for P-V mode or kvar/pu for Q-V mode.

max_abs_output float | None

Optional output limit in kW for P-V mode or kvar for Q-V mode.

allow_negative_output bool

Whether overvoltage may command charging in P-V mode or kvar absorption in Q-V mode.

voltage_statistic VoltageWindowStatistic

How to reduce local voltage samples from the previous control window.

Source code in openg2g/controller/storage.py
class StorageDroopConfig(BaseModel):
    """Configuration for local-voltage storage droop control.

    Attributes:
        mode: Droop output mode: `qv` controls kvar, `pv` controls kW.
        v_ref: Reference local voltage in pu.
        deadband_pu: Symmetric deadband around `v_ref`, in pu.
        full_output_voltage_error_pu: Absolute voltage error from `v_ref` where
            storage reaches its output limit.
        droop_gain_per_pu: Optional gain override in kW/pu for P-V mode or
            kvar/pu for Q-V mode.
        max_abs_output: Optional output limit in kW for P-V mode or kvar for
            Q-V mode.
        allow_negative_output: Whether overvoltage may command charging in P-V
            mode or kvar absorption in Q-V mode.
        voltage_statistic: How to reduce local voltage samples from the previous
            control window.
    """

    model_config = ConfigDict(frozen=True)

    mode: StorageDroopMode = "qv"
    v_ref: float = 1.0
    deadband_pu: float = 0.005
    full_output_voltage_error_pu: float = 0.05
    droop_gain_per_pu: float | None = None
    max_abs_output: float | None = None
    allow_negative_output: bool = True
    voltage_statistic: VoltageWindowStatistic = "minimum"

LocalVoltageStorageDroopController

Bases: Controller[DatacenterBackend, OpenDSSGrid]

Proportional storage droop controller using only the storage's local voltage.

The controller samples the grid history emitted since the previous control step, reads the voltage at each storage attachment bus, and emits SetStoragePower commands. In Q-V mode, positive output injects reactive power; in P-V mode, positive output discharges real power into the grid. Commands are held by each storage object until the next control step.

Parameters:

Name Type Description Default
grid OpenDSSGrid

Grid backend.

required
storages Mapping[EnergyStorage, str]

Mapping from each EnergyStorage to control onto its attachment bus. Each storage must be attached to grid at the bus given here.

required
config StorageDroopConfig

Droop configuration.

required
dt_s Fraction

Control interval in seconds.

Fraction(1)
Source code in openg2g/controller/storage.py
class LocalVoltageStorageDroopController(Controller[DatacenterBackend, OpenDSSGrid]):
    """Proportional storage droop controller using only the storage's local voltage.

    The controller samples the grid history emitted since the previous control
    step, reads the voltage at each storage attachment bus, and emits
    [`SetStoragePower`][openg2g.grid.command.SetStoragePower] commands. In Q-V
    mode, positive output injects reactive power; in P-V mode, positive output
    discharges real power into the grid. Commands are held by each storage
    object until the next control step.

    Args:
        grid: Grid backend.
        storages: Mapping from each [`EnergyStorage`][openg2g.grid.storage.EnergyStorage]
            to control onto its attachment bus. Each storage must be attached
            to *grid* at the bus given here.
        config: Droop configuration.
        dt_s: Control interval in seconds.
    """

    def __init__(
        self,
        *,
        grid: OpenDSSGrid,
        storages: Mapping[EnergyStorage, str],
        config: StorageDroopConfig,
        dt_s: Fraction = Fraction(1),
    ) -> None:
        self._validate_config(config)
        if dt_s <= 0:
            raise ValueError("dt_s must be positive.")
        if not storages:
            raise ValueError("storages must contain at least one storage resource.")

        self._grid = grid
        self._config = config
        self._dt_s = dt_s
        self._storages = tuple(
            _ControlledStorage(
                storage=storage,
                bus=bus,
                output_limit=self._storage_output_limit(config, storage),
                droop_gain_per_pu=self._resolve_droop_gain(config, self._storage_output_limit(config, storage)),
            )
            for storage, bus in storages.items()
        )
        self._history_cursor = 0
        self._bus_case_cache: dict[str, str] = {}

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def reset(self) -> None:
        self._history_cursor = 0
        self._bus_case_cache.clear()

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:
        history = self._grid.history()
        window = history[self._history_cursor :]
        self._history_cursor = len(history)

        if not window:
            return []

        commands: list[DatacenterCommand | GridCommand] = []
        for controlled in self._storages:
            local_voltage_pu = self._window_local_voltage_pu(window, controlled.bus)
            output = self._droop_output(local_voltage_pu, controlled)
            if self._config.mode == "qv":
                power_kw = 0.0
                reactive_power_kvar = output
            else:
                power_kw = output
                reactive_power_kvar = 0.0

            events.emit(
                "controller.storage_droop.step",
                {
                    "time_s": clock.time_s,
                    "storage_name": controlled.storage.name,
                    "storage_bus": controlled.bus,
                    "mode": self._config.mode,
                    "local_voltage_pu": local_voltage_pu,
                    "output": output,
                    "output_limit": controlled.output_limit,
                    "power_kw": power_kw,
                    "reactive_power_kvar": reactive_power_kvar,
                    "v_ref": self._config.v_ref,
                    "deadband_pu": self._config.deadband_pu,
                    "window_size": len(window),
                    "voltage_statistic": self._config.voltage_statistic,
                },
            )
            commands.append(
                SetStoragePower(
                    storage=controlled.storage,
                    power_kw=power_kw,
                    reactive_power_kvar=reactive_power_kvar,
                )
            )
        return commands

    def _window_local_voltage_pu(self, window: list[GridState], storage_bus: str) -> float:
        samples = [self._state_local_voltage_pu(state, storage_bus) for state in window]
        stat = self._config.voltage_statistic
        if stat == "minimum":
            return min(samples)
        if stat == "mean":
            return sum(samples) / len(samples)
        if stat == "latest":
            return samples[-1]
        raise ValueError(f"Unsupported voltage statistic: {stat!r}")

    def _state_local_voltage_pu(self, state: GridState, storage_bus: str) -> float:
        bus = self._resolve_bus_in_state(state, storage_bus)
        phases = state.voltages[bus]
        values = _finite_phase_voltages(phases)
        if not values:
            raise ValueError(f"Storage bus {storage_bus!r} has no finite phase voltages.")
        return min(values)

    def _resolve_bus_in_state(self, state: GridState, storage_bus: str) -> str:
        if storage_bus in state.voltages:
            return storage_bus
        target = storage_bus.lower()
        cached = self._bus_case_cache.get(target)
        if cached is not None and cached in state.voltages:
            return cached
        for bus in state.voltages.buses():
            if bus.lower() == target:
                self._bus_case_cache[target] = bus
                return bus
        raise ValueError(f"Storage bus {storage_bus!r} not found in grid voltages.")

    def _droop_output(self, local_voltage_pu: float, controlled: _ControlledStorage) -> float:
        error = self._config.v_ref - local_voltage_pu
        abs_error = abs(error)
        if abs_error <= self._config.deadband_pu:
            output = 0.0
        else:
            effective_error = math.copysign(abs_error - self._config.deadband_pu, error)
            output = controlled.droop_gain_per_pu * effective_error

        if not self._config.allow_negative_output:
            output = max(output, 0.0)
        return max(-controlled.output_limit, min(controlled.output_limit, output))

    @staticmethod
    def _storage_output_limit(config: StorageDroopConfig, storage: EnergyStorage) -> float:
        if config.mode == "qv":
            rating = float(storage.rated_apparent_power_kva)
        else:
            rating = float(storage.rated_power_kw)

        if not math.isfinite(rating) or rating <= 0.0:
            raise ValueError(f"Storage {storage.name!r} output rating must be positive.")

        if config.max_abs_output is not None:
            return min(rating, float(config.max_abs_output))
        return rating

    @staticmethod
    def _validate_config(config: StorageDroopConfig) -> None:
        if not math.isfinite(config.v_ref) or config.v_ref <= 0.0:
            raise ValueError("v_ref must be positive.")
        if not math.isfinite(config.deadband_pu) or config.deadband_pu < 0.0:
            raise ValueError("deadband_pu must be non-negative.")
        if not math.isfinite(config.full_output_voltage_error_pu) or config.full_output_voltage_error_pu <= 0.0:
            raise ValueError("full_output_voltage_error_pu must be positive.")
        if config.full_output_voltage_error_pu <= config.deadband_pu:
            raise ValueError("full_output_voltage_error_pu must be larger than deadband_pu.")
        if config.droop_gain_per_pu is not None and (
            not math.isfinite(config.droop_gain_per_pu) or config.droop_gain_per_pu <= 0.0
        ):
            raise ValueError("droop_gain_per_pu must be positive when provided.")
        if config.max_abs_output is not None and (
            not math.isfinite(config.max_abs_output) or config.max_abs_output <= 0.0
        ):
            raise ValueError("max_abs_output must be positive when provided.")

    @staticmethod
    def _resolve_droop_gain(config: StorageDroopConfig, output_limit: float) -> float:
        if config.droop_gain_per_pu is not None:
            return float(config.droop_gain_per_pu)
        effective_span_pu = config.full_output_voltage_error_pu - config.deadband_pu
        return output_limit / effective_span_pu

openg2g.controller.tap_schedule

Tap schedule controller: applies pre-defined regulator tap changes at specified times.

TapScheduleController

Bases: Controller[DatacenterBackend, GridBackend]

Applies pre-defined tap changes at scheduled times.

When multiple schedule entries fire in the same step, their tap values are merged (later entries win).

Parameters:

Name Type Description Default
schedule TapSchedule

Tap schedule built via TapPosition(...).at(t=...) | ....

required
dt_s Fraction

How often the controller checks the schedule (seconds).

Fraction(1)
Source code in openg2g/controller/tap_schedule.py
class TapScheduleController(Controller[DatacenterBackend, GridBackend]):
    """Applies pre-defined tap changes at scheduled times.

    When multiple schedule entries fire in the same step, their tap
    values are merged (later entries win).

    Args:
        schedule: Tap schedule built via
            [`TapPosition(...).at(t=...) | ...`][openg2g.grid.config.TapSchedule].
        dt_s: How often the controller checks the schedule (seconds).
    """

    def __init__(self, *, schedule: TapSchedule, dt_s: Fraction = Fraction(1)) -> None:
        self._dt_s = dt_s
        self._entries = list(schedule)
        self._idx = 0

    def reset(self) -> None:
        self._idx = 0

    @property
    def dt_s(self) -> Fraction:
        return self._dt_s

    def step(
        self,
        clock: SimulationClock,
        events: EventEmitter,
    ) -> list[DatacenterCommand | GridCommand]:

        t_now = clock.time_s
        merged: dict[str, float] = {}
        any_fired = False

        while self._idx < len(self._entries):
            t_ev, pos = self._entries[self._idx]
            if float(t_ev) <= t_now + 1e-12:
                merged.update(pos.regulators)
                any_fired = True
                self._idx += 1
            else:
                break

        if not any_fired or not merged:
            return []

        tap = TapPosition(regulators=merged)
        events.emit("controller.tap_schedule.fired", {"tap_position": tap})
        return [SetTaps(tap_position=tap)]