Advertisement
thewindmage420

Viper.py

Jan 27th, 2025 (edited)
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 30.55 KB | None | 0 0
  1. import pygame
  2. import random
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from collections import deque, namedtuple
  8. import os
  9. import matplotlib.pyplot as plt
  10. import pickle
  11. import sys
  12. from torch.utils.tensorboard import SummaryWriter
  13.  
  14. # ================================
  15. # Configuration
  16. # ================================
  17.  
  18. # Mode Selection: 'train' or 'play'
  19. MODE = 'train'  # Change to 'play' to watch the AI play the game
  20.  
  21. # Constants
  22. BLOCK_SIZE = 24              # Increased block size for larger visuals
  23. GRID_WIDTH = 20              # Increased grid width for a wider game area
  24. GRID_HEIGHT = 40             # Increased grid height for a taller game area
  25. SCREEN_WIDTH = BLOCK_SIZE * GRID_WIDTH    # 24 * 20 = 480 pixels
  26. SCREEN_HEIGHT = BLOCK_SIZE * GRID_HEIGHT  # 24 * 40 = 960 pixels
  27. FPS = 20000                     # Adjusted FPS for smoother gameplay
  28.  
  29. # Colors
  30. WHITE = (255, 255, 255)
  31. BLACK = (0, 0, 0)
  32. GREEN = (0, 255, 0)
  33. RED = (255, 0, 0)
  34.  
  35. # Directions
  36. UP = (0, -1)
  37. DOWN = (0, 1)
  38. LEFT = (-1, 0)
  39. RIGHT = (1, 0)
  40.  
  41. # Hyperparameters
  42. LR = 1e-4                # Adjusted learning rate for stability
  43. GAMMA = 0.98             # Adjusted discount factor for long-term rewards
  44. MEMORY_SIZE = 100_000    # Increased memory size for more diverse experiences
  45. BATCH_SIZE = 256         # Adjusted batch size for stable gradient estimates
  46. TARGET_UPDATE = 1000     # Number of steps before updating the target network
  47. TAU = 0.005              # Polyak averaging factor for soft updates
  48. EPS_START = 1.0          # Starting value of epsilon
  49. EPS_END = 0.05           # Minimum value of epsilon
  50. EPS_DECAY = 1000         # Number of episodes over which epsilon decays
  51.  
  52. # Prioritized Replay Parameters
  53. PER_ALPHA = 0.6          # How much prioritization is used (0 - no prioritization, 1 - full prioritization)
  54. PER_BETA_START = 0.4     # Initial value of beta for importance-sampling
  55. PER_BETA_FRAMES = 1_000  # Number of frames over which beta will be annealed from initial value to 1
  56.  
  57. # Model saving path (Use relative paths)
  58. MODEL_PATH = 'snake_dqn.pth'
  59. MEMORY_PATH = 'snake_memory.pkl'
  60. AGENT_INFO_PATH = 'agent_info.pkl'
  61. PLOTS_PATH = 'rewards.png'
  62.  
  63. # ================================
  64. # Initialize Pygame
  65. # ================================
  66.  
  67. pygame.init()
  68. font = pygame.font.SysFont('arial', 20)
  69.  
  70. # ================================
  71. # Helper Functions
  72. # ================================
  73.  
  74. def get_file_path(filename):
  75.     """
  76.    Returns the absolute path for a given filename in the current working directory.
  77.    """
  78.     return os.path.join(os.getcwd(), filename)
  79.  
  80. def manhattan_distance(p1, p2):
  81.     """
  82.    Calculates the Manhattan distance between two points.
  83.    """
  84.     return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
  85.  
  86. # ================================
  87. # Game Environment
  88. # ================================
  89.  
  90. class SnakeGame:
  91.     def __init__(self, screen=None, grid_width=GRID_WIDTH, grid_height=GRID_HEIGHT):
  92.         self.screen = screen
  93.         self.grid_width = grid_width
  94.         self.grid_height = grid_height
  95.         self.reset()
  96.  
  97.     def reset(self):
  98.         """
  99.        Resets the game to its initial state.
  100.        """
  101.         self.direction = random.choice([UP, DOWN, LEFT, RIGHT])
  102.         self.head = (self.grid_width // 2, self.grid_height // 2)
  103.         # Initialize snake with length 3
  104.         self.snake = deque([
  105.             self.head,
  106.             (self.head[0] - self.direction[0], self.head[1] - self.direction[1]),
  107.             (self.head[0] - 2 * self.direction[0], self.head[1] - 2 * self.direction[1]),
  108.         ])
  109.         self.spawn_food()
  110.         self.score = 0
  111.         self.frame_iteration = 0
  112.  
  113.     def spawn_food(self):
  114.         """
  115.        Spawns food at a random location not occupied by the snake.
  116.        """
  117.         while True:
  118.             self.food = (random.randint(0, self.grid_width -1), random.randint(0, self.grid_height -1))
  119.             if self.food not in self.snake:
  120.                 break
  121.  
  122.     def play_step(self, action):
  123.         """
  124.        Executes one step of the game based on the given action.
  125.  
  126.        Parameters:
  127.            action (list): One-hot encoded action [straight, right, left]
  128.  
  129.        Returns:
  130.            reward (float): Reward obtained after the action
  131.            game_over (bool): Whether the game has ended
  132.            score (int): Current score
  133.        """
  134.         self.frame_iteration += 1
  135.  
  136.         # Move the snake
  137.         self.move(action)
  138.         self.snake.appendleft(self.head)
  139.  
  140.         # Initialize reward
  141.         reward = -0.1  # Step penalty to encourage efficiency
  142.         game_over = False
  143.  
  144.         # Check if game over
  145.         if self.is_collision() or self.frame_iteration > 100 * len(self.snake):
  146.             game_over = True
  147.             reward = -50.0
  148.             return reward, game_over, self.score
  149.  
  150.         # Check if food is eaten
  151.         if self.head == self.food:
  152.             self.score += 1
  153.             reward = 10.0
  154.             self.spawn_food()
  155.         else:
  156.             self.snake.pop()
  157.  
  158.         # Render the game if screen is provided
  159.         if self.screen:
  160.             self.render()
  161.  
  162.         return reward, game_over, self.score
  163.  
  164.     def is_collision(self, pt=None):
  165.         """
  166.        Checks if the given point collides with the boundaries or the snake itself.
  167.  
  168.        Parameters:
  169.            pt (tuple): The point to check. If None, checks the snake's head.
  170.  
  171.        Returns:
  172.            bool: True if collision occurs, False otherwise
  173.        """
  174.         if pt is None:
  175.             pt = self.head
  176.         # Hits boundary
  177.         if pt[0] < 0 or pt[0] >= self.grid_width or pt[1] < 0 or pt[1] >= self.grid_height:
  178.             return True
  179.         # Hits itself
  180.         if pt in list(self.snake)[1:]:
  181.             return True
  182.         return False
  183.  
  184.     def move(self, action):
  185.         """
  186.        Updates the snake's direction and moves its head based on the action.
  187.  
  188.        Parameters:
  189.            action (list): One-hot encoded action [straight, right, left]
  190.        """
  191.         # [straight, right, left]
  192.         directions = [UP, RIGHT, DOWN, LEFT]
  193.         idx = directions.index(self.direction)
  194.         if np.array_equal(action, [1, 0, 0]):
  195.             new_dir = self.direction
  196.         elif np.array_equal(action, [0, 1, 0]):
  197.             new_dir = directions[(idx + 1) % 4]
  198.         else:  # [0,0,1]
  199.             new_dir = directions[(idx -1) % 4]
  200.         self.direction = new_dir
  201.         x, y = self.head
  202.         dx, dy = self.direction
  203.         self.head = (x + dx, y + dy)
  204.  
  205.     def calculate_distances(self, head):
  206.         """
  207.        Calculates the distance from the head to the walls and to the snake's body in all four directions.
  208.  
  209.        Parameters:
  210.            head (tuple): Current head position (x, y)
  211.  
  212.        Returns:
  213.            dict: Distances in 'left', 'right', 'up', 'down' directions
  214.        """
  215.         distances = {
  216.             'left': head[0],
  217.             'right': self.grid_width - head[0] - 1,
  218.             'up': head[1],
  219.             'down': self.grid_height - head[1] - 1
  220.         }
  221.  
  222.         # Calculate distance to the snake's body
  223.         for segment in list(self.snake)[1:]:
  224.             if segment[1] == head[1]:
  225.                 if segment[0] < head[0]:
  226.                     distances['left'] = min(distances['left'], head[0] - segment[0] - 1)
  227.                 else:
  228.                     distances['right'] = min(distances['right'], segment[0] - head[0] - 1)
  229.             elif segment[0] == head[0]:
  230.                 if segment[1] < head[1]:
  231.                     distances['up'] = min(distances['up'], head[1] - segment[1] - 1)
  232.                 else:
  233.                     distances['down'] = min(distances['down'], segment[1] - head[1] - 1)
  234.  
  235.         return distances
  236.  
  237.     def render(self):
  238.         """
  239.        Renders the game state on the screen.
  240.        """
  241.         if self.screen is None:
  242.             return
  243.         # Clear the screen
  244.         self.screen.fill(BLACK)
  245.  
  246.         # Draw snake
  247.         for pos in self.snake:
  248.             pygame.draw.rect(self.screen, GREEN, pygame.Rect(pos[0]*BLOCK_SIZE, pos[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
  249.  
  250.         # Draw food
  251.         pygame.draw.rect(self.screen, RED, pygame.Rect(self.food[0]*BLOCK_SIZE, self.food[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
  252.  
  253.         # Draw score
  254.         score_text = font.render(f'Score: {self.score}', True, WHITE)
  255.         self.screen.blit(score_text, [0,0])
  256.  
  257.         # Update the display
  258.         pygame.display.flip()
  259.  
  260. # ================================
  261. # Neural Network Model
  262. # ================================
  263.  
  264. class DuelingDQN(nn.Module):
  265.     def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
  266.         """
  267.        Initializes the Dueling Deep Q-Network.
  268.  
  269.        Parameters:
  270.            input_size (int): Number of input features
  271.            hidden_size1 (int): Number of neurons in the first hidden layer
  272.            hidden_size2 (int): Number of neurons in the second hidden layer
  273.            output_size (int): Number of possible actions
  274.        """
  275.         super(DuelingDQN, self).__init__()
  276.         self.feature = nn.Sequential(
  277.             nn.Linear(input_size, hidden_size1),
  278.             nn.LayerNorm(hidden_size1),   # Layer Normalization for stability
  279.             nn.ReLU(),
  280.             nn.Linear(hidden_size1, hidden_size2),
  281.             nn.LayerNorm(hidden_size2),
  282.             nn.ReLU(),
  283.         )
  284.         # Value stream
  285.         self.value = nn.Sequential(
  286.             nn.Linear(hidden_size2, hidden_size2),
  287.             nn.ReLU(),
  288.             nn.Linear(hidden_size2, 1)
  289.         )
  290.         # Advantage stream
  291.         self.advantage = nn.Sequential(
  292.             nn.Linear(hidden_size2, hidden_size2),
  293.             nn.ReLU(),
  294.             nn.Linear(hidden_size2, output_size)
  295.         )
  296.  
  297.     def forward(self, x):
  298.         """
  299.        Forward pass through the network.
  300.  
  301.        Parameters:
  302.            x (torch.Tensor): Input tensor
  303.  
  304.        Returns:
  305.            torch.Tensor: Output tensor representing Q-values for each action
  306.        """
  307.         features = self.feature(x)
  308.         value = self.value(features)
  309.         advantage = self.advantage(features)
  310.         q_vals = value + (advantage - advantage.mean())
  311.         return q_vals
  312.  
  313. # ================================
  314. # Prioritized Replay Memory
  315. # ================================
  316.  
  317. Transition = namedtuple('Transition',
  318.                         ('state', 'action', 'reward', 'next_state', 'done'))
  319.  
  320. class PrioritizedReplayMemory:
  321.     def __init__(self, capacity, alpha=PER_ALPHA):
  322.         self.capacity = capacity
  323.         self.alpha = alpha
  324.         self.memory = []
  325.         self.pos = 0
  326.         self.priorities = np.zeros((capacity,), dtype=np.float32)
  327.  
  328.     def push(self, *args):
  329.         """
  330.        Saves a transition with maximum priority.
  331.        """
  332.         max_priority = self.priorities.max() if self.memory else 1.0
  333.  
  334.         if len(self.memory) < self.capacity:
  335.             self.memory.append(Transition(*args))
  336.         else:
  337.             self.memory[self.pos] = Transition(*args)
  338.  
  339.         self.priorities[self.pos] = max_priority
  340.         self.pos = (self.pos + 1) % self.capacity
  341.  
  342.     def sample(self, batch_size, beta=PER_BETA_START):
  343.         """
  344.        Samples a batch of transitions with probabilities proportional to their priorities.
  345.  
  346.        Returns:
  347.            tuple: (samples, indices, weights)
  348.        """
  349.         if len(self.memory) == self.capacity:
  350.             priorities = self.priorities
  351.         else:
  352.             priorities = self.priorities[:self.pos]
  353.  
  354.         probabilities = priorities ** self.alpha
  355.         probabilities /= probabilities.sum()
  356.  
  357.         indices = np.random.choice(len(self.memory), batch_size, p=probabilities)
  358.         samples = [self.memory[idx] for idx in indices]
  359.  
  360.         total = len(self.memory)
  361.         weights = (total * probabilities[indices]) ** (-beta)
  362.         weights /= weights.max()
  363.         weights = np.array(weights, dtype=np.float32)
  364.  
  365.         return samples, indices, weights
  366.  
  367.     def update_priorities(self, batch_indices, batch_priorities):
  368.         """
  369.        Updates the priorities of sampled transitions.
  370.  
  371.        Parameters:
  372.            batch_indices (list): List of indices sampled
  373.            batch_priorities (list): List of new priorities
  374.        """
  375.         for idx, priority in zip(batch_indices, batch_priorities):
  376.             self.priorities[idx] = priority
  377.  
  378.     def __len__(self):
  379.         return len(self.memory)
  380.  
  381. # ================================
  382. # Agent
  383. # ================================
  384.  
  385. class Agent:
  386.     def __init__(self):
  387.         """
  388.        Initializes the agent with necessary parameters and models.
  389.        """
  390.         self.n_games = 0
  391.         self.epsilon = EPS_START                # Exploration rate starts high for exploration
  392.         self.gamma = GAMMA
  393.         self.memory = PrioritizedReplayMemory(MEMORY_SIZE)
  394.         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  395.        
  396.         # Initialize the neural network
  397.         self.model = DuelingDQN(15, 256, 128, 3).to(self.device)  # Updated input_size to 15
  398.         self.target_model = DuelingDQN(15, 256, 128, 3).to(self.device)  # Updated input_size to 15
  399.         self.update_target(tau=1.0)      # Initial hard update
  400.  
  401.         # Optimizer and loss function
  402.         self.optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4)
  403.         self.criterion = nn.MSELoss()
  404.  
  405.         # TensorBoard writer
  406.         self.writer = SummaryWriter('runs/snake_dqn')
  407.  
  408.         # For adaptive epsilon
  409.         self.beta = PER_BETA_START
  410.         self.beta_increment = (1.0 - PER_BETA_START) / PER_BETA_FRAMES
  411.         self.frame = 1  # To track frames for beta annealing
  412.  
  413.     def update_target(self, tau=TAU):
  414.         """
  415.        Soft updates the target model's weights towards the current model's weights.
  416.  
  417.        Parameters:
  418.            tau (float): Polyak averaging factor
  419.        """
  420.         for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
  421.             target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
  422.  
  423.     def get_state(self, game):
  424.         """
  425.        Constructs the current state from the game.
  426.  
  427.        Parameters:
  428.            game (SnakeGame): The current game instance
  429.  
  430.        Returns:
  431.            np.array: The state representation
  432.        """
  433.         head = game.head
  434.         distances = game.calculate_distances(head)
  435.  
  436.         # Current direction
  437.         dir_l = game.direction == LEFT
  438.         dir_r = game.direction == RIGHT
  439.         dir_u = game.direction == UP
  440.         dir_d = game.direction == DOWN
  441.  
  442.         # Food location relative to head
  443.         food_left = game.food[0] < head[0]
  444.         food_right = game.food[0] > head[0]
  445.         food_up = game.food[1] < head[1]
  446.         food_down = game.food[1] > head[1]
  447.  
  448.         state = [
  449.             # Danger straight
  450.             (game.direction == UP and game.is_collision((head[0], head[1]-1))) or
  451.             (game.direction == RIGHT and game.is_collision((head[0]+1, head[1]))) or
  452.             (game.direction == DOWN and game.is_collision((head[0], head[1]+1))) or
  453.             (game.direction == LEFT and game.is_collision((head[0]-1, head[1]))),
  454.  
  455.             # Danger right
  456.             (game.direction == UP and game.is_collision((head[0]+1, head[1]))) or
  457.             (game.direction == RIGHT and game.is_collision((head[0], head[1]+1))) or
  458.             (game.direction == DOWN and game.is_collision((head[0]-1, head[1]))) or
  459.             (game.direction == LEFT and game.is_collision((head[0], head[1]-1))),
  460.  
  461.             # Danger left
  462.             (game.direction == UP and game.is_collision((head[0]-1, head[1]))) or
  463.             (game.direction == RIGHT and game.is_collision((head[0], head[1]-1))) or
  464.             (game.direction == DOWN and game.is_collision((head[0]+1, head[1]))) or
  465.             (game.direction == LEFT and game.is_collision((head[0], head[1]+1))),
  466.  
  467.             # Move direction
  468.             dir_l,
  469.             dir_r,
  470.             dir_u,
  471.             dir_d,
  472.  
  473.             # Food location
  474.             food_left,
  475.             food_right,
  476.             food_up,
  477.             food_down,
  478.  
  479.             # Distances to obstacles (normalized)
  480.             distances['left'] / max(game.grid_width, game.grid_height),
  481.             distances['right'] / max(game.grid_width, game.grid_height),
  482.             distances['up'] / max(game.grid_width, game.grid_height),
  483.             distances['down'] / max(game.grid_width, game.grid_height)
  484.         ]
  485.  
  486.         # Convert to numpy array
  487.         state = np.array(state, dtype=np.float32)
  488.         return state
  489.  
  490.     def remember(self, state, action, reward, next_state, done):
  491.         """
  492.        Stores the experience in memory with maximum priority.
  493.  
  494.        Parameters:
  495.            state (np.array): Current state
  496.            action (list): Action taken
  497.            reward (float): Reward received
  498.            next_state (np.array): Next state
  499.            done (bool): Whether the game is over
  500.        """
  501.         transition = Transition(state, action, reward, next_state, done)
  502.         self.memory.push(*transition)
  503.  
  504.     def train_long_memory(self):
  505.         """
  506.        Samples a batch from memory and trains the model.
  507.        """
  508.         if len(self.memory) < BATCH_SIZE:
  509.             return  # Not enough samples to train
  510.  
  511.         # Anneal beta towards 1
  512.         beta = min(1.0, self.beta + self.beta_increment)
  513.         self.beta = beta
  514.  
  515.         samples, indices, weights = self.memory.sample(BATCH_SIZE, self.beta)
  516.         batch = Transition(*zip(*samples))
  517.  
  518.         states = torch.tensor(np.array(batch.state), dtype=torch.float).to(self.device)
  519.         actions = torch.tensor(np.array([np.argmax(a) for a in batch.action]), dtype=torch.long).to(self.device)
  520.         rewards = torch.tensor(np.array(batch.reward), dtype=torch.float).to(self.device)
  521.         next_states = torch.tensor(np.array(batch.next_state), dtype=torch.float).to(self.device)
  522.         dones = torch.tensor(np.array(batch.done), dtype=torch.bool).to(self.device)
  523.         weights = torch.tensor(weights, dtype=torch.float).to(self.device)
  524.  
  525.         # Current Q values
  526.         pred = self.model(states).gather(1, actions.view(-1,1)).squeeze(1)
  527.  
  528.         # Double DQN: action selection from the current model
  529.         next_actions = torch.argmax(self.model(next_states), dim=1).unsqueeze(1)
  530.  
  531.         # Q values from the target model
  532.         target = self.target_model(next_states).gather(1, next_actions).squeeze(1)
  533.  
  534.         # Compute target Q values
  535.         target_vals = rewards + (self.gamma * target * (~dones))
  536.  
  537.         # Compute TD errors
  538.         td_errors = target_vals - pred
  539.  
  540.         # Compute loss with importance sampling weights
  541.         loss = (torch.pow(td_errors, 2) * weights).mean()
  542.  
  543.         # Optimize the model
  544.         self.optimizer.zero_grad()
  545.         loss.backward()
  546.         # Gradient clipping
  547.         torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
  548.         self.optimizer.step()
  549.  
  550.         # Update priorities
  551.         new_priorities = torch.abs(td_errors).detach().cpu().numpy() + 1e-6  # Avoid zero priority
  552.         self.memory.update_priorities(indices, new_priorities)
  553.  
  554.         # Log loss to TensorBoard
  555.         self.writer.add_scalar('Loss/train', loss.item(), self.n_games)
  556.  
  557.     def get_action(self, state, game, play_mode=False):
  558.         """
  559.        Decides the next action based on the current state and game instance.
  560.  
  561.        Parameters:
  562.            state (np.array): Current state
  563.            game (SnakeGame): Current game instance
  564.            play_mode (bool): If True, disables exploration
  565.  
  566.        Returns:
  567.            list: One-hot encoded action [straight, right, left]
  568.        """
  569.         final_move = [0, 0, 0]
  570.         epsilon = 0.0 if play_mode else self.epsilon
  571.  
  572.         if random.random() < epsilon:
  573.             # Explore: Random valid action
  574.             valid_actions = self.get_valid_actions(state, game)
  575.             move = random.choice(valid_actions)
  576.         else:
  577.             # Exploit: Use the model to select the best action
  578.             self.model.eval()  # Set model to evaluation mode for inference
  579.             state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(self.device)  # Add batch dimension
  580.             with torch.no_grad():
  581.                 prediction = self.model(state_tensor)
  582.             move = torch.argmax(prediction).item()
  583.             self.model.train()  # Revert back to training mode
  584.  
  585.         final_move[move] = 1
  586.         return final_move
  587.  
  588.     def get_valid_actions(self, state, game):
  589.         """
  590.        Determines valid actions based on the current state and game to prevent immediate collisions.
  591.  
  592.        Parameters:
  593.            state (np.array): Current state
  594.            game (SnakeGame): Current game instance
  595.  
  596.        Returns:
  597.            list: List of valid action indices
  598.        """
  599.         # Decode the state to get current direction and danger flags
  600.         # Not using pre-encoded danger flags; instead, using the actual game state
  601.  
  602.         # Current direction
  603.         dir_l, dir_r, dir_u, dir_d = state[3], state[4], state[5], state[6]
  604.         food_left, food_right, food_up, food_down = state[7], state[8], state[9], state[10]
  605.         distances = state[11:15]  # Distances: left, right, up, down
  606.  
  607.         # Map direction encoding to corresponding direction vectors
  608.         direction_map = {
  609.             (1, 0, 0, 0): LEFT,
  610.             (0, 1, 0, 0): RIGHT,
  611.             (0, 0, 1, 0): UP,
  612.             (0, 0, 0, 1): DOWN
  613.         }
  614.         current_direction = direction_map.get((dir_l, dir_r, dir_u, dir_d), LEFT)  # Default to LEFT if not found
  615.  
  616.         # All possible moves: [straight, right, left]
  617.         directions = [UP, RIGHT, DOWN, LEFT]
  618.         try:
  619.             idx = directions.index(current_direction)
  620.         except ValueError:
  621.             idx = 0  # Default index if direction not found
  622.  
  623.         # Predict new head positions for each action
  624.         action_to_new_dir = {
  625.             0: current_direction,  # Straight
  626.             1: directions[(idx + 1) % 4],  # Right
  627.             2: directions[(idx - 1) % 4]  # Left
  628.         }
  629.  
  630.         # Validate each action
  631.         valid_actions = []
  632.         for action, new_dir in action_to_new_dir.items():
  633.             # Predict new head position
  634.             head = game.head
  635.             new_head = (head[0] + new_dir[0], head[1] + new_dir[1])
  636.  
  637.             # Check for boundary and self-collision
  638.             if 0 <= new_head[0] < game.grid_width and 0 <= new_head[1] < game.grid_height and new_head not in list(game.snake)[1:]:
  639.                 valid_actions.append(action)
  640.  
  641.         # If all actions are invalid, allow all to prevent deadlock
  642.         return valid_actions if valid_actions else [0, 1, 2]
  643.  
  644.     def save_agent(self, checkpoint=None):
  645.         """
  646.        Saves the agent's model and memory to disk.
  647.  
  648.        Parameters:
  649.            checkpoint (str): Optional checkpoint identifier
  650.        """
  651.         if checkpoint:
  652.             path = get_file_path(f'snake_dqn_{checkpoint}.pth')
  653.         else:
  654.             path = get_file_path(MODEL_PATH)
  655.         torch.save(self.model.state_dict(), path)
  656.         with open(get_file_path(MEMORY_PATH), 'wb') as f:
  657.             pickle.dump(self.memory, f)
  658.         with open(get_file_path(AGENT_INFO_PATH), 'wb') as f:
  659.             pickle.dump({'n_games': self.n_games, 'epsilon': self.epsilon, 'beta': self.beta}, f)
  660.  
  661.     def load_agent(self):
  662.         """
  663.        Loads the agent's model and memory from disk if they exist.
  664.        """
  665.         if os.path.exists(get_file_path(MODEL_PATH)):
  666.             self.model.load_state_dict(torch.load(get_file_path(MODEL_PATH), map_location=self.device))
  667.             self.model.eval()
  668.         if os.path.exists(get_file_path(MEMORY_PATH)):
  669.             with open(get_file_path(MEMORY_PATH), 'rb') as f:
  670.                 self.memory = pickle.load(f)
  671.         if os.path.exists(get_file_path(AGENT_INFO_PATH)):
  672.             with open(get_file_path(AGENT_INFO_PATH), 'rb') as f:
  673.                 info = pickle.load(f)
  674.                 self.n_games = info.get('n_games', 0)
  675.                 self.epsilon = info.get('epsilon', EPS_START)
  676.                 self.beta = info.get('beta', PER_BETA_START)
  677.                 self.writer.add_scalar('Epsilon', self.epsilon, self.n_games)
  678.  
  679. # ================================
  680. # Plotting Function
  681. # ================================
  682.  
  683. def plot_rewards_graph(rewards, avg_rewards):
  684.     """
  685.    Plots the rewards and average rewards as line graphs with fixed y-axis range (-100 to 100).
  686.    """
  687.     plt.figure(figsize=(16, 8))  # Create a new figure
  688.     plt.title('Training Progress', fontsize=18)
  689.     plt.xlabel('Games', fontsize=14)
  690.     plt.ylabel('Reward', fontsize=14)
  691.  
  692.     # Plot rewards and average rewards as lines
  693.     plt.plot(rewards, label='Reward', linestyle='-', linewidth=1.5, alpha=0.8)
  694.     plt.plot(avg_rewards, label='Avg Reward (100)', linestyle='--', linewidth=2)
  695.  
  696.     # Set y-axis limits to a fixed range
  697.     plt.ylim(-100, 100)
  698.  
  699.     # Add a legend and grid for better readability
  700.     plt.legend(fontsize=12)
  701.     plt.grid(alpha=0.4)
  702.  
  703.     # Save the plot
  704.     plt.tight_layout()
  705.     plt.savefig(get_file_path(PLOTS_PATH))
  706.    
  707.     # Close the figure to free up memory
  708.     plt.close()
  709.  
  710. # ================================
  711. # Save and Load Functions
  712. # ================================
  713.  
  714. def save_agent(agent, checkpoint=None):
  715.     """
  716.    Saves the agent's model and memory to disk.
  717.  
  718.    Parameters:
  719.        agent (Agent): The agent to save
  720.        checkpoint (str): Optional checkpoint identifier
  721.    """
  722.     agent.save_agent(checkpoint)
  723.     print("Agent saved successfully.")
  724.  
  725. def load_agent_func(agent):
  726.     """
  727.    Loads the agent's model and memory from disk if they exist.
  728.  
  729.    Parameters:
  730.        agent (Agent): The agent to load
  731.    """
  732.     agent.load_agent()
  733.     print("Agent loaded successfully.")
  734.  
  735. # ================================
  736. # Training Function
  737. # ================================
  738.  
  739. def train(agent, game):
  740.     """
  741.    The main training loop for the agent.
  742.  
  743.    Parameters:
  744.        agent (Agent): The agent to train
  745.        game (SnakeGame): The game environment
  746.    """
  747.     rewards = []
  748.     avg_rewards = []
  749.     clock = pygame.time.Clock()
  750.  
  751.     while True:
  752.         state = agent.get_state(game)
  753.         done = False
  754.         score = 0
  755.  
  756.         while not done:
  757.             # Handle Pygame events to allow exiting during training
  758.             for event in pygame.event.get():
  759.                 if event.type == pygame.QUIT:
  760.                     agent.writer.close()
  761.                     pygame.quit()
  762.                     sys.exit()
  763.  
  764.             # Agent takes an action
  765.             action = agent.get_action(state, game)
  766.             reward, done, score = game.play_step(action)
  767.             next_state = agent.get_state(game)
  768.             agent.remember(state, action, reward, next_state, done)
  769.             state = next_state
  770.             rewards.append(reward)
  771.  
  772.             if done:
  773.                 break
  774.  
  775.             # Control the training speed
  776.             clock.tick(FPS)
  777.  
  778.         # Update the number of games played
  779.         agent.n_games +=1
  780.  
  781.         # Train the agent with the experience of the current game
  782.         agent.train_long_memory()
  783.  
  784.         # Soft update target network
  785.         agent.update_target(tau=TAU)
  786.  
  787.         # Decay epsilon
  788.         agent.epsilon = max(EPS_END, agent.epsilon - (EPS_START - EPS_END) / EPS_DECAY)
  789.         agent.writer.add_scalar('Epsilon', agent.epsilon, agent.n_games)
  790.  
  791.         # Log progress and plot rewards every 100 games
  792.         if agent.n_games % 100 ==0:
  793.             avg_reward = sum(rewards[-100:]) / 100
  794.             avg_rewards.append(avg_reward)
  795.             agent.writer.add_scalar('Average Reward', avg_reward, agent.n_games)
  796.             print(f'Game: {agent.n_games}, Score: {score}, Avg Reward: {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}')
  797.             plot_rewards_graph(rewards, avg_rewards)
  798.  
  799.         # Save the agent periodically
  800.         if agent.n_games % 500 ==0:
  801.             save_agent(agent)
  802.             print("Model and memory saved.")
  803.  
  804.         # **Reset the game after each game over**
  805.         game.reset()
  806.  
  807. # ================================
  808. # Play Function
  809. # ================================
  810.  
  811. def play(agent, game):
  812.     """
  813.    Allows the agent to play the game using the trained model.
  814.  
  815.    Parameters:
  816.        agent (Agent): The trained agent
  817.        game (SnakeGame): The game environment
  818.    """
  819.     if not os.path.exists(get_file_path(MODEL_PATH)):
  820.         print("No trained model found. Please train the agent first.")
  821.         return
  822.     clock = pygame.time.Clock()
  823.  
  824.     # Setup Pygame display
  825.     screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
  826.     pygame.display.set_caption('Snake AI Play Mode')
  827.  
  828.     # Initialize the game with the screen
  829.     game.screen = screen
  830.     game.reset()
  831.  
  832.     while True:
  833.         for event in pygame.event.get():
  834.             if event.type == pygame.QUIT:
  835.                 agent.writer.close()
  836.                 pygame.quit()
  837.                 sys.exit()
  838.             # Optional: Add a key to exit play mode
  839.             if event.type == pygame.KEYDOWN:
  840.                 if event.key == pygame.K_ESCAPE:
  841.                     agent.writer.close()
  842.                     pygame.quit()
  843.                     sys.exit()
  844.  
  845.         state = agent.get_state(game)
  846.         action = agent.get_action(state, game, play_mode=True)
  847.         reward, done, score = game.play_step(action)
  848.  
  849.         if done:
  850.             print(f'Game Over! Score: {score}')
  851.             agent.writer.add_scalar('Score', score, agent.n_games)
  852.             game.reset()
  853.  
  854.         # Render the game
  855.         game.render()
  856.  
  857.         clock.tick(FPS)
  858.  
  859. # ================================
  860. # Main Function
  861. # ================================
  862.  
  863. def main():
  864.     """
  865.    The main entry point of the program.
  866.    """
  867.     agent = Agent()
  868.     load_agent_func(agent)
  869.    
  870.     if MODE == 'train':
  871.         # Initialize Pygame display for training visualization
  872.         screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
  873.         pygame.display.set_caption('Snake AI Training Mode')
  874.         game = SnakeGame(screen)
  875.         print("Starting training...")
  876.         train(agent, game)
  877.     elif MODE == 'play':
  878.         game = SnakeGame()  # Create game without screen initially
  879.         print("Starting play mode...")
  880.         play(agent, game)
  881.     else:
  882.         print("Invalid MODE! Please set MODE to 'train' or 'play'.")
  883.  
  884. if __name__ == '__main__':
  885.     main()
  886.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement