Agent
¶
This module contains classes used to define the standard behavior of the agent. It relies on the controllers, the chosen training/test policy and the learning algorithm to specify its behavior in the environment.
NeuralAgent (environment, q_network[, ...]) |
The NeuralAgent class wraps a deep Q-network for training and testing in a given environment. |
DataSet (env[, random_state, max_size, ...]) |
A replay memory consisting of circular buffers for observations, actions, rewards and terminals. |
Detailed description¶
-
class
deer.agent.
NeuralAgent
(environment, q_network, replay_memory_size=1000000, replay_start_size=None, batch_size=32, random_state=<mtrand.RandomState object>, exp_priority=0, train_policy=None, test_policy=None, only_full_history=True)¶ The NeuralAgent class wraps a deep Q-network for training and testing in a given environment.
Attach controllers to it in order to conduct an experiment (when to train the agent, when to test,...).
Parameters: environment : object from class Environment
The environment in which the agent interacts
q_network : object from class QNetwork
The q_network associated to the agent
replay_memory_size : int
Size of the replay memory. Default : 1000000
replay_start_size : int
Number of observations (=number of time steps taken) in the replay memory before starting learning. Default: minimum possible according to environment.inputDimensions().
batch_size : int
Number of tuples taken into account for each iteration of gradient descent. Default : 32
random_state : numpy random number generator
Default : random seed.
exp_priority : float
The exponent that determines how much prioritization is used, default is 0 (uniform priority). One may check out Schaul et al. (2016) - Prioritized Experience Replay.
train_policy : object from class Policy
Policy followed when in training mode (mode -1)
test_policy : object from class Policy
Policy followed when in other modes than training (validation and test modes)
only_full_history : boolean
Whether we wish to train the neural network only on full histories or we wish to fill with zeroes the observations before the beginning of the episode
Methods
attach
(controller)avgBellmanResidual
()Returns the average training loss on the epoch avgEpisodeVValue
()Returns the average V value on the episode (on time steps where a non-random action has been taken) bestAction
()Returns the best Action detach
(controllerIdx)discountFactor
()Get the discount factor dumpNetwork
(fname[, nEpoch])Dump the network learningRate
()Get the learning rate mode
()overrideNextAction
(action)Possibility to override the chosen action. resumeTrainingMode
()run
(n_epochs, epoch_length)This function encapsulates the whole process of the learning. setControllersActive
(toDisable, active)Activate controller setDiscountFactor
(df)Set the discount factor setLearningRate
(lr)Set the learning rate for the gradient descent setNetwork
(fname[, nEpoch])Set values into the network startMode
(mode, epochLength)summarizeTestPerformance
()totalRewardOverLastTest
()Returns the average sum of rewards per episode and the number of episode train
()This function selects a random batch of data (with self._dataset.randomBatch) and performs a Q-learning iteration (with self._network.train). -
avgBellmanResidual
()¶ Returns the average training loss on the epoch
-
avgEpisodeVValue
()¶ Returns the average V value on the episode (on time steps where a non-random action has been taken)
-
bestAction
()¶ Returns the best Action
-
discountFactor
()¶ Get the discount factor
-
dumpNetwork
(fname, nEpoch=-1)¶ Dump the network
Parameters: fname : string
Name of the file where the network will be dumped
nEpoch : int
Epoch number (Optional)
-
learningRate
()¶ Get the learning rate
-
overrideNextAction
(action)¶ Possibility to override the chosen action. This possibility should be used on the signal OnActionChosen.
-
run
(n_epochs, epoch_length)¶ This function encapsulates the whole process of the learning. It starts by calling the controllers method “onStart”, Then it runs a given number of epochs where an epoch is made up of one or many episodes (called with agent._runEpisode) and where an epoch ends up after the number of steps reaches the argument “epoch_length”. It ends up by calling the controllers method “end”.
Parameters: n_epochs : number of epochs
int
epoch_length : maximum number of steps for a given epoch
int
-
setControllersActive
(toDisable, active)¶ Activate controller
-
setDiscountFactor
(df)¶ Set the discount factor
-
setLearningRate
(lr)¶ Set the learning rate for the gradient descent
-
setNetwork
(fname, nEpoch=-1)¶ Set values into the network
Parameters: fname : string
Name of the file where the values are
nEpoch : int
Epoch number (Optional)
-
totalRewardOverLastTest
()¶ Returns the average sum of rewards per episode and the number of episode
-
train
()¶ This function selects a random batch of data (with self._dataset.randomBatch) and performs a Q-learning iteration (with self._network.train).
-
-
class
deer.agent.
DataSet
(env, random_state=None, max_size=1000, use_priority=False, only_full_history=True)¶ A replay memory consisting of circular buffers for observations, actions, rewards and terminals.
Methods
actions
()Get all actions currently in the replay memory, ordered by time where they were taken. addSample
(obs, action, reward, is_terminal, ...)Store a (observation[for all subjects], action, reward, is_terminal) in the dataset. observations
()Get all observations currently in the replay memory, ordered by time where they were observed. randomBatch
(size, use_priority)Return corresponding states, actions, rewards, terminal status, and next_states for size randomly chosen transitions. rewards
()Get all rewards currently in the replay memory, ordered by time where they were received. terminals
()Get all terminals currently in the replay memory, ordered by time where they were observed. updatePriorities
(priorities, rndValidIndices)-
actions
()¶ Get all actions currently in the replay memory, ordered by time where they were taken.
-
addSample
(obs, action, reward, is_terminal, priority)¶ Store a (observation[for all subjects], action, reward, is_terminal) in the dataset. Parameters ———– obs : ndarray
An ndarray(dtype=’object’) where obs[s] corresponds to the observation made on subject s before the agent took action [action].- action : int
- The action taken after having observed [obs].
- reward : float
- The reward associated to taking this [action].
- is_terminal : bool
- Tells whether [action] lead to a terminal state (i.e. corresponded to a terminal transition).
- priority : float
- The priority to be associated with the sample
-
observations
()¶ Get all observations currently in the replay memory, ordered by time where they were observed.
observations[s][i] corresponds to the observation made on subject s before the agent took actions()[i].
-
randomBatch
(size, use_priority)¶ Return corresponding states, actions, rewards, terminal status, and next_states for size randomly chosen transitions. Note that if terminal[i] == True, then next_states[s][i] == np.zeros_like(states[s][i]) for each subject s.
Parameters: size : int
Number of transitions to return.
Returns
——-
states : ndarray
An ndarray(size=number_of_subjects, dtype=’object), where states[s] is a 2+D matrix of dimensions size x s.memorySize x “shape of a given observation for this subject”. States were taken randomly in the data with the only constraint that they are complete regarding the histories for each observed subject.
actions : ndarray
An ndarray(size=number_of_subjects, dtype=’int32’) where actions[i] is the action taken after having observed states[:][i].
rewards : ndarray
An ndarray(size=number_of_subjects, dtype=’float32’) where rewards[i] is the reward obtained for taking actions[i-1].
next_states : ndarray
Same structure than states, but next_states[s][i] is guaranteed to be the information concerning the state following the one described by states[s][i] for each subject s.
terminals : ndarray
An ndarray(size=number_of_subjects, dtype=’bool’) where terminals[i] is True if actions[i] lead to terminal states and False otherwise
Throws
——-
- SliceError
If a batch of this size could not be built based on current data set (not enough data or all trajectories are too short).
-
rewards
()¶ Get all rewards currently in the replay memory, ordered by time where they were received.
-
terminals
()¶ Get all terminals currently in the replay memory, ordered by time where they were observed.
terminals[i] is True if actions()[i] lead to a terminal state (i.e. corresponded to a terminal transition), and False otherwise.
-
updatePriorities
(priorities, rndValidIndices)¶
-