UNITYで動くNSGA2のサンプルプログラムを実装しましたので、公開します。
どこ探してもなかったので自分で作りました。
一応Github(4月1日公開予定)GitHubの使い方がよくわかってないので適当にあげてます。
https://github.com/AokiMotohide/NSGA2_Unity.git
NSGA2については以下のブログがわかりやすいです。
混雑度ソートについてはこちらを参考にしています。
不明点あればご気軽にコメントどうぞ、間違っている可能性も大いにあります
目次
動作の様子
こんな形で動きます。非支配解(ランク1)の解を赤、それ以外を青で示すようになっています。
問題は多目的最適化のベンチマークとしてよく使われるらしいDTLZ2を用いています。
上記の画像のようにパラメータを設定できるようになっています。
それぞれのソースコード
NSGA2_Run
空のオブジェクトにアタッチし、実際にプログラムを動作させるためのプログラムです。
ーーーーーーーーーーーーーーーーーーーーーーーー
using System.Collections;
using System.Collections.Generic;
using System.Linq;using UnityEngine;
//メインクラス
public class NSGA2_Run : MonoBehaviour
{
public int populationSize = 100; //進化サイズ
public int generations = 10; //世代数
public float mutationProbability = 0.1f; //突然変異率
public float crossoverProbability = 0.9f; //交叉率
public int numberOfVariables = 12; // DTLZ2 for 3 objectives typically uses k=9 (12-3=9) decision variables
public float lowerBound = 0f;
public float upperBound = 1f;
public float delay = 0.5f; public GameObject pointPrefab; // 小さな3Dオブジェクト(例:球)のプレハブ // 生成した視覚化点を一時保存するリスト
private List<GameObject> instantiatedPoints = new List<GameObject>(); //全体の初期集合
public static List<Sample_Solution> population;
//アーカイブ母集団
public static List<Sample_Solution> archivePopulation = new List<Sample_Solution>(); // アーカイブ母集団を保持するリスト
//子母集団
public static List<Sample_Solution> offspringPopulation;
//合体母集団
List<Sample_Solution> combinedPopulation; void Start()
{
//NSGA2クラスのインスタンス化
NSGAII nsga2 = new NSGAII(
populationSize,
generations,
mutationProbability,
crossoverProbability,
numberOfVariables,
lowerBound,
upperBound); //NSGA2の進化(進化のみを行う)
//nsga2.Evolve();
//進化の視覚化(進化と視覚化をどちらも行う)
StartCoroutine(RunAndVisualize()); Debug.Log("完了でございます"); } //結果の視覚化 //動作的にはEvolve関数と同じことをする
public IEnumerator RunAndVisualize()
{ //全体の集合
population = NSGAII.InitializePopulation();
NSGAII.Evaluate(population); archivePopulation = population;
//VisualizeParetoFront(NSGAII.archivePopulation);
Debug.Log("今からループ回すべ"); //進化開始
for (int gen = 0; gen < NSGAII.generations; gen++)
{
// アーカイブ母集団を基にオフスプリングを生成
offspringPopulation = NSGAII.NextGeneration(archivePopulation); // アーカイブ母集団とオフスプリングを結合
combinedPopulation = archivePopulation.Concat(offspringPopulation).ToList(); // 新しいアーカイブ母集団を抽出
NSGAII.Evaluate(combinedPopulation);
archivePopulation = NSGAII.ExtractNewArchive(combinedPopulation); // 現在の非支配解を視覚化
VisualizeParetoFront(archivePopulation);
//VisualizeParetoFront(combinedPopulation); // アーカイブ母集団を新しい母集団として使用
population = archivePopulation;
Debug.Log(gen + "世代目の計算完了");
// 各世代の終わりに、少しの遅延を持たせる
yield return new WaitForSeconds(delay); }
}
//アーカイブ母集団を表示
public void VisualizeParetoFront(List<Sample_Solution> population)
{ // 以前に生成した全ての点を削除
foreach (var point in instantiatedPoints)
{
Destroy(point);
}
instantiatedPoints.Clear(); int count = 1;
foreach (Sample_Solution sol in population)
{
if (sol.Rank == 1) // ランクが1の場合のみ表示
{
//Debug.Log(count + ":アーカイブのランク" + sol.Rank);
count++; Vector3 position = new Vector3*1.FirstOrDefault();
}
///交叉
private static List<Sample_Solution> Crossover(Sample_Solution parent1, Sample_Solution parent2)
{
// ここでは単純な1点交叉を使用します
Sample_Solution child1 = new Sample_Solution(numberOfVariables);
Sample_Solution child2 = new Sample_Solution(numberOfVariables); int crossoverPoint = UnityEngine.Random.Range(0, numberOfVariables); for (int i = 0; i < numberOfVariables; i++)
{
if (i < crossoverPoint)
{
child1.variables[i] = parent1.variables[i];
child2.variables[i] = parent2.variables[i];
}
else
{
child1.variables[i] = parent2.variables[i];
child2.variables[i] = parent1.variables[i];
}
} return new List<Sample_Solution>() { child1, child2 };
}
//突然変異
private static void Mutate(Sample_Solution solution)
{
for (int i = 0; i < numberOfVariables; i++)
{
if (UnityEngine.Random.value < mutationProbability)
{
// ここでは単純な突然変異を使用し、変数の値をランダムに変更します
solution.variables[i] = UnityEngine.Random.Range(lowerBound, upperBound);
}
}
}
// 非支配ソリューションを抽出
public static List<Sample_Solution> ExtractNewArchive(List<Sample_Solution> combinedPopulation)
{
// 1. 非支配ソートを実行して、各解のランクを取得
List<int> ranks = NonDominatedSorting.Sort(combinedPopulation); // 2. 最高ランクから順に、アーカイブ母集団を構築
int maxRank = ranks.Max(); //新しいアーカイブ母集団
List<Sample_Solution> newArchivePopulation = new List<Sample_Solution>();
for (int rank = 1; rank <= maxRank && newArchivePopulation.Count < populationSize; rank++)
{
List<Sample_Solution> currentFront = new List<Sample_Solution>();
for (int i = 0; i < combinedPopulation.Count; i++)
{
if (ranks[i] == rank)
{
currentFront.Add(combinedPopulation[i]);
}
} // 3. 混雑度を計算
List<double> crowdingDistances = NonDominatedSorting.CalculateCrowdingDistance(currentFront); // 4. アーカイブ母集団に解を追加
if (newArchivePopulation.Count + currentFront.Count <= populationSize)
{
newArchivePopulation.AddRange(currentFront);
}
else
{
// 混雑度が高い解を優先して追加
List<int> sortedIndices = crowdingDistances
.Select*2 //もしiがjを支配していたら
{
dominatedBy[i].Add(j); //解iを支配するすべての解のインデックスを保存する
}
else if (Dominates(solutions[i], solutions[j])) //もしiがjに支配されていたら
{
numDominated[i]++; //解iが支配される他の解の数を保存します。
}
}
} //被支配数がゼロならランク1(,多分ランクが高い=値が小さい、ランクが低い:値が大きい)ランク1が最高
if (numDominated[i] == 0)
{
frontLevels[i] = 1;
}
} for (int i = 0; i < N; i++)
{
foreach (int j in dominatedBy[i]) //解iが支配するすべての解jに対してループを回す
{
numDominated[j]--; //解jが支配されている数を1つ減少させます。これは、解iが解jを支配することが確認されたためです。
if (numDominated[j] == 0)
{
frontLevels[j] = frontLevels[i] + 1;
}
}
} return frontLevels;
}
///混雑度の計算
public static List<double> CalculateCrowdingDistance(List<Sample_Solution> front)
{
int size = front.Count; //フロントの解の数 // 各解の混雑度を保存するためのリストを初期化
List<double> crowdingDistance = new List<double>(new double[size]); // フロントに解がない場合
if (size == 0)
return crowdingDistance; // フロントに1つの解のみがある場合混雑度は無限大
if (size == 1)
{
crowdingDistance[0] = double.PositiveInfinity; //正の無限大
return crowdingDistance;
} // フロントに2つの解がある場合、混雑度は無限大
if (size == 2)
{
crowdingDistance[0] = double.PositiveInfinity;
crowdingDistance[1] = double.PositiveInfinity;
return crowdingDistance;
}
//目的関数ごとに解をソート
List<int>[] sortedIndicesByObjective = new List<int>[3]; // 各目的関数ごとに解をソートするためのインデックスを保存
for (int i = 0; i < 3; i++)
{
sortedIndicesByObjective[i] = SortByObjective(front, i); //目的関数ごとに解をソート
//各目的関数においてソートされた解の最初と最後(境界解)の混雑度を無限大に設定
crowdingDistance[sortedIndicesByObjective[i][0]] = double.PositiveInfinity;
crowdingDistance[sortedIndicesByObjective[i][size - 1]] = double.PositiveInfinity;
}
// 中間の解の混雑度を計算、sizeは配列のサイズ
for (int i = 1; i < size - 1; i++)
{
double distance = 0.0; //目的関数ごとの混雑度の計算
for (int j = 0; j < 3; j++)
{
double objectiveDiff = GetObjectiveValue(front[sortedIndicesByObjective[j][i + 1]], j) -
GetObjectiveValue(front[sortedIndicesByObjective[j][i - 1]], j);
distance += objectiveDiff; //目的関数ごとの距離
}
crowdingDistance[sortedIndicesByObjective[0][i]] = distance; //混雑度
} return crowdingDistance;
}
// 指定された目的関数に基づいてフロントの解をソートする関数
private static List<int> SortByObjective(List<Sample_Solution> front, int objectiveIndex)
{
List<int> indices = new List<int>(front.Count);
for (int i = 0; i < front.Count; i++)
{
indices.Add(i);
} //ラムダ式を使ったソート,(a, b) => a.CompareTo(b)で並べ替えることができる
indices.Sort((x, y) => GetObjectiveValue(front[x], objectiveIndex).CompareTo(GetObjectiveValue(front[y], objectiveIndex))); return indices;
} // 解と目的関数のインデックスを受け取り、その目的関数の値を返す関数
private static double GetObjectiveValue(Sample_Solution solution, int objectiveIndex)
{
if (objectiveIndex < 0 || objectiveIndex >= solution.objectives.Length)
{
throw new ArgumentOutOfRangeException("Invalid objective index.");
} return solution.objectives[objectiveIndex];
} private static bool Dominates(Sample_Solution a, Sample_Solution b)
{
bool betterInAnyObjective = false; // 全ての目的関数(objectives)において非支配の関係を調べる
bool isBetterOrEqualInAllObjectives = true;
for (int i = 0; i < a.objectives.Length; i++)
{
if (a.objectives[i] > b.objectives[i])
{
isBetterOrEqualInAllObjectives = false;
break;
}
} if (isBetterOrEqualInAllObjectives)
{
for (int i = 0; i < a.objectives.Length; i++)
{
if (a.objectives[i] < b.objectives[i])
{
betterInAnyObjective = true;
break;
}
}
} return betterInAnyObjective;
}}
ーーーーーーーーーーーーーーーーーーーー
*1:float)sol.objectives[0], (float)sol.objectives[1], (float)sol.objectives[2]);
GameObject point = Instantiate(pointPrefab, position, Quaternion.identity);
point.GetComponent<Renderer>().material.color = Color.red;
instantiatedPoints.Add(point); // この点をリストに追加
}
else
{
Vector3 position = new Vector3((float)sol.objectives[0], (float)sol.objectives[1], (float)sol.objectives[2]);
GameObject point = Instantiate(pointPrefab, position, Quaternion.identity);
point.GetComponent<Renderer>().material.color = Color.blue;
instantiatedPoints.Add(point); // この点をリストに追加
}
} }
}
ーーーーーーーーーーーーーーーーーーーー
NSGA2.cs
実際にNSGA2の計算を行うクラスです。
ーーーーーーーーーーーーーーーーーーーーーーーーーーーー
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using System.Linq;
public class NSGAII
{
// ポピュレーションの設定値
public static int populationSize;
public static int generations;
public static float mutationProbability;
public static float crossoverProbability;
public static int numberOfVariables;
public static float lowerBound;
public static float upperBound; //public static List<Sample_Solution> archivePopulation = new List<Sample_Solution>(); // アーカイブ母集団を保持するリスト
// コンストラクタ:初期化と設定
public NSGAII(int _populationSize, int _generations, float _mutationProbability, float _crossoverProbability, int _numberOfVariables, float _lowerBound, float _upperBound)
{
populationSize = _populationSize;
generations = _generations;
mutationProbability = _mutationProbability;
crossoverProbability = _crossoverProbability;
numberOfVariables = _numberOfVariables;
lowerBound = _lowerBound;
upperBound = _upperBound;
} // 進化の主要なルーチン
public void Evolve()
{
List<Sample_Solution> population = InitializePopulation(); // 初期のポピュレーションを生成
Evaluate(population); // ポピュレーションの評価
NSGA2_Run.archivePopulation = ExtractNewArchive(population); // 初期のアーカイブ母集団を初期化 //進化開始
for (int i = 0; i < generations; i++)
{ List<Sample_Solution> offspringPopulation = NextGeneration(population); //オフスプリングは次の世代の子母集団
List<Sample_Solution> combinedPopulation = NSGA2_Run.archivePopulation.Concat(offspringPopulation).ToList(); // アーカイブ母集団と小母集団を結合 //全母集団を非優劣ソート、混雑度ソートによって選別し、アーカイブ母集団を作成する
List<Sample_Solution> newArchive = ExtractNewArchive(combinedPopulation); //アーカイブ母集団の完成
NSGA2_Run.archivePopulation = newArchive; //子母集団の生成
population = NSGA2_Run.archivePopulation;
}
} // 初期のポピュレーションを生成
public static List<Sample_Solution> InitializePopulation()
{
List<Sample_Solution>population2 = new List<Sample_Solution>(); for (int i = 0; i < populationSize; i++)
{
Sample_Solution sol = new Sample_Solution(numberOfVariables);
for (int j = 0; j < numberOfVariables; j++)
{
sol.variables[j] = UnityEngine.Random.Range(lowerBound, upperBound); // 変数をランダムに初期化
}
population2.Add(sol);
}
return population2;
} // ポピュレーションの評価
public static void Evaluate(List<Sample_Solution> population)
{
foreach (Sample_Solution sol in population)
{
sol.objectives = DTLZ2(sol.variables); // DTLZ2関数を用いた評価
}
}
// DTLZ2関数の実装
public static float DTLZ2(float x)
{
int M = 3; // 目的関数の数
int k = x.Length - M + 1; // k parameter (usually k=10 for DTLZ2) float g = 0.0f; // gの計算 (最後のk変数に基づく)
for (int i = 1; i < M; i++)
{
g += (x[i] - 0.5f) * (x[i] - 0.5f);
} //for (int i = M; i < M + k; i++)
//{
// g += (x[i - 1] - 0.5f) * (x[i - 1] - 0.5f); // i-1 because array index starts from 0
//} float[] objectives = new float[M]; // 目的関数 f1
objectives[0] = (1.0f + g) * Mathf.Cos(x[0] * Mathf.PI / 2) * Mathf.Cos(x[1] * Mathf.PI / 2); // 目的関数 f2
objectives[1] = (1.0f + g) * Mathf.Cos(x[0] * Mathf.PI / 2) * Mathf.Sin(x[1] * Mathf.PI / 2); // 目的関数 f3
objectives[2] = (1.0f + g) * Mathf.Sin(x[0] * Mathf.PI / 2); return objectives;
}
// 次の世代のポピュレーションを生成
public static List<Sample_Solution> NextGeneration(List<Sample_Solution> population)
{
//次の世代のPopulationを保持するやつ
List<Sample_Solution> offspringPopulationNext = new List<Sample_Solution>(); //人口のサイズになるまで
while (offspringPopulationNext.Count < populationSize)
{
// 1. トーナメント選択を用いてペアの親を選択
Sample_Solution parent1 = TournamentSelect(population);
Sample_Solution parent2 = TournamentSelect(population); // 2. 交叉を用いて2つの子供を生成
if (UnityEngine.Random.value < crossoverProbability) //交叉するかしないか
{
List<Sample_Solution> children = Crossover(parent1, parent2);
offspringPopulationNext.AddRange(children);
}
else
{
offspringPopulationNext.Add(parent1);
offspringPopulationNext.Add(parent2);
}
// 3.すべての子に対して変異率に応じて突然変異を適用して子供の遺伝子を変更
foreach (var child in offspringPopulationNext)
{
Mutate(child);
} } return offspringPopulationNext;
}
//トーナメント選択 5個を選択し、それからも最も良いものを選ぶもの選ぶ
private static Sample_Solution TournamentSelect(List<Sample_Solution> population)
{
int tournamentSize = 5; // トーナメントのサイズ
List<Sample_Solution> tournament = new List<Sample_Solution>(); for (int i = 0; i < tournamentSize; i++)
{
int randomIndex = UnityEngine.Random.Range(0, population.Count);
tournament.Add(population[randomIndex]);
} // 最良の解を返す(ここでは目的関数の合計値が最も低いものを最良とします)
return tournament.OrderBy(sol => sol.objectives.Sum(
*2:value, index) => new { Value = value, Index = index })
.OrderByDescending(item => item.Value)
.Select(item => item.Index)
.ToList(); for (int i = 0; i < sortedIndices.Count && newArchivePopulation.Count < populationSize; i++)
{
newArchivePopulation.Add(currentFront[sortedIndices[i]]);
}
}
} // newArchivePopulation のランクを再計算
List<int> newArchiveRanks = NonDominatedSorting.Sort(newArchivePopulation);
for (int i = 0; i < newArchivePopulation.Count; i++)
{
newArchivePopulation[i].Rank = newArchiveRanks[i];
}
return newArchivePopulation;
}}
ーーーーーーーーーーーーーーーーーーーーーーーーーー
Sample_Solution.cs
解の個体の情報を持つクラスです。
ーーーーーーーーーーーーーーーーーーー
using System.Collections;
using System.Collections.Generic;
using UnityEngine;public class Sample_Solution
{
//非優劣ソートにおけるランク
public int Rank { get; set; } // numberOfVariablesに基づいて変数のサイズを初期化するための新しいプロパティとコンストラクタ
public float variables;
public float objectives; public Sample_Solution(int numberOfVars)
{
variables = new float[numberOfVars];
objectives = new float[3]; // 3つの目的関数を持っていると仮定しています
} // 他の追加の属性やメソッドが必要な場合はこちらに追加してください。
}
ーーーーーーーーーーーーーーーーーーー
NonDominatedSorting
非支配ソートと混雑度ソートを行うクラスです。
ーーーーーーーーーーーーーーーーーーーーーーーーーーーー
//「非支配ソート」・「混雑度の計算」を行うクラスusing System.Collections;
using System.Collections.Generic;
using UnityEngine;
using System;
using System.Linq;public static class NonDominatedSorting
{ public static List<int> Sort(List<Sample_Solution> solutions) //各要素の値を取得
{
int N = solutions.Count; //候補解の数をカウント、
List<int> frontLevels = new List<int>(new int[N]); //支配ソートレベル
List<int> dominatedBy = new List<int>[N]; //解iを支配するすべての解のインデックスを保存します。
int numDominated = new int[N]; //解iを支配する他の解の数 //候補解の数ごとにループ
for (int i = 0; i < N; i++)
{
dominatedBy[i] = new List<int>();
//要素ごとに2周目のループ
for (int j = 0; j < N; j++)
{
if (i != j)
{
// iがjに支配されるかどうかを確認します。
if (Dominates(solutions[j], solutions[i]