Source code for rizemind.strategies.contribution.shapley.trainer_mapping
from eth_typing import ChecksumAddress
[docs]
class ParticipantMapping:
"""Mapping between trainer addresses and their ids.
This class maintains a bijective mapping between trainer addresses and integer IDs,
enabling efficient bit-mask representation of trainer coalitions for Shapley value
calculation.
Attributes:
participant_ids: Dictionary mapping trainer addresses to their unique IDs.
"""
participant_ids: dict[ChecksumAddress, int]
def __init__(self) -> None:
"""Initialize an empty participant mapping."""
self.participant_ids = {}
[docs]
def add_participant(self, participant: ChecksumAddress) -> None:
"""Add a participant to the mapping.
If the participant doesn't already exist, assigns them the next available ID.
Args:
participant: The trainer address to add.
"""
if participant not in self.participant_ids:
self.participant_ids[participant] = self.get_total_participants()
[docs]
def get_total_participants(self) -> int:
"""Get the total number of participants.
Returns:
The count of unique participants in the mapping.
"""
return len(self.participant_ids.values())
[docs]
def get_participant_id(self, participant: ChecksumAddress) -> int:
"""Get the numerical ID for a participant.
Args:
participant: The trainer address to look up.
Returns:
The numerical ID assigned to this participant.
Raises:
ValueError: If the participant is not in the mapping.
"""
if participant not in self.participant_ids:
raise ValueError(f"{participant} did not participate.")
return self.participant_ids[participant]
[docs]
def get_participant_mask(self, participant: ChecksumAddress) -> int:
"""Get the bit mask for a participant.
Args:
participant: The trainer address.
Returns:
An integer with a single bit set representing this participant.
"""
participant_id = self.get_participant_id(participant)
return 1 << participant_id
[docs]
def get_participant_set_id(self, participants: list[ChecksumAddress]) -> str:
"""Generate a unique set ID for a group of participants.
Args:
participants: List of trainer addresses in the set.
Returns:
String representation of the bit mask identifying this set.
"""
return self.include_participants(participants=participants, id="0")
[docs]
def in_set(self, trainer: ChecksumAddress, id: str) -> bool:
"""Check if a trainer is a member of a set.
Args:
trainer: The trainer address to check.
id: The set identifier (bit mask as string).
Returns:
True if the trainer is in the set, False otherwise.
"""
aggregate_mask = int(id)
trainer_mask = self.get_participant_mask(trainer)
return (aggregate_mask & trainer_mask) != 0
[docs]
def exclude_participants(
self,
participants: ChecksumAddress | list[ChecksumAddress],
id: str | None = None,
):
"""Remove participants from a set.
Args:
participants: Single trainer address or list of addresses to exclude.
id: The set identifier to modify. If None, starts with empty set.
Returns:
String representation of the updated set bit mask.
"""
aggregate_mask = int(id) if id is not None else 0
if isinstance(participants, list):
for participant in participants:
participant_mask = self.get_participant_mask(participant)
aggregate_mask &= ~participant_mask
else:
participant_mask = self.get_participant_mask(participants)
aggregate_mask &= ~participant_mask
return str(aggregate_mask)
[docs]
def include_participants(
self,
participants: ChecksumAddress | list[ChecksumAddress],
id: str | None = None,
):
"""Add participants to a set.
Args:
participants: Single trainer address or list of addresses to include.
id: The set identifier to modify. If None, starts with empty set.
Returns:
String representation of the updated set bit mask.
"""
aggregate_mask = int(id) if id is not None else 0
if isinstance(participants, list):
for participant in participants:
participant_mask = self.get_participant_mask(participant)
aggregate_mask |= participant_mask
else:
participant_mask = self.get_participant_mask(participants)
aggregate_mask |= participant_mask
return str(aggregate_mask)
[docs]
def get_participants(self) -> list[ChecksumAddress]:
"""Get all participant addresses.
Returns:
List of all trainer addresses in the mapping.
"""
return list(self.participant_ids.keys())