Source code for rizemind.authentication.can_train_criterion
import os
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from rizemind.authentication.signatures.auth import recover_auth_signer
from rizemind.authentication.train_auth import (
parse_train_auth_res,
prepare_train_auth_ins,
)
from rizemind.authentication.typing import SupportsEthAccountStrategy
from rizemind.contracts.erc.erc5267.typings import EIP712Domain
from rizemind.exception.parse_exception import ParseException
[docs]
class CanTrainCriterion(Criterion):
"""Flower criterion to select clients that can train.
This criterion implements a check to authenticate clients using Ethereum signatures.
It verifies if a client is authorized to participate in a specific training round.
Attributes:
round_id: The identifier of the current training round.
domain: The EIP-712 domain for signing authentication messages.
swarm: The protocol that provides the EIP-712 domain and
verifies training permissions.
"""
round_id: int
domain: EIP712Domain
swarm: SupportsEthAccountStrategy
def __init__(self, round_id: int, swarm: SupportsEthAccountStrategy):
"""Initializes the CanTrainCriterion.
Args:
round_id: The ID of the current training round.
swarm: The protocol that provides the EIP-712 domain and
verifies training permissions.
"""
self.round_id = round_id
self.domain = swarm.get_eip712_domain()
self.swarm = swarm
[docs]
def select(self, client: ClientProxy) -> bool:
"""Selects a client for training based on authentication.
This method reads the client's properties and recover's its address
to determine whether it can participate in the training.
Args:
client: The client proxy to evaluate for selection.
Returns:
True if the client is authenticated and authorized to train,
False otherwise.
"""
nonce = os.urandom(32)
ins = prepare_train_auth_ins(
round_id=self.round_id, nonce=nonce, domain=self.domain
)
try:
res = client.get_properties(ins, timeout=60, group_id=self.round_id)
auth = parse_train_auth_res(res)
signer = recover_auth_signer(
round=self.round_id,
nonce=nonce,
domain=self.domain,
signature=auth.signature,
)
return self.swarm.can_train(signer, self.round_id)
except (ParseException, ValueError):
return False