Advertisement
thewindmage420

Tetris Reward Function

Jan 23rd, 2025 (edited)
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.90 KB | None | 0 0
  1. def get_reward_and_next_state(
  2.     agent,
  3.     grid,
  4.     lines_cleared,
  5.     total_lines_cleared,
  6.     game_over,
  7.     rotations=0,
  8.     moved_horizontally=False,
  9.     moved_down=False,
  10.     locked=False,
  11.     tetrimino_shape=None,
  12.     x=0,
  13.     y=0,
  14.     score=0,
  15.     pieces_placed=0,
  16.     level=1,
  17.     device=None
  18. ):
  19.     """
  20.    Revised Tetris reward function to strongly discourage tower-building and incentivize rotations.
  21.    """
  22.     # 0) Step penalty: small negative each frame to avoid stalling
  23.     reward = -1
  24.  
  25.     # 1) LINE CLEAR REWARD
  26.     line_clear_reward_map = {1: 1000, 2: 3000, 3: 7000, 4: 14000}
  27.     line_clear_points = line_clear_reward_map.get(lines_cleared, 0)
  28.     reward += line_clear_points
  29.  
  30.     if lines_cleared > 0:
  31.         print(f"[DEBUG] Lines cleared: {lines_cleared} -> +{line_clear_points}")
  32.  
  33.     # 2) ROTATION REWARD: Reward for performing rotations
  34.     if rotations > 3:
  35.         rotation_bonus = rotations * -5  # Adjust the multiplier as needed
  36.         reward += rotation_bonus
  37.         print(f"[DEBUG] Rotations performed: {rotations} -> +{rotation_bonus} reward")
  38.  
  39.     # After state updates
  40.     after_holes = count_holes(grid)
  41.     after_max_height, after_col_heights = get_max_height_and_column_heights(grid)
  42.     after_bumpiness = get_bumpiness(after_col_heights)
  43.  
  44.     if locked:
  45.         # Before locking, remove the locked piece to measure before state
  46.         grid_without = [row[:] for row in grid]
  47.         for i, row_block in enumerate(tetrimino_shape):
  48.             for j, block in enumerate(row_block):
  49.                 if block and 0 <= (y + i) < GRID_HEIGHT and 0 <= (x + j) < GRID_WIDTH:
  50.                     grid_without[y + i][x + j] = 0
  51.  
  52.         before_holes = count_holes(grid_without)
  53.         before_max_height, before_col_heights = get_max_height_and_column_heights(grid_without)
  54.         before_bumpiness = get_bumpiness(before_col_heights)
  55.  
  56.         # 3) LARGE PENALTIES FOR HEIGHT, HOLES, BUMPINESS
  57.         height_penalty = 60.0
  58.         hole_penalty   = 40.0
  59.         bump_penalty   = 30.0
  60.  
  61.         reward -= after_max_height * height_penalty
  62.         reward -= after_holes * hole_penalty
  63.         reward -= after_bumpiness * bump_penalty
  64.  
  65.         # 4) HARSH PENALTY IF ANY COLUMN ABOVE 75%
  66.         if after_max_height >= 0.75 * GRID_HEIGHT:
  67.             reward -= 500
  68.             print("[DEBUG] Tall column penalty: -500")
  69.  
  70.         # 5) REWARD "IMPROVEMENTS" ONLY IF THE BOARD DIDN'T WORSEN
  71.         board_not_worse = (
  72.             after_holes <= before_holes and
  73.             after_max_height <= before_max_height and
  74.             after_bumpiness <= before_bumpiness
  75.         )
  76.         if board_not_worse:
  77.             improvement_bonus = 20
  78.             if (after_holes < before_holes) or (after_max_height < before_max_height) or (after_bumpiness < before_bumpiness):
  79.                 improvement_bonus += 20
  80.             reward += improvement_bonus
  81.             print(f"[DEBUG] Board not worse => +{improvement_bonus} improvement bonus")
  82.  
  83.         # 6) Add partial fill bonus
  84.         partial_fill = 0
  85.         for row in grid:
  86.             fill_count = sum(1 for cell in row if cell != 0)
  87.             if fill_count >= GRID_WIDTH - 2:
  88.                 partial_fill += (fill_count - (GRID_WIDTH - 2)) * 5
  89.         if partial_fill > 0:
  90.             reward += partial_fill
  91.             print(f"[DEBUG] Partial fill bonus => +{partial_fill}")
  92.  
  93.         # 7) Extra penalty for locking
  94.         reward -= 5
  95.  
  96.     # 8) Rotation penalty
  97.     if rotations > 3:
  98.         penalty = (rotations - 3) * 5
  99.         reward -= penalty
  100.         print(f"[DEBUG] Rotation penalty => -{penalty}")
  101.  
  102.     # 9) GAME OVER PENALTY
  103.     if game_over:
  104.         reward -= 3000  
  105.         print("[DEBUG] Game over penalty => -3000")
  106.  
  107.     # Build next state with 5 channels
  108.     next_state = construct_state_tensor(grid, score, lines_cleared, pieces_placed, level, device)
  109.  
  110.     return reward, next_state
  111.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement