Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using UnityEngine;
- using Evolution;
- using NeuralNetwork;
- using System;
- using UnityEngine.Events;
- using System.Linq;
- public class NeuralAgentManager : MonoBehaviour
- {
- [Header("Setup")]
- public GameObject agentPrefab;
- public Graph scoreGraph;
- public int seed;
- [Header("Population Settings")]
- public int populationCount = 100;
- public float generationLengthSeconds = 30;
- [Header("Evolution Parameters")]
- [Range(0f, 1f)] public float mutationRate = .5f;
- [Range(0f, 1f)] public float mutationStrength = .2f;
- [Range(0f, 1f)] public float selectionRate = .02f;
- [Header("Visualization")]
- public bool showBestAgent;
- [Header("Events")]
- public UnityEvent onNewGeneration;
- public int currentGeneration { get; private set; }
- public WalkingAgent[] agents { get; private set; }
- private Neuroevolution neuroevolution;
- private float timer;
- private void Start()
- {
- InitalizeAgents();
- }
- private void InitalizeAgents()
- {
- agents = new WalkingAgent[populationCount];
- if (agentPrefab.GetComponent<Creature>() == null)
- {
- Debug.LogError("Agent prefab must contain a Creature component!");
- return;
- }
- for (int i = 0; i < agents.Length; i++)
- {
- var obj = Instantiate(agentPrefab);
- obj.transform.parent = transform;
- agents[i] = new WalkingAgent(obj.GetComponent<Creature>());
- }
- neuroevolution = new(agents, mutationRate, mutationStrength, selectionRate);
- }
- private void OnValidate()
- {
- if (agents == null) return;
- UpdateAgentVisibility();
- }
- private void UpdateAgentVisibility()
- {
- if (showBestAgent)
- {
- agents[0].SetVisible(true);
- for (int i = 1; i < agents.Length; i++)
- {
- agents[i].SetVisible(false);
- }
- }
- else
- {
- foreach (var agent in agents)
- {
- agent.SetVisible(true);
- }
- }
- }
- void FixedUpdate()
- {
- EvaluateAgents();
- UpdateGeneration();
- }
- private void EvaluateAgents()
- {
- foreach (var agent in agents)
- {
- agent.Evaluate(Time.fixedDeltaTime);
- }
- }
- private void UpdateGeneration()
- {
- timer += Time.fixedDeltaTime;
- if (timer > generationLengthSeconds)
- {
- timer = 0;
- NewGeneration();
- }
- }
- private void NewGeneration()
- {
- currentGeneration++;
- onNewGeneration.Invoke();
- neuroevolution.NewGeneration();
- UpdateScoreGraph();
- ResetAgents();
- }
- private void UpdateScoreGraph()
- {
- if (scoreGraph != null)
- {
- scoreGraph.AddPoint(new Vector2(currentGeneration, GetHighestScore()));
- }
- }
- private void ResetAgents()
- {
- foreach (var agent in agents)
- {
- if (showBestAgent)
- {
- agent.SetVisible(false);
- }
- agent.Reset();
- }
- if (showBestAgent)
- {
- agents[0].SetVisible(true);
- }
- }
- private float GetHighestScore()
- {
- return agents.Max(agent => agent.GetScore());
- }
- }
- public class WalkingAgent : EvolvableAgent
- {
- public Creature creature { get; private set; }
- public bool isDead { get; private set; }
- //float[] inputs;
- int inputCount;
- public WalkingAgent(Creature creature)
- {
- int inputCount = creature.BodyPartsCount() * 4;
- this.inputCount = inputCount;
- //This can be tweaked
- NeuralNetworkParams neuralNetworkParams = new NeuralNetworkParams()
- {
- hiddenLayerNeuronCount = new int[] { 8, 8 },
- outputCount = creature.jointController.GetJointsCount(),
- inputCount = inputCount
- };
- neuralNetwork = new NeuralNetwork.NeuralNetwork(neuralNetworkParams);
- neuralNetwork.SetRandomParameters();
- this.creature = creature;
- creature.GetHead().GetComponent<OnCollision>().onCollision.AddListener(OnHeadCollision);
- }
- public override void Evaluate(float deltaTime)
- {
- float[] inputs = new float[inputCount];
- int idx = 0;
- for (int i = 0; i < creature.BodyPartsCount(); i++)
- {
- Transform bodyPart = creature.GetBodyPart(i).transform;
- Vector2 direction = bodyPart.right;
- Vector2 position = creature.GetRelativePosition(bodyPart);
- inputs[idx] = position.x;
- inputs[idx + 1] = position.y;
- inputs[idx + 2] = direction.x;
- inputs[idx + 3] = direction.y;
- idx += 4;
- }
- neuralNetwork.SetInputs(inputs);
- float[] outputs = neuralNetwork.Evaluate();
- for (int i = 0; i < creature.jointController.GetJointsCount(); i++)
- {
- creature.jointController.SetMotorSpeed(i, outputs[i]);
- }
- }
- public void Reset()
- {
- creature.Reset();
- isDead = false;
- creature.gameObject.SetActive(true);
- }
- public void SetVisible(bool isVisible)
- {
- creature.setVisible.SetVisibility(isVisible);
- }
- // How far creature has travelled
- public override float GetScore()
- {
- return creature.GetHead().transform.position.x;
- }
- void OnHeadCollision()
- {
- isDead = true;
- creature.gameObject.SetActive(false);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement