TorchStateActionCritic

class maze.core.agent.torch_state_action_critic.TorchStateActionCritic(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[Union[str, int], bool], action_spaces_dict: Dict[Union[str, int], gym.spaces.Dict])

Encapsulates multiple torch state action critics for training in structured environments.

Parameters
  • networks – Mapping of value functions (critic) to encapsulate.

  • num_policies – The number of corresponding policies.

  • device – Device the policy should be located on (cpu or cuda)

  • only_discrete_spaces – A dict specifying if the action spaces w.r.t. the step only hold discrete action spaces.

compute_state_action_value_step(observation: Dict[str, torch.Tensor], action: Dict[str, torch.Tensor], critic_id: Union[str, int, tuple]) → List[torch.Tensor]

Predict the value with specified step_key, step_observation and action.

Parameters
  • observation – The observation for the current step.

  • action – The action performed at the current step.

  • critic_id – The current step key of the multi-step env.

Returns

A list of tensors holding the predicted q value for each critic.

compute_state_action_values_step(observation: Dict[str, torch.Tensor], critic_id: Union[str, int, tuple]) → List[Dict[str, torch.Tensor]]

Predict the value with specified step_key, step_observation and action for discrete actions only.

Parameters
  • observation – The observation for the current step.

  • critic_id – The current step key of the multi-step env.

Returns

A list of dicts holding the predicted q value for each action w.r.t. to the critic.

property device

implementation of TorchModel

eval()None

(overrides TorchModel)

implementation of TorchModel

load_state_dict(state_dict: Dict)None

(overrides TorchModel)

implementation of TorchModel

abstract property num_critics

Returns the number of critic networks. :return: Number of critic networks.

parameters() → List[torch.Tensor]

(overrides TorchModel)

implementation of TorchModel

per_critic_parameters() → List[List[torch.Tensor]]

Retrieve all trainable critic parameters (to be assigned to optimizers). :return: List of lists holding all parameters for the base critic corresponding to number of critic per step.

abstract predict_next_q_values(next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) → Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]

Predict the target q value for the next step. \(V (st) := E_{at∼π}[Q(st, at) − α log(π(at |st))]\).

Parameters
  • next_observations – The next observations.

  • next_actions – The next actions sampled from the policy.

  • next_actions_logits – The logits of the next actions (only relevantt for the discrete case).

  • next_actions_log_probs – The log probabilities of the actions.

  • alpha – The alpha or entropy coefficient for each step.

Returns

A dict w.r.t. the step holding tensors representing the predicted next q value

abstract predict_q_values(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions: Dict[Union[str, int], Dict[str, torch.Tensor]], gather_output: bool) → Dict[Union[str, int], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]]

(overrides StateActionCritic)

implementation of StateActionCritic

re_init_networks()None

Reinitialize all parameters of the network.

state_dict() → Dict

(overrides TorchModel)

implementation of TorchModel

to(device: str)None

(overrides TorchModel)

implementation of TorchModel

train()None

(overrides TorchModel)

implementation of TorchModel

update_target_weights(tau: float)None

Preform a soft or hard update depending on the tau value chosen. tau==1 results in a hard update

Parameters

tau – Parameter weighting the soft update of the target network.