C#/수업내용

2020.07.08. 수업내용 - 인공지능 강화학습2 ( 펭귄 MLAgent)

dev_sr 2020. 7. 8. 16:38

 

 

 

trainer_config.yaml (UTF-8 파일이여야함, 들여쓰기 중요함)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
behaviors:
  PenguinLearning:
    trainer_type: ppo
    hyperparameters:
      batch_size: 128
      buffer_size: 2048
      learning_rate: 0.0003
      beta: 1.0e-2
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 256
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 1.0e6
    time_horizon: 128
    summary_freq: 5000
    threaded: true
    
curriculum:
  PenguinLearning:
    measure: reward
    thresholds: [ -0.1, 0.7, 1.7, 1.7, 1.7, 2.7, 2.7 ]
    min_lesson_length: 80
    signal_smoothing: true
    parameters: 
        fish_speed: [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5 ]
        feed_radius: [ 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.2 ]

 

yaml 파일 유효성 검사 사이트 (Valid YAML 이 나오면 됨)

 

Best YAML Validator Online

 

codebeautify.org

 

 

PenguinArea 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
using System.Collections;
using System.Collections.Generic;
using System.Net;
using TMPro;
using UnityEngine;
 
 
public class PenguinArea : MonoBehaviour
{
    [Tooltip("The agent inside the area")]
    public PenguinAgent penguinAgent;
 
    [Tooltip("The baby penguin inside the area")]
    public GameObject penguinBaby;
 
    [Tooltip("The TextMeshPro text that shows the cumulative reward of the agent")]
    public TextMeshPro cumulativeRewardText;
 
    [Tooltip("Prefab of a live fish")]
    public Fish fishPrefab;
 
    public float fishSpeed = 5f;
    private List<GameObject> fishList;
 
 
    private void Start()
    {
        this.ResetArea();
    }
 
    private void Update()
    {
        cumulativeRewardText.text = penguinAgent.GetCumulativeReward().ToString("0.00");
    }
 
    //초기 위치로 이동
    public void ResetArea()
    {
        RemoveAllFish();
        PlacePenguin();
        PlaceBaby();
        SpawnFish(4this.fishSpeed);
    }
 
    /// <summary>
    /// Remove a specific fish from the area when it is eaten
    /// </summary>
    /// <param name="fishObject">The fish to remove</param>
    public void RemoveSpecificFish(GameObject fishObject)
    {
        fishList.Remove(fishObject);
        Destroy(fishObject);
    }
 
    /// <summary>
    /// The number of fish remaining
    /// </summary>
    public int FishRemaining
    {
        get { return fishList.Count; }
    }
 
 
    /// <summary>
    /// Choose a random position on the X-Z plane within a partial donut shape
    /// </summary>
    /// <param name="center">The center of the donut</param>
    /// <param name="minAngle">Minimum angle of the wedge</param>
    /// <param name="maxAngle">Maximum angle of the wedge</param>
    /// <param name="minRadius">Minimum distance from the center</param>
    /// <param name="maxRadius">Maximum distance from the center</param>
    /// <returns>A position falling within the specified region</returns>
    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
    {
        float radius = minRadius;
        float angle = minAngle;
 
        if (maxRadius > minRadius)
        {
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);
        }
 
        if (maxAngle > minAngle)
        {
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);
        }
 
        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
    }
 
    /// <summary>
    /// Remove all fish from the area
    /// </summary>
    private void RemoveAllFish()
    {
        if (fishList != null)
        {
            for (int i = 0; i < fishList.Count; i++)
            {
                if (fishList[i] != null)
                {
                    Destroy(fishList[i]);
                }
            }
        }
 
        fishList = new List<GameObject>();
    }
 
    /// <summary>
    /// Place the penguin in the area
    /// </summary>
    private void PlacePenguin()
    {
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
    }
 
    /// <summary>
    /// Place the baby in the area
    /// </summary>
    private void PlaceBaby()
    {
        Rigidbody rigidbody = penguinBaby.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinBaby.transform.position = ChooseRandomPosition(transform.position, -45f, 45f, 4f, 9f) + Vector3.up * .5f;
        penguinBaby.transform.rotation = Quaternion.Euler(0f, 180f, 0f);
    }
 
    /// <summary>
    /// Spawn some number of fish in the area and set their swim speed
    /// </summary>
    /// <param name="count">The number to spawn</param>
    /// <param name="fishSpeed">The swim speed</param>
    private void SpawnFish(int count, float fishSpeed)
    {
        for (int i = 0; i < count; i++)
        {
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
 
            // Set the fish's parent to this area's transform
            fishObject.transform.SetParent(transform);
 
            // Keep track of the fish
            fishList.Add(fishObject);
 
            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
        }
    }
}
 

 

 

PenguinAgent  (엄마펭귄)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
 
public class PenguinAgent : Agent
{
    [Tooltip("How fast the agent moves forward")]
    public float moveSpeed = 5f;
 
    [Tooltip("How fast the agent turns")]
    public float turnSpeed = 180f;
 
    [Tooltip("Prefab of the heart that appears when the baby is fed")]
    public GameObject heartPrefab;
 
    [Tooltip("Prefab of the regurgitated fish that appears when the baby is fed")]
    public GameObject regurgitatedFishPrefab;
 
    private PenguinArea penguinArea;
    new private Rigidbody rigidbody;    //rigidbody를 재정의함
    private GameObject baby;
 
    private bool isFull; // If true, penguin has a full stomach
    private float feedRadius = 0f;
 
    private EnvironmentParameters resetParams;
 
    public override void Initialize()
    {
        //상속받는 Agent 의 Initialize
        Debug.Log("Initialize");
        penguinArea = GetComponentInParent<PenguinArea>();
        baby = penguinArea.penguinBaby;
        rigidbody = GetComponent<Rigidbody>();
        this.resetParams = Academy.Instance.EnvironmentParameters;
        this.AgentReset();
    }
    public override void OnEpisodeBegin()
    {
        Debug.Log("OnEpisodeBegin");
    }
 
    public override void CollectObservations(VectorSensor sensor)
    {
        // Whether the penguin has eaten a fish (1 float = 1 value)
        sensor.AddObservation(isFull);
 
        // Distance to the baby (1 float = 1 value)
        sensor.AddObservation(Vector3.Distance(baby.transform.position, transform.position));
 
        // Direction to baby (1 Vector3 = 3 values)
        sensor.AddObservation((baby.transform.position - transform.position).normalized);
 
        // Direction penguin is facing (1 Vector3 = 3 values)
        sensor.AddObservation(transform.forward);
 
        // 1 + 1 + 3 + 3 = 8 total values
    }
 
    public override void OnActionReceived(float[] vectorAction)
    {
        // Convert the first action to forward movement
        float forwardAmount = vectorAction[0];
 
        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (vectorAction[1== 1f)
        {
            turnAmount = -1f;
        }
        else if (vectorAction[1== 2f)
        {
            turnAmount = 1f;
        }
 
        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
 
        // Apply a tiny negative reward every step to encourage action
        AddReward(-1f / this.MaxStep);
    }
 
    /// <summary>
    /// Reset the agent and area
    /// </summary>
    public void AgentReset()
    {
        isFull = false;
        penguinArea.ResetArea();
       //feedRadius = penguinAcademy.FeedRadius;
    }
 
    private void FixedUpdate()
    {
        // Test if the agent is close enough to to feed the baby
        if (Vector3.Distance(transform.position, baby.transform.position) < feedRadius)
        {
            // Close enough, try to feed the baby
            RegurgitateFish();
        }
    }
 
    /// <summary>
    /// When the agent collides with something, take action
    /// </summary>
    /// <param name="collision">The collision info</param>
    private void OnCollisionEnter(Collision collision)
    {
        if (collision.transform.CompareTag("fish"))
        {
            // Try to eat the fish
            EatFish(collision.gameObject);
        }
        else if (collision.transform.CompareTag("baby"))
        {
            // Try to feed the baby
            RegurgitateFish();  
        }
    }
 
    /// <summary>
    /// Check if agent is full, if not, eat the fish and get a reward
    /// </summary>
    /// <param name="fishObject">The fish to eat</param>
    private void EatFish(GameObject fishObject)
    {
        if (isFull) return// Can't eat another fish while full
        isFull = true;
 
        penguinArea.RemoveSpecificFish(fishObject);
 
        AddReward(1f);
    }
 
    /// <summary>
    /// Check if agent is full, if yes, feed the baby
    /// </summary>
    private void RegurgitateFish()
    {
        if (!isFull) return// Nothing to regurgitate
        isFull = false;
 
        // Spawn regurgitated fish
        GameObject regurgitatedFish = Instantiate<GameObject>(regurgitatedFishPrefab);
        regurgitatedFish.transform.parent = transform.parent;
        regurgitatedFish.transform.position = baby.transform.position;
        Destroy(regurgitatedFish, 4f);
 
        // Spawn heart
        GameObject heart = Instantiate<GameObject>(heartPrefab);
        heart.transform.parent = transform.parent;
        heart.transform.position = baby.transform.position + Vector3.up;
        Destroy(heart, 4f);
 
        AddReward(1f);
 
        if (penguinArea.FishRemaining <= 0)
        {
            this.AgentReset();
            Debug.Log("EndEpisode");
            this.EndEpisode();
        }
    }
}
 

 

 

Fish 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
 
public class Fish : MonoBehaviour
{
    [Tooltip("The swim speed")]
    public float fishSpeed;
 
    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;
 
    private void FixedUpdate()
    {
        if (fishSpeed > 0f)
        {
            Swim();
        }
    }
 
    private void Swim()
    {
        // If it's time for the next action, pick a new speed and destination
        // Else, swim toward the destination
        if (Time.fixedTime >= nextActionTime)
        {
            // Randomize the speed
            randomizedSpeed = fishSpeed * UnityEngine.Random.Range(0.5f, 1.5f);
 
            // Pick a random target
            targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
 
            // Rotate toward the target
            transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
 
            // Calculate the time to get there
            float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
            nextActionTime = Time.fixedTime + timeToGetThere;
        }
        else
        {
            // Make sure that the fish does not swim past the target
            Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
            if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
            {
                transform.position += moveVector;
            }
            else
            {
                transform.position = targetPosition;
                nextActionTime = Time.fixedTime;
            }
        }
    }
}
 
 

 

C:\Users\tjoeun\Desktop\ml-agents-release_3\config (아나콘다 깔린 곳) 에 trainer_config.yaml 를 넣고

cmd에서 config 위치까지 찾은 다음에

mlagents-learn trainer_config.yaml --run-id=PenguinLearning 를 입력해줌

다시 훈련 재 시작할 때

mlagents-learn trainer_config.yaml --run-id=PenguinLearning --resume

훈련에서 빠져나오고 nnmodel파일을 생성할 때 cmd에서

Ctrl + C   하고 생긴 nnmodel파일을 펭귄에이전트에 넣어준다.

 

10만번은 해야 좀 빠릿해지는 듯

 

 

 

출처: 

 

Learn to leverage Artificial Intelligence to enhance your Unity projects人工知能を活用して Unity プロジェクトを�

Our newest additions to the Unity Learn platform will teach you how to use Reinforcement Learning and AI to solve game development challenges and make bett...

blogs.unity3d.com

MLAgent 에 맞게 다시 코드를 수정함 ( Academy X)