mirror of
https://github.com/huggingface/deep-rl-class.git
synced 2026-04-13 11:59:45 +08:00
Add standardization of the returns to obtain more stable training of Reinforce
This commit is contained in:
@@ -673,8 +673,13 @@
|
||||
" for t in range(n_steps)[::-1]:\n",
|
||||
" disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
|
||||
" returns.appendleft( ) # complete here \n",
|
||||
" \n",
|
||||
"\n",
|
||||
" \n",
|
||||
" ## standardization of the returns is employed to make training more stable\n",
|
||||
" eps = np.finfo(np.float32).eps.item()\n",
|
||||
" ## eps is the smallest representable float, which is \n",
|
||||
" # added to the standard deviation of the returns to avoid numerical instabilities\n",
|
||||
" returns = torch.tensor(returns)\n",
|
||||
" returns = (returns - returns.mean()) / (returns.std() + eps)\n",
|
||||
" # Line 7:\n",
|
||||
" policy_loss = []\n",
|
||||
" for log_prob, disc_return in zip(saved_log_probs, returns):\n",
|
||||
@@ -766,6 +771,12 @@
|
||||
" disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
|
||||
" returns.appendleft( gamma*disc_return_t + rewards[t] ) \n",
|
||||
" \n",
|
||||
" ## standardization of the returns is employed to make training more stable\n",
|
||||
" eps = np.finfo(np.float32).eps.item()\n",
|
||||
" ## eps is the smallest representable float, which is \n",
|
||||
" # added to the standard deviation of the returns to avoid numerical instabilities \n",
|
||||
" returns = torch.tensor(returns)\n",
|
||||
" returns = (returns - returns.mean()) / (returns.std() + eps)\n",
|
||||
" # Line 7:\n",
|
||||
" policy_loss = []\n",
|
||||
" for log_prob, disc_return in zip(saved_log_probs, returns):\n",
|
||||
|
||||
Reference in New Issue
Block a user