Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using UnityEngine;
- using GeneticAlgorithm;
- using NeuralNetwork;
- public class AgentManager : MonoBehaviour
- {
- public int numAgents;
- public float pongInstanceSize;
- public float gameLengthSeconds;
- public GameObject pongPrefab;
- [Range(0, 1)]
- public float mutationChance;
- public int mutationStrength;
- public float timeScale;
- public int currentGeneration;
- GeneticAlgorithm.GeneticAlgorithm geneticAlgorithm;
- PongAgent[] pongAgents;
- Agent[] geneticAgents;
- float timer;
- void Start()
- {
- pongAgents = new PongAgent[numAgents];
- for (int i = 0; i < pongAgents.Length; i++)
- {
- pongAgents[i] = new PongAgent();
- GameObject pongObj = Instantiate(pongPrefab);
- pongObj.transform.parent = transform;
- pongPrefab.transform.position = Vector2.right * (i + 1) * pongInstanceSize;
- pongAgents[i].pongGame = pongObj.GetComponent<PongGame>();
- }
- geneticAlgorithm = new GeneticAlgorithm.GeneticAlgorithm();
- geneticAlgorithm.mutationChance = mutationChance;
- geneticAlgorithm.mutationStrength = mutationStrength;
- geneticAgents = new Agent[numAgents];
- geneticAlgorithm.SetAgents(geneticAgents);
- for (int i = 0; i < geneticAgents.Length; i++)
- {
- geneticAgents[i] = new Agent();
- }
- }
- void Update()
- {
- Time.timeScale = timeScale;
- timer += Time.deltaTime;
- if (timer > gameLengthSeconds)
- {
- int scoreSum = 0;
- // Update genetic agents
- for (int i = 0; i < pongAgents.Length; i++)
- {
- string geneticCode = NeuralNetworkSerializer.SerializeToBitString(pongAgents[i].neuralNetwork);
- geneticAgents[i].geneticCode = geneticCode;
- geneticAgents[i].score = pongAgents[i].pongGame.score.score;
- scoreSum += pongAgents[i].pongGame.score.score;
- pongAgents[i].pongGame.score.ResetScore();
- }
- int averageScore = scoreSum / pongAgents.Length;
- geneticAlgorithm.NewGeneration();
- Debug.Log("New Generation! Average score: " + averageScore);
- currentGeneration++;
- // Update pong agents' neural networks
- for (int i = 0; i < pongAgents.Length; i++)
- {
- NeuralNetworkSerializer.DeserializeFromBitString(geneticAgents[i].geneticCode, pongAgents[i].neuralNetwork);
- }
- timer = 0;
- }
- foreach (var agent in pongAgents)
- {
- agent.Evaluate(Time.deltaTime);
- }
- }
- }
- public class PongAgent
- {
- public PongGame pongGame;
- public NeuralNetwork.NeuralNetwork neuralNetwork;
- public PongAgent()
- {
- NeuralNetworkParams networkParams = new NeuralNetworkParams
- {
- numInputs = 3,
- hiddenLayersNodes = new int[] { 5, 5 },
- outputLayerNodes = 2
- };
- neuralNetwork = new NeuralNetwork.NeuralNetwork(networkParams);
- neuralNetwork.SetRandomParameters();
- }
- // Should be called every frame
- public void Evaluate(float deltaTime)
- {
- // Move paddle based on balls and paddles position using network
- Vector2 ballPosition = pongGame.ball.transform.localPosition;
- Vector2 ballPositionNormalized = new Vector2(ballPosition.x / 6.5f, ballPosition.y / 4.9f);
- float paddleHeightNormalized = pongGame.paddle.transform.position.y / 4.3f;
- neuralNetwork.SetInputs(new float[] { ballPositionNormalized.x, ballPositionNormalized.y, paddleHeightNormalized });
- float[] output = neuralNetwork.Evaluate();
- if (output[0] > output[1])
- {
- pongGame.paddle.MoveUp(deltaTime);
- }
- else
- {
- pongGame.paddle.MoveDown(deltaTime);
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement