Advertisement
JontePonte

Untitled

May 11th, 2025 (edited)
162
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 5.67 KB | None | 0 0
  1. using UnityEngine;
  2. using Evolution;
  3. using NeuralNetwork;
  4. using System;
  5. using UnityEngine.Events;
  6. using System.Linq;
  7.  
  8. public class NeuralAgentManager : MonoBehaviour
  9. {
  10.     [Header("Setup")]
  11.     public GameObject agentPrefab;
  12.     public Graph scoreGraph;
  13.     public int seed;
  14.  
  15.     [Header("Population Settings")]
  16.     public int populationCount = 100;
  17.     public float generationLengthSeconds = 30;
  18.  
  19.     [Header("Evolution Parameters")]
  20.     [Range(0f, 1f)] public float mutationRate = .5f;
  21.     [Range(0f, 1f)] public float mutationStrength = .2f;
  22.     [Range(0f, 1f)] public float selectionRate = .02f;
  23.  
  24.     [Header("Visualization")]
  25.     public bool showBestAgent;
  26.  
  27.     [Header("Events")]
  28.     public UnityEvent onNewGeneration;
  29.  
  30.     public int currentGeneration { get; private set; }
  31.     public WalkingAgent[] agents { get; private set; }
  32.  
  33.     private Neuroevolution neuroevolution;
  34.     private float timer;
  35.  
  36.     private void Start()
  37.     {
  38.         InitalizeAgents();
  39.     }
  40.  
  41.     private void InitalizeAgents()
  42.     {
  43.         agents = new WalkingAgent[populationCount];
  44.  
  45.         if (agentPrefab.GetComponent<Creature>() == null)
  46.         {
  47.             Debug.LogError("Agent prefab must contain a Creature component!");
  48.             return;
  49.         }
  50.  
  51.         for (int i = 0; i < agents.Length; i++)
  52.         {
  53.             var obj = Instantiate(agentPrefab);
  54.             obj.transform.parent = transform;
  55.             agents[i] = new WalkingAgent(obj.GetComponent<Creature>());
  56.         }
  57.  
  58.         neuroevolution = new(agents, mutationRate, mutationStrength, selectionRate);
  59.     }
  60.  
  61.     private void OnValidate()
  62.     {
  63.         if (agents == null) return;
  64.         UpdateAgentVisibility();
  65.     }
  66.  
  67.     private void UpdateAgentVisibility()
  68.     {
  69.         if (showBestAgent)
  70.         {
  71.             agents[0].SetVisible(true);
  72.             for (int i = 1; i < agents.Length; i++)
  73.             {
  74.                 agents[i].SetVisible(false);
  75.             }
  76.         }
  77.         else
  78.         {
  79.             foreach (var agent in agents)
  80.             {
  81.                 agent.SetVisible(true);
  82.             }
  83.         }
  84.     }
  85.  
  86.     void FixedUpdate()
  87.     {
  88.         EvaluateAgents();
  89.         UpdateGeneration();
  90.     }
  91.  
  92.     private void EvaluateAgents()
  93.     {
  94.         foreach (var agent in agents)
  95.         {
  96.             agent.Evaluate(Time.fixedDeltaTime);
  97.         }
  98.     }
  99.  
  100.     private void UpdateGeneration()
  101.     {
  102.         timer += Time.fixedDeltaTime;
  103.  
  104.         if (timer > generationLengthSeconds)
  105.         {
  106.             timer = 0;
  107.             NewGeneration();
  108.         }
  109.     }
  110.  
  111.     private void NewGeneration()
  112.     {
  113.         currentGeneration++;
  114.         onNewGeneration.Invoke();
  115.         neuroevolution.NewGeneration();
  116.  
  117.         UpdateScoreGraph();
  118.         ResetAgents();
  119.     }
  120.  
  121.     private void UpdateScoreGraph()
  122.     {
  123.         if (scoreGraph != null)
  124.         {
  125.             scoreGraph.AddPoint(new Vector2(currentGeneration, GetHighestScore()));
  126.         }
  127.     }
  128.  
  129.     private void ResetAgents()
  130.     {
  131.         foreach (var agent in agents)
  132.         {
  133.             if (showBestAgent)
  134.             {
  135.                 agent.SetVisible(false);
  136.             }
  137.             agent.Reset();
  138.         }
  139.  
  140.         if (showBestAgent)
  141.         {
  142.             agents[0].SetVisible(true);
  143.         }
  144.     }
  145.  
  146.     private float GetHighestScore()
  147.     {
  148.         return agents.Max(agent => agent.GetScore());
  149.     }
  150. }
  151.  
  152. public class WalkingAgent : EvolvableAgent
  153. {
  154.     public Creature creature { get; private set; }
  155.     public bool isDead { get; private set; }
  156.  
  157.     //float[] inputs;
  158.     int inputCount;
  159.  
  160.     public WalkingAgent(Creature creature)
  161.     {
  162.         int inputCount = creature.BodyPartsCount() * 4;
  163.         this.inputCount = inputCount;
  164.  
  165.         //This can be tweaked
  166.         NeuralNetworkParams neuralNetworkParams = new NeuralNetworkParams()
  167.         {
  168.             hiddenLayerNeuronCount = new int[] { 8, 8 },
  169.             outputCount = creature.jointController.GetJointsCount(),
  170.             inputCount = inputCount
  171.         };
  172.  
  173.         neuralNetwork = new NeuralNetwork.NeuralNetwork(neuralNetworkParams);
  174.         neuralNetwork.SetRandomParameters();
  175.  
  176.         this.creature = creature;
  177.         creature.GetHead().GetComponent<OnCollision>().onCollision.AddListener(OnHeadCollision);
  178.     }
  179.  
  180.     public override void Evaluate(float deltaTime)
  181.     {
  182.         float[] inputs = new float[inputCount];
  183.  
  184.         int idx = 0;
  185.         for (int i = 0; i < creature.BodyPartsCount(); i++)
  186.         {
  187.             Transform bodyPart = creature.GetBodyPart(i).transform;
  188.             Vector2 direction = bodyPart.right;
  189.             Vector2 position = creature.GetRelativePosition(bodyPart);
  190.  
  191.             inputs[idx] = position.x;
  192.             inputs[idx + 1] = position.y;
  193.             inputs[idx + 2] = direction.x;
  194.             inputs[idx + 3] = direction.y;
  195.  
  196.             idx += 4;
  197.         }
  198.  
  199.         neuralNetwork.SetInputs(inputs);
  200.  
  201.         float[] outputs = neuralNetwork.Evaluate();
  202.  
  203.         for (int i = 0; i < creature.jointController.GetJointsCount(); i++)
  204.         {
  205.             creature.jointController.SetMotorSpeed(i, outputs[i]);
  206.         }
  207.     }
  208.  
  209.     public void Reset()
  210.     {
  211.         creature.Reset();
  212.         isDead = false;
  213.         creature.gameObject.SetActive(true);
  214.     }
  215.  
  216.     public void SetVisible(bool isVisible)
  217.     {
  218.         creature.setVisible.SetVisibility(isVisible);
  219.     }
  220.  
  221.     // How far creature has travelled
  222.     public override float GetScore()
  223.     {
  224.         return creature.GetHead().transform.position.x;
  225.     }
  226.  
  227.     void OnHeadCollision()
  228.     {
  229.         isDead = true;
  230.         creature.gameObject.SetActive(false);
  231.     }
  232. }
  233.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement