Skip to content

Abstractions 🏗

Each algorithm is mainly composed of four classes: Model📦, Algorithm👣, Agent🤖, Trainer🔁 with HAS-A relationship.

  • Model📦: Define single or multiple forward networks. The input is the observations and the output is the original output of networks.
  • Algorithm👣: Define the mechanism to update parameters in the Model📦 and the post-processing of the output of Model📦 (argmax, ...).
  • Agent🤖: A data bridge between Environment🗺 and Algorithm👣.
  • Trainer🔁: Define the overall training process of Agent🤖 and the tools to assist the training (Buffer,...).

The Trainer.__call__ function returns a generator that holds the training control-flow and all related data. The generator returns a log_data training log at each step, and the generator is called iteratively to complete the training and get all log_data.

The logger📊 part uses Tensorboard and Weights & Biases to record training logs and decorates the Trainer.__call__ function, see the core code for the specific implementation.




🧵 The solid line indicates the control-flow; The dotted line indicates the data-flow.


Model Structure Diagram
class Model(nn.Module):
    def __init__(self, **kwargs) -> None:

    def value(self, x: torch.Tensor, a: Optional[torch.Tensor] = None) -> tuple[Any]:
        # Returns output value of a single or multiple critics

    def action(self, x: torch.Tensor) -> tuple[Any]:
        # Returns action or action probability distribution

class Algorithm:
    def __init__(self, **kwargs) -> None:
        self.model = Model(**kwargs)
        # 1. Initialize model, target model
        # 2. Initialize optimizer

    def predict(self, obs: torch.Tensor) -> tuple[Any]:
        # Returns action or action probability distribution or Q-function

    def learn(self, data: BufferSamples) -> dict[str, Any]:
        # Given the training data, it defines a loss function to update the parameters in the Model.

        # 1. Computing target
        # 2. Computing loss
        # 3. Update model
        # 4. Returns log_data of train

    def sync_target(self) -> None:
        # Synchronize model and target model

class Agent:
    def __init__(self, **kwargs) -> None:
        self.alg = Algorithm(**kwargs)
        # 1. Initialize Algorithm
        # 2. Initialize run steps variable

    def predict(self, obs: np.ndarray) -> np.ndarray:
        # 1. obs pre-processing (to_tensor & to_device)
        # 2. act = Algorithm.predict
        # 3. act post-processing (to_numpy & to_cpu)
        # 4. Returns the act used for the evaluation

    def sample(self, obs: np.ndarray) -> np.ndarray:
        # 1. obs pre-processing (to_tensor & to_device)
        # 2. act = Algorithm.predict
        # 3. act post-processing (to_numpy & to_cpu & add noise)
        # 4. Returns the act used for training

    def learn(self, data: BufferSamples) -> dict[str, Any]:
        # Data pre-processing
        # Calling Algorithm.learn
        # Returns return of Algorithm.learn

class Trainer:
    def __init__(self, **kwargs) -> None:
        self.agent = Agent(**kwargs)
        # 1. Initialize args
        # 2. Initialize the training and evaluation environment
        # 3. Initialize Buffer
        # 4. Initialize Agent

    def __call__(self) -> Generator[dict[str, Any], None, None]:
        # 1. Define the training control-flow
        # 2. Returns a generator

    def _run_collect(self) -> dict[str, Any]:
        # 1. Sample a step and add data to the Buffer
        # 2. Returns log_data

    def _run_train(self) -> dict[str, Any]:
        # 1. Samples data from the Buffer
        # 2. Training single step
        # 3. Returns log_data

if __name__ == "__main__":
    trainer = Trainer()
    for log_data in trainer():