Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pygame
- import random
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from collections import deque, namedtuple
- import os
- import matplotlib.pyplot as plt
- import pickle
- import sys
- from torch.utils.tensorboard import SummaryWriter
- # ================================
- # Configuration
- # ================================
- # Mode Selection: 'train' or 'play'
- MODE = 'train' # Change to 'play' to watch the AI play the game
- # Constants
- BLOCK_SIZE = 24 # Increased block size for larger visuals
- GRID_WIDTH = 20 # Increased grid width for a wider game area
- GRID_HEIGHT = 40 # Increased grid height for a taller game area
- SCREEN_WIDTH = BLOCK_SIZE * GRID_WIDTH # 24 * 20 = 480 pixels
- SCREEN_HEIGHT = BLOCK_SIZE * GRID_HEIGHT # 24 * 40 = 960 pixels
- FPS = 20000 # Adjusted FPS for smoother gameplay
- # Colors
- WHITE = (255, 255, 255)
- BLACK = (0, 0, 0)
- GREEN = (0, 255, 0)
- RED = (255, 0, 0)
- # Directions
- UP = (0, -1)
- DOWN = (0, 1)
- LEFT = (-1, 0)
- RIGHT = (1, 0)
- # Hyperparameters
- LR = 1e-4 # Adjusted learning rate for stability
- GAMMA = 0.98 # Adjusted discount factor for long-term rewards
- MEMORY_SIZE = 100_000 # Increased memory size for more diverse experiences
- BATCH_SIZE = 256 # Adjusted batch size for stable gradient estimates
- TARGET_UPDATE = 1000 # Number of steps before updating the target network
- TAU = 0.005 # Polyak averaging factor for soft updates
- EPS_START = 1.0 # Starting value of epsilon
- EPS_END = 0.05 # Minimum value of epsilon
- EPS_DECAY = 1000 # Number of episodes over which epsilon decays
- # Prioritized Replay Parameters
- PER_ALPHA = 0.6 # How much prioritization is used (0 - no prioritization, 1 - full prioritization)
- PER_BETA_START = 0.4 # Initial value of beta for importance-sampling
- PER_BETA_FRAMES = 1_000 # Number of frames over which beta will be annealed from initial value to 1
- # Model saving path (Use relative paths)
- MODEL_PATH = 'snake_dqn.pth'
- MEMORY_PATH = 'snake_memory.pkl'
- AGENT_INFO_PATH = 'agent_info.pkl'
- PLOTS_PATH = 'rewards.png'
- # ================================
- # Initialize Pygame
- # ================================
- pygame.init()
- font = pygame.font.SysFont('arial', 20)
- # ================================
- # Helper Functions
- # ================================
- def get_file_path(filename):
- """
- Returns the absolute path for a given filename in the current working directory.
- """
- return os.path.join(os.getcwd(), filename)
- def manhattan_distance(p1, p2):
- """
- Calculates the Manhattan distance between two points.
- """
- return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
- # ================================
- # Game Environment
- # ================================
- class SnakeGame:
- def __init__(self, screen=None, grid_width=GRID_WIDTH, grid_height=GRID_HEIGHT):
- self.screen = screen
- self.grid_width = grid_width
- self.grid_height = grid_height
- self.reset()
- def reset(self):
- """
- Resets the game to its initial state.
- """
- self.direction = random.choice([UP, DOWN, LEFT, RIGHT])
- self.head = (self.grid_width // 2, self.grid_height // 2)
- # Initialize snake with length 3
- self.snake = deque([
- self.head,
- (self.head[0] - self.direction[0], self.head[1] - self.direction[1]),
- (self.head[0] - 2 * self.direction[0], self.head[1] - 2 * self.direction[1]),
- ])
- self.spawn_food()
- self.score = 0
- self.frame_iteration = 0
- def spawn_food(self):
- """
- Spawns food at a random location not occupied by the snake.
- """
- while True:
- self.food = (random.randint(0, self.grid_width -1), random.randint(0, self.grid_height -1))
- if self.food not in self.snake:
- break
- def play_step(self, action):
- """
- Executes one step of the game based on the given action.
- Parameters:
- action (list): One-hot encoded action [straight, right, left]
- Returns:
- reward (float): Reward obtained after the action
- game_over (bool): Whether the game has ended
- score (int): Current score
- """
- self.frame_iteration += 1
- # Move the snake
- self.move(action)
- self.snake.appendleft(self.head)
- # Initialize reward
- reward = -0.1 # Step penalty to encourage efficiency
- game_over = False
- # Check if game over
- if self.is_collision() or self.frame_iteration > 100 * len(self.snake):
- game_over = True
- reward = -50.0
- return reward, game_over, self.score
- # Check if food is eaten
- if self.head == self.food:
- self.score += 1
- reward = 10.0
- self.spawn_food()
- else:
- self.snake.pop()
- # Render the game if screen is provided
- if self.screen:
- self.render()
- return reward, game_over, self.score
- def is_collision(self, pt=None):
- """
- Checks if the given point collides with the boundaries or the snake itself.
- Parameters:
- pt (tuple): The point to check. If None, checks the snake's head.
- Returns:
- bool: True if collision occurs, False otherwise
- """
- if pt is None:
- pt = self.head
- # Hits boundary
- if pt[0] < 0 or pt[0] >= self.grid_width or pt[1] < 0 or pt[1] >= self.grid_height:
- return True
- # Hits itself
- if pt in list(self.snake)[1:]:
- return True
- return False
- def move(self, action):
- """
- Updates the snake's direction and moves its head based on the action.
- Parameters:
- action (list): One-hot encoded action [straight, right, left]
- """
- # [straight, right, left]
- directions = [UP, RIGHT, DOWN, LEFT]
- idx = directions.index(self.direction)
- if np.array_equal(action, [1, 0, 0]):
- new_dir = self.direction
- elif np.array_equal(action, [0, 1, 0]):
- new_dir = directions[(idx + 1) % 4]
- else: # [0,0,1]
- new_dir = directions[(idx -1) % 4]
- self.direction = new_dir
- x, y = self.head
- dx, dy = self.direction
- self.head = (x + dx, y + dy)
- def calculate_distances(self, head):
- """
- Calculates the distance from the head to the walls and to the snake's body in all four directions.
- Parameters:
- head (tuple): Current head position (x, y)
- Returns:
- dict: Distances in 'left', 'right', 'up', 'down' directions
- """
- distances = {
- 'left': head[0],
- 'right': self.grid_width - head[0] - 1,
- 'up': head[1],
- 'down': self.grid_height - head[1] - 1
- }
- # Calculate distance to the snake's body
- for segment in list(self.snake)[1:]:
- if segment[1] == head[1]:
- if segment[0] < head[0]:
- distances['left'] = min(distances['left'], head[0] - segment[0] - 1)
- else:
- distances['right'] = min(distances['right'], segment[0] - head[0] - 1)
- elif segment[0] == head[0]:
- if segment[1] < head[1]:
- distances['up'] = min(distances['up'], head[1] - segment[1] - 1)
- else:
- distances['down'] = min(distances['down'], segment[1] - head[1] - 1)
- return distances
- def render(self):
- """
- Renders the game state on the screen.
- """
- if self.screen is None:
- return
- # Clear the screen
- self.screen.fill(BLACK)
- # Draw snake
- for pos in self.snake:
- pygame.draw.rect(self.screen, GREEN, pygame.Rect(pos[0]*BLOCK_SIZE, pos[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
- # Draw food
- pygame.draw.rect(self.screen, RED, pygame.Rect(self.food[0]*BLOCK_SIZE, self.food[1]*BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
- # Draw score
- score_text = font.render(f'Score: {self.score}', True, WHITE)
- self.screen.blit(score_text, [0,0])
- # Update the display
- pygame.display.flip()
- # ================================
- # Neural Network Model
- # ================================
- class DuelingDQN(nn.Module):
- def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
- """
- Initializes the Dueling Deep Q-Network.
- Parameters:
- input_size (int): Number of input features
- hidden_size1 (int): Number of neurons in the first hidden layer
- hidden_size2 (int): Number of neurons in the second hidden layer
- output_size (int): Number of possible actions
- """
- super(DuelingDQN, self).__init__()
- self.feature = nn.Sequential(
- nn.Linear(input_size, hidden_size1),
- nn.LayerNorm(hidden_size1), # Layer Normalization for stability
- nn.ReLU(),
- nn.Linear(hidden_size1, hidden_size2),
- nn.LayerNorm(hidden_size2),
- nn.ReLU(),
- )
- # Value stream
- self.value = nn.Sequential(
- nn.Linear(hidden_size2, hidden_size2),
- nn.ReLU(),
- nn.Linear(hidden_size2, 1)
- )
- # Advantage stream
- self.advantage = nn.Sequential(
- nn.Linear(hidden_size2, hidden_size2),
- nn.ReLU(),
- nn.Linear(hidden_size2, output_size)
- )
- def forward(self, x):
- """
- Forward pass through the network.
- Parameters:
- x (torch.Tensor): Input tensor
- Returns:
- torch.Tensor: Output tensor representing Q-values for each action
- """
- features = self.feature(x)
- value = self.value(features)
- advantage = self.advantage(features)
- q_vals = value + (advantage - advantage.mean())
- return q_vals
- # ================================
- # Prioritized Replay Memory
- # ================================
- Transition = namedtuple('Transition',
- ('state', 'action', 'reward', 'next_state', 'done'))
- class PrioritizedReplayMemory:
- def __init__(self, capacity, alpha=PER_ALPHA):
- self.capacity = capacity
- self.alpha = alpha
- self.memory = []
- self.pos = 0
- self.priorities = np.zeros((capacity,), dtype=np.float32)
- def push(self, *args):
- """
- Saves a transition with maximum priority.
- """
- max_priority = self.priorities.max() if self.memory else 1.0
- if len(self.memory) < self.capacity:
- self.memory.append(Transition(*args))
- else:
- self.memory[self.pos] = Transition(*args)
- self.priorities[self.pos] = max_priority
- self.pos = (self.pos + 1) % self.capacity
- def sample(self, batch_size, beta=PER_BETA_START):
- """
- Samples a batch of transitions with probabilities proportional to their priorities.
- Returns:
- tuple: (samples, indices, weights)
- """
- if len(self.memory) == self.capacity:
- priorities = self.priorities
- else:
- priorities = self.priorities[:self.pos]
- probabilities = priorities ** self.alpha
- probabilities /= probabilities.sum()
- indices = np.random.choice(len(self.memory), batch_size, p=probabilities)
- samples = [self.memory[idx] for idx in indices]
- total = len(self.memory)
- weights = (total * probabilities[indices]) ** (-beta)
- weights /= weights.max()
- weights = np.array(weights, dtype=np.float32)
- return samples, indices, weights
- def update_priorities(self, batch_indices, batch_priorities):
- """
- Updates the priorities of sampled transitions.
- Parameters:
- batch_indices (list): List of indices sampled
- batch_priorities (list): List of new priorities
- """
- for idx, priority in zip(batch_indices, batch_priorities):
- self.priorities[idx] = priority
- def __len__(self):
- return len(self.memory)
- # ================================
- # Agent
- # ================================
- class Agent:
- def __init__(self):
- """
- Initializes the agent with necessary parameters and models.
- """
- self.n_games = 0
- self.epsilon = EPS_START # Exploration rate starts high for exploration
- self.gamma = GAMMA
- self.memory = PrioritizedReplayMemory(MEMORY_SIZE)
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # Initialize the neural network
- self.model = DuelingDQN(15, 256, 128, 3).to(self.device) # Updated input_size to 15
- self.target_model = DuelingDQN(15, 256, 128, 3).to(self.device) # Updated input_size to 15
- self.update_target(tau=1.0) # Initial hard update
- # Optimizer and loss function
- self.optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4)
- self.criterion = nn.MSELoss()
- # TensorBoard writer
- self.writer = SummaryWriter('runs/snake_dqn')
- # For adaptive epsilon
- self.beta = PER_BETA_START
- self.beta_increment = (1.0 - PER_BETA_START) / PER_BETA_FRAMES
- self.frame = 1 # To track frames for beta annealing
- def update_target(self, tau=TAU):
- """
- Soft updates the target model's weights towards the current model's weights.
- Parameters:
- tau (float): Polyak averaging factor
- """
- for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
- target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
- def get_state(self, game):
- """
- Constructs the current state from the game.
- Parameters:
- game (SnakeGame): The current game instance
- Returns:
- np.array: The state representation
- """
- head = game.head
- distances = game.calculate_distances(head)
- # Current direction
- dir_l = game.direction == LEFT
- dir_r = game.direction == RIGHT
- dir_u = game.direction == UP
- dir_d = game.direction == DOWN
- # Food location relative to head
- food_left = game.food[0] < head[0]
- food_right = game.food[0] > head[0]
- food_up = game.food[1] < head[1]
- food_down = game.food[1] > head[1]
- state = [
- # Danger straight
- (game.direction == UP and game.is_collision((head[0], head[1]-1))) or
- (game.direction == RIGHT and game.is_collision((head[0]+1, head[1]))) or
- (game.direction == DOWN and game.is_collision((head[0], head[1]+1))) or
- (game.direction == LEFT and game.is_collision((head[0]-1, head[1]))),
- # Danger right
- (game.direction == UP and game.is_collision((head[0]+1, head[1]))) or
- (game.direction == RIGHT and game.is_collision((head[0], head[1]+1))) or
- (game.direction == DOWN and game.is_collision((head[0]-1, head[1]))) or
- (game.direction == LEFT and game.is_collision((head[0], head[1]-1))),
- # Danger left
- (game.direction == UP and game.is_collision((head[0]-1, head[1]))) or
- (game.direction == RIGHT and game.is_collision((head[0], head[1]-1))) or
- (game.direction == DOWN and game.is_collision((head[0]+1, head[1]))) or
- (game.direction == LEFT and game.is_collision((head[0], head[1]+1))),
- # Move direction
- dir_l,
- dir_r,
- dir_u,
- dir_d,
- # Food location
- food_left,
- food_right,
- food_up,
- food_down,
- # Distances to obstacles (normalized)
- distances['left'] / max(game.grid_width, game.grid_height),
- distances['right'] / max(game.grid_width, game.grid_height),
- distances['up'] / max(game.grid_width, game.grid_height),
- distances['down'] / max(game.grid_width, game.grid_height)
- ]
- # Convert to numpy array
- state = np.array(state, dtype=np.float32)
- return state
- def remember(self, state, action, reward, next_state, done):
- """
- Stores the experience in memory with maximum priority.
- Parameters:
- state (np.array): Current state
- action (list): Action taken
- reward (float): Reward received
- next_state (np.array): Next state
- done (bool): Whether the game is over
- """
- transition = Transition(state, action, reward, next_state, done)
- self.memory.push(*transition)
- def train_long_memory(self):
- """
- Samples a batch from memory and trains the model.
- """
- if len(self.memory) < BATCH_SIZE:
- return # Not enough samples to train
- # Anneal beta towards 1
- beta = min(1.0, self.beta + self.beta_increment)
- self.beta = beta
- samples, indices, weights = self.memory.sample(BATCH_SIZE, self.beta)
- batch = Transition(*zip(*samples))
- states = torch.tensor(np.array(batch.state), dtype=torch.float).to(self.device)
- actions = torch.tensor(np.array([np.argmax(a) for a in batch.action]), dtype=torch.long).to(self.device)
- rewards = torch.tensor(np.array(batch.reward), dtype=torch.float).to(self.device)
- next_states = torch.tensor(np.array(batch.next_state), dtype=torch.float).to(self.device)
- dones = torch.tensor(np.array(batch.done), dtype=torch.bool).to(self.device)
- weights = torch.tensor(weights, dtype=torch.float).to(self.device)
- # Current Q values
- pred = self.model(states).gather(1, actions.view(-1,1)).squeeze(1)
- # Double DQN: action selection from the current model
- next_actions = torch.argmax(self.model(next_states), dim=1).unsqueeze(1)
- # Q values from the target model
- target = self.target_model(next_states).gather(1, next_actions).squeeze(1)
- # Compute target Q values
- target_vals = rewards + (self.gamma * target * (~dones))
- # Compute TD errors
- td_errors = target_vals - pred
- # Compute loss with importance sampling weights
- loss = (torch.pow(td_errors, 2) * weights).mean()
- # Optimize the model
- self.optimizer.zero_grad()
- loss.backward()
- # Gradient clipping
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
- self.optimizer.step()
- # Update priorities
- new_priorities = torch.abs(td_errors).detach().cpu().numpy() + 1e-6 # Avoid zero priority
- self.memory.update_priorities(indices, new_priorities)
- # Log loss to TensorBoard
- self.writer.add_scalar('Loss/train', loss.item(), self.n_games)
- def get_action(self, state, game, play_mode=False):
- """
- Decides the next action based on the current state and game instance.
- Parameters:
- state (np.array): Current state
- game (SnakeGame): Current game instance
- play_mode (bool): If True, disables exploration
- Returns:
- list: One-hot encoded action [straight, right, left]
- """
- final_move = [0, 0, 0]
- epsilon = 0.0 if play_mode else self.epsilon
- if random.random() < epsilon:
- # Explore: Random valid action
- valid_actions = self.get_valid_actions(state, game)
- move = random.choice(valid_actions)
- else:
- # Exploit: Use the model to select the best action
- self.model.eval() # Set model to evaluation mode for inference
- state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(self.device) # Add batch dimension
- with torch.no_grad():
- prediction = self.model(state_tensor)
- move = torch.argmax(prediction).item()
- self.model.train() # Revert back to training mode
- final_move[move] = 1
- return final_move
- def get_valid_actions(self, state, game):
- """
- Determines valid actions based on the current state and game to prevent immediate collisions.
- Parameters:
- state (np.array): Current state
- game (SnakeGame): Current game instance
- Returns:
- list: List of valid action indices
- """
- # Decode the state to get current direction and danger flags
- # Not using pre-encoded danger flags; instead, using the actual game state
- # Current direction
- dir_l, dir_r, dir_u, dir_d = state[3], state[4], state[5], state[6]
- food_left, food_right, food_up, food_down = state[7], state[8], state[9], state[10]
- distances = state[11:15] # Distances: left, right, up, down
- # Map direction encoding to corresponding direction vectors
- direction_map = {
- (1, 0, 0, 0): LEFT,
- (0, 1, 0, 0): RIGHT,
- (0, 0, 1, 0): UP,
- (0, 0, 0, 1): DOWN
- }
- current_direction = direction_map.get((dir_l, dir_r, dir_u, dir_d), LEFT) # Default to LEFT if not found
- # All possible moves: [straight, right, left]
- directions = [UP, RIGHT, DOWN, LEFT]
- try:
- idx = directions.index(current_direction)
- except ValueError:
- idx = 0 # Default index if direction not found
- # Predict new head positions for each action
- action_to_new_dir = {
- 0: current_direction, # Straight
- 1: directions[(idx + 1) % 4], # Right
- 2: directions[(idx - 1) % 4] # Left
- }
- # Validate each action
- valid_actions = []
- for action, new_dir in action_to_new_dir.items():
- # Predict new head position
- head = game.head
- new_head = (head[0] + new_dir[0], head[1] + new_dir[1])
- # Check for boundary and self-collision
- if 0 <= new_head[0] < game.grid_width and 0 <= new_head[1] < game.grid_height and new_head not in list(game.snake)[1:]:
- valid_actions.append(action)
- # If all actions are invalid, allow all to prevent deadlock
- return valid_actions if valid_actions else [0, 1, 2]
- def save_agent(self, checkpoint=None):
- """
- Saves the agent's model and memory to disk.
- Parameters:
- checkpoint (str): Optional checkpoint identifier
- """
- if checkpoint:
- path = get_file_path(f'snake_dqn_{checkpoint}.pth')
- else:
- path = get_file_path(MODEL_PATH)
- torch.save(self.model.state_dict(), path)
- with open(get_file_path(MEMORY_PATH), 'wb') as f:
- pickle.dump(self.memory, f)
- with open(get_file_path(AGENT_INFO_PATH), 'wb') as f:
- pickle.dump({'n_games': self.n_games, 'epsilon': self.epsilon, 'beta': self.beta}, f)
- def load_agent(self):
- """
- Loads the agent's model and memory from disk if they exist.
- """
- if os.path.exists(get_file_path(MODEL_PATH)):
- self.model.load_state_dict(torch.load(get_file_path(MODEL_PATH), map_location=self.device))
- self.model.eval()
- if os.path.exists(get_file_path(MEMORY_PATH)):
- with open(get_file_path(MEMORY_PATH), 'rb') as f:
- self.memory = pickle.load(f)
- if os.path.exists(get_file_path(AGENT_INFO_PATH)):
- with open(get_file_path(AGENT_INFO_PATH), 'rb') as f:
- info = pickle.load(f)
- self.n_games = info.get('n_games', 0)
- self.epsilon = info.get('epsilon', EPS_START)
- self.beta = info.get('beta', PER_BETA_START)
- self.writer.add_scalar('Epsilon', self.epsilon, self.n_games)
- # ================================
- # Plotting Function
- # ================================
- def plot_rewards_graph(rewards, avg_rewards):
- """
- Plots the rewards and average rewards as line graphs with fixed y-axis range (-100 to 100).
- """
- plt.figure(figsize=(16, 8)) # Create a new figure
- plt.title('Training Progress', fontsize=18)
- plt.xlabel('Games', fontsize=14)
- plt.ylabel('Reward', fontsize=14)
- # Plot rewards and average rewards as lines
- plt.plot(rewards, label='Reward', linestyle='-', linewidth=1.5, alpha=0.8)
- plt.plot(avg_rewards, label='Avg Reward (100)', linestyle='--', linewidth=2)
- # Set y-axis limits to a fixed range
- plt.ylim(-100, 100)
- # Add a legend and grid for better readability
- plt.legend(fontsize=12)
- plt.grid(alpha=0.4)
- # Save the plot
- plt.tight_layout()
- plt.savefig(get_file_path(PLOTS_PATH))
- # Close the figure to free up memory
- plt.close()
- # ================================
- # Save and Load Functions
- # ================================
- def save_agent(agent, checkpoint=None):
- """
- Saves the agent's model and memory to disk.
- Parameters:
- agent (Agent): The agent to save
- checkpoint (str): Optional checkpoint identifier
- """
- agent.save_agent(checkpoint)
- print("Agent saved successfully.")
- def load_agent_func(agent):
- """
- Loads the agent's model and memory from disk if they exist.
- Parameters:
- agent (Agent): The agent to load
- """
- agent.load_agent()
- print("Agent loaded successfully.")
- # ================================
- # Training Function
- # ================================
- def train(agent, game):
- """
- The main training loop for the agent.
- Parameters:
- agent (Agent): The agent to train
- game (SnakeGame): The game environment
- """
- rewards = []
- avg_rewards = []
- clock = pygame.time.Clock()
- while True:
- state = agent.get_state(game)
- done = False
- score = 0
- while not done:
- # Handle Pygame events to allow exiting during training
- for event in pygame.event.get():
- if event.type == pygame.QUIT:
- agent.writer.close()
- pygame.quit()
- sys.exit()
- # Agent takes an action
- action = agent.get_action(state, game)
- reward, done, score = game.play_step(action)
- next_state = agent.get_state(game)
- agent.remember(state, action, reward, next_state, done)
- state = next_state
- rewards.append(reward)
- if done:
- break
- # Control the training speed
- clock.tick(FPS)
- # Update the number of games played
- agent.n_games +=1
- # Train the agent with the experience of the current game
- agent.train_long_memory()
- # Soft update target network
- agent.update_target(tau=TAU)
- # Decay epsilon
- agent.epsilon = max(EPS_END, agent.epsilon - (EPS_START - EPS_END) / EPS_DECAY)
- agent.writer.add_scalar('Epsilon', agent.epsilon, agent.n_games)
- # Log progress and plot rewards every 100 games
- if agent.n_games % 100 ==0:
- avg_reward = sum(rewards[-100:]) / 100
- avg_rewards.append(avg_reward)
- agent.writer.add_scalar('Average Reward', avg_reward, agent.n_games)
- print(f'Game: {agent.n_games}, Score: {score}, Avg Reward: {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}')
- plot_rewards_graph(rewards, avg_rewards)
- # Save the agent periodically
- if agent.n_games % 500 ==0:
- save_agent(agent)
- print("Model and memory saved.")
- # **Reset the game after each game over**
- game.reset()
- # ================================
- # Play Function
- # ================================
- def play(agent, game):
- """
- Allows the agent to play the game using the trained model.
- Parameters:
- agent (Agent): The trained agent
- game (SnakeGame): The game environment
- """
- if not os.path.exists(get_file_path(MODEL_PATH)):
- print("No trained model found. Please train the agent first.")
- return
- clock = pygame.time.Clock()
- # Setup Pygame display
- screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
- pygame.display.set_caption('Snake AI Play Mode')
- # Initialize the game with the screen
- game.screen = screen
- game.reset()
- while True:
- for event in pygame.event.get():
- if event.type == pygame.QUIT:
- agent.writer.close()
- pygame.quit()
- sys.exit()
- # Optional: Add a key to exit play mode
- if event.type == pygame.KEYDOWN:
- if event.key == pygame.K_ESCAPE:
- agent.writer.close()
- pygame.quit()
- sys.exit()
- state = agent.get_state(game)
- action = agent.get_action(state, game, play_mode=True)
- reward, done, score = game.play_step(action)
- if done:
- print(f'Game Over! Score: {score}')
- agent.writer.add_scalar('Score', score, agent.n_games)
- game.reset()
- # Render the game
- game.render()
- clock.tick(FPS)
- # ================================
- # Main Function
- # ================================
- def main():
- """
- The main entry point of the program.
- """
- agent = Agent()
- load_agent_func(agent)
- if MODE == 'train':
- # Initialize Pygame display for training visualization
- screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
- pygame.display.set_caption('Snake AI Training Mode')
- game = SnakeGame(screen)
- print("Starting training...")
- train(agent, game)
- elif MODE == 'play':
- game = SnakeGame() # Create game without screen initially
- print("Starting play mode...")
- play(agent, game)
- else:
- print("Invalid MODE! Please set MODE to 'train' or 'play'.")
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement