次の方法で共有


テストの実行

機械学習向けの L1および L2 正規化

James McCaffrey

コード サンプルのダウンロード

James McCaffreyL1 正規化と L2 正規化は密接に関連する手法です。この手法を機械学習 (ML) トレーニング アルゴリズムで使用すると、モデルのオーバーフィット (過剰適合) を減らすことができます。オーバーフィットがなくなると、モデルの予測精度が高まります。今回は、この正規化をソフトウェア開発者の視点から取り上げます。正規化の背景にある考え方自体は難しくありませんが、複数の考え方が相互に関連してくるため、その説明はやや複雑になります。

今回は、ロジスティック回帰 (LR) 分類での正規化を取り上げますが、正規化は多くの種類の機械学習 (特にニューラル ネットワーク分類) で使用されます。LR 分類は、変数が可能性のある 2 つの値のうちいずれになるかを予測するモデルを作成するのが目標です。たとえば、あるサッカー チームの現在の勝率 (x1)、試合会場 (x2)、および怪我で出場できない選手の数 (x3) を基に、そのチームの次回の試合結果 (敗け = 0、勝ち = 1) を予測するような場合です。

Y を予測値とすると、この問題の LR モデルは以下の式で表されます。

z = b0 + b1(x1) + b2(x2) + b3(x3)
Y = 1.0 / (1.0 + e^-z)

b0、b1、b2、b3 は重みです。このような重みは決めておくべき数値にすぎません。言葉で説明すると、各入力値に対応する重み b を掛けて合算し、その合計に定数 b0 を加えた値を z とします。この z を自然対数の底 e を使用する方程式に渡します。その結果、Y は常に 0 ~ 1 の値になります。Y が 0.5 未満の場合は予測結果を 0 とし、Y が 0.5 以上の場合は予測結果を 1 とします。フィーチャーが n 個あれば、重み b は n+1 個になります。

たとえば、チームの現在の勝率が 0.75 で、相手チームのホームゲーム (-1) で、怪我で出場できない選手が 3 人いるとします。また、各重みは b0 = 5.0、b1 = 8.0、b2 = 3.0、b3 = -2.0 だとします。この場合は、z = 5.0 + (8.0)(0.75) + (3.0)(-1) + (-2.0)(3) = 2.0、Y = 1.0 / (1.0 + e^-2.0) = 0.88 となります。Y は 0.5 以上なので、チームは次回の試合で勝利すると予測されます。

正規化について理解する最適な方法は、具体例で考えてみることです。図 1 のデモ プログラムのスクリーンショットをご覧ください。実際のデータを使用するのではなく、デモ プログラムではまず、1,000 件の合成データ項目を生成します。各項目は 12 個の予測変数 (多くの場合、ML 用語で "フィーチャー" と呼ばれます) があります。従属変数の値は最後の列に含まれます。1,000 件のデータ項目を作成後、このデータ セットを無作為に分割し、800 件のトレーニング セットと 200 件のテスト セットに分けます。トレーニング セットはモデルの重み b を見つけるのに使用します。テスト セットは求めたモデルの品質評価に使用します。

ロジスティック回帰分類での正規化
図 1 ロジスティック回帰分類での正規化

次に、デモ プログラムは、正規化を使用しないで LR 分類のトレーニングを行います。求めたモデルの正確性は、トレーニング データで 85.00%、テスト データで 80.50% になりました。80.50% の正確性の方がより現実的な値で、新しいデータが提示された場合に期待できるモデルの正確性の大まかな予測値になります。後ほど説明しますが、モデルはオーバーフィットしているため、予測の正確性はあまり高くありません。

次に、デモは、L1 正規化と L2 正規化の適切な重みを求めています。正規化の重みは、正規化プロセスで使用される 1 つの数値です。デモは、L1 の適切な重みを 0.005、L2 の適切な重みを 0.001 と判断しています。

続いて、L1 正規化を使用してトレーニングを行ってから、L2 正規化を使用して再度トレーニングを行います。L1 正規化によって求めた LR モデルの正確性はテスト データで 95.00%、L2 正規化では 94.50% になりました。両方の正規化によって予測の正確性が大幅に向上しています。

今回は、少なくとも中級レベルのプログラミング スキルがあることを前提としますが、L1 正規化または L2 正規化の知識は問いません。デモ プログラムは C# を使用してコーディングしていますが、コードを JavaScript または Python などの別の言語にリファクタリングしても大きな問題は起きません。

デモ コードは長すぎてコラムにすべて掲載することはできませんが、完全なソース コードは、このコラム付属のコード ダウンロードから入手できます。また、コードを小さく抑え、中心となる考え方を可能な限り明瞭にするため、通常のエラー チェックはすべて削除しています。

プログラムの全体構造

スペースを節約するために少し編集したプログラムの全体構造を図 2 に示します。デモを作成するには、Visual Studio を起動して、Regularization という名前で新しい C# コンソール アプリケーションを作成します。デモは、Microsoft .NET Framework との大きな依存関係はないので、新しいバージョンの Visual Studio であれば動作します。

図 2 プログラムの全体構造

using System;
namespace Regularization
{
  class RegularizationProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("Begin L1 and L2 Regularization demo");
      int numFeatures = 12;
      int numRows = 1000;
      int seed = 42;
      Console.WriteLine("Generating " + numRows +
        " artificial data items with " + numFeatures + " features");
      double[][] allData = MakeAllData(numFeatures, numRows, seed);
      Console.WriteLine("Creating train and test matrices");
      double[][] trainData;
      double[][] testData;
      MakeTrainTest(allData, 0, out trainData, out testData);
      Console.WriteLine("Training data: ");
      ShowData(trainData, 4, 2, true);
      Console.WriteLine("Test data: ");
      ShowData(testData, 3, 2, true);
      Console.WriteLine("Creating LR binary classifier");
      LogisticClassifier lc = new LogisticClassifier(numFeatures);
      int maxEpochs = 1000;
      Console.WriteLine("Starting training using no regularization");
      double[] weights = lc.Train(trainData, maxEpochs,
        seed, 0.0, 0.0);
      Console.WriteLine("Best weights found:");
      ShowVector(weights, 3, weights.Length, true);
      double trainAccuracy = lc.Accuracy(trainData, weights);
      Console.WriteLine("Prediction accuracy on training data = " +
        trainAccuracy.ToString("F4"));
      double testAccuracy = lc.Accuracy(testData, weights);
      Console.WriteLine("Prediction accuracy on test data = " +
        testAccuracy.ToString("F4"));
      Console.WriteLine("Seeking good L1 weight");
      double alpha1 = lc.FindGoodL1Weight(trainData, seed);
      Console.WriteLine("L1 weight = " + alpha1.ToString("F3"));
      Console.WriteLine("Seeking good L2 weight");
      double alpha2 = lc.FindGoodL2Weight(trainData, seed);
      Console.WriteLine("L2 weight = " + alpha2.ToString("F3"));
      Console.WriteLine("Training with L1 regularization, " +
        "alpha1 = " + alpha1.ToString("F3"));
      weights = lc.Train(trainData, maxEpochs, seed, alpha1, 0.0);
      Console.WriteLine("Best weights found:");
      ShowVector(weights, 3, weights.Length, true);
      trainAccuracy = lc.Accuracy(trainData, weights);
      Console.WriteLine("Prediction accuracy on training data = " +
        trainAccuracy.ToString("F4"));
      testAccuracy = lc.Accuracy(testData, weights);
      Console.WriteLine("Prediction accuracy on test data = " +
        testAccuracy.ToString("F4"));
      Console.WriteLine("Training with L2 regularization, " +
        "alpha2 = " + alpha2.ToString("F3"));
      weights = lc.Train(trainData, maxEpochs, seed, 0.0, alpha2);
      Console.WriteLine("Best weights found:");
      ShowVector(weights, 3, weights.Length, true);
      trainAccuracy = lc.Accuracy(trainData, weights);
      Console.WriteLine("Prediction accuracy on training data = " +
        trainAccuracy.ToString("F4"));
      testAccuracy = lc.Accuracy(testData, weights);
      Console.WriteLine("Prediction accuracy on test data = " +
        testAccuracy.ToString("F4"));
      Console.WriteLine("End Regularization demo");
      Console.ReadLine();
    }
    static double[][] MakeAllData(int numFeatures,
      int numRows, int seed) { . . }
    static void MakeTrainTest(double[][] allData, int seed,
      out double[][] trainData, out double[][] testData) { . . }
    public static void ShowData(double[][] data, int numRows,
      int decimals, bool indices) { . . }
    public static void ShowVector(double[] vector, int decimals,
      int lineLen, bool newLine) { . . }
  }
  public class LogisticClassifier
  {
    private int numFeatures;
    private double[] weights;
    private Random rnd;
    public LogisticClassifier(int numFeatures) { . . }
    public double FindGoodL1Weight(double[][] trainData,
      int seed) { . . }
    public double FindGoodL2Weight(double[][] trainData,
      int seed) { . . }
    public double[] Train(double[][] trainData, int maxEpochs,
      int seed, double alpha1, double alpha2) { . . }
    private void Shuffle(int[] sequence) { . . }
    public double Error(double[][] trainData, double[] weights,
      double alpha1, double alpha2) { . . }
    public double ComputeOutput(double[] dataItem,
      double[] weights) { . . }
    public int ComputeDependent(double[] dataItem,
      double[] weights) { . . }
    public double Accuracy(double[][] trainData,
      double[] weights) { . . }
    public class Particle { . . }
  }
} // ns

テンプレート コードが Visual Studio エディターに読み込まれたら、内容がよくわかるようにソリューション エクスプローラー ウィンドウで Program.cs の名前を「RegularizationProgram.cs」に変更します。これにより、Visual Studio が自動的に Program クラスの名前を変更します。ソース コードの先頭で、不要な名前空間を指定する using ステートメントをすべて削除し、トップ レベルの System 名前空間への参照のみを残します。

ロジスティック回帰のロジックはすべて 1 つの LogisticClassifier クラスに含めています。LogisticClassifier クラスには、トレーニングに使用する最適化アルゴリズムの粒子群最適化 (PSO) をカプセル化するために、入れ子になった Particle ヘルパー クラスがあります。LogisticClassifier クラスには、alpha1 および alpha2 というパラメーターを受け取る Error メソッドを含めています。この 1 つのパラメーターが L1 正規化と L2 正規化用の重みです。

Main メソッドでは、次のステートメントを使用して合成データを作成します。

int numFeatures = 12;
int numRows = 1000;
int seed = 42;
double[][] allData = MakeAllData(numFeatures, numRows, seed);

シード値に 42 を使用したのは、単に代表的なデモの出力として優れた値が得られるという理由からです。MakeAllData メソッドは、-10.0 ~ +10.0 の間の 13 個のランダムな重み (フィーチャーごとに重みが 1 つずつと、b0 の重み) を生成します。その後、このメソッドは 1,000 回繰り返されます。この繰り返し処理では毎回、12 個のランダムな入力値のセットを生成後、このランダムな重みを使用してロジスティック回帰の中間出力値を計算します。別のランダム値を出力に加算して意味をなさないデータにし、オーバーフィットの可能性を高めます。

以下のステートメントでは、データをトレーニング用の 800 項目と、モデル評価用の 200 項目に分割しています。

double[][] trainData;
double[][] testData;
MakeTrainTest(allData, 0, out trainData, out testData);

ロジスティック回帰の予測モデルは、以下のステートメントで作成しています。

LogisticClassifier lc = new LogisticClassifier(numFeatures);
int maxEpochs = 1000;
double[] weights = lc.Train(trainData, maxEpochs, seed, 0.0, 0.0);
ShowVector(weights, 4, weights.Length, true);

変数 maxEpochs は、PSO トレーニング アルゴリズムのループ カウンターの上限です。Train メソッドに引数として渡している 2 つの 0.0 は、L1 正規化の重みと L2 正規化の重みです。このように重みを 0.0 に設定すると、正規化が使用されなくなります。モデルの品質は、以下の 2 つのステートメントで評価しています。

double trainAccuracy = lc.Accuracy(trainData, weights);
double testAccuracy = lc.Accuracy(testData, weights);

正規化を使用する場合に面倒なことの 1 つが、正規化の重みを決めなければならないことです。正規化の適切な重みを見つけるアプローチの 1 つは手作業で試行錯誤を繰り返すことですが、通常はプログラムを使用します。以下のステートメントで、適切な L1 正規化の重みを見つけ、モデルの品質を評価しています。

double alpha1 = lc.FindGoodL1Weight(trainData, seed);
weights = lc.Train(trainData, maxEpochs, seed, alpha1, 0.0);
trainAccuracy = lc.Accuracy(trainData, weights);
testAccuracy = lc.Accuracy(testData, weights);

L2 正規化を使用して LR 分類をトレーニングするステートメントは、L1 正規化を使用して LR 分類をトレーニングするステートメントと同様です。

double alpha2 = lc.FindGoodL2Weight(trainData, seed);
weights = lc.Train(trainData, maxEpochs, seed, 0.0, alpha2);
trainAccuracy = lc.Accuracy(trainData, weights);
testAccuracy = lc.Accuracy(testData, weights);

デモでは、LR オブジェクトのパブリック スコープ メソッド (FindGoodL1Weight と FindGoodL2Weight) を使用して、alpha1 と alpha2 の値を決定後、Train メソッドに渡しています。別の設計として、以下のコードを呼び出す方法もあります。

bool useL1 = true;
bool useL2 = false:
lc.Train(traiData, maxEpochs, useL1, useL2);

この設計アプローチでは、トレーニング メソッドが正規化の重みを決定するため、インターフェイスが若干わかりやすくなります。

正規化について

L1 正規化と L2 正規化は、モデルのオーバーフィットを減らす手法なので、正規化を理解するには、まずオーバーフィットについて理解する必要があります。大ざっぱに言うと、モデルを過剰にトレーニングすると、最終的にはトレーニング データに極めて正確にフィット (適合) する重みが得られますが、このトレーニング後のモデルを新しいデータに適用すると、予測精度が大幅に低下します。

オーバーフィットについて、図 3 の 2 つのグラフを使って説明します。最初のグラフは、赤の点と緑の点で示される 2 種類の項目を分類することを目標とする、仮定の状況を表しています。滑らかな青い曲線は 2 つのクラスの真の分割を表し、分割曲線の上に赤の点、下に緑の点が来ることを想定しています。データにはランダムな誤差が含まれるため、赤の点のうち 2 つが分割曲線の下に位置し、緑の点のうち 2 つが上に位置しているのがわかります。オーバーフィットが発生しない優れたトレーニングでは、滑らかな青い曲線に対応する重みが求められます。新しいデータ点として (3, 7) が与えられるとします。このデータ項目は曲線の上に配置されることになり、赤のクラスになると正しく予測されます。

モデルのオーバーフィット
図 3 モデルのオーバーフィット

図 3 の 2 つ目のグラフにも同じ点が含まれていますが、オーバーフィットが生じた結果、青い曲線が変化しています。今度は、赤い点はすべて曲線の上に位置し、緑の点はすべて曲線の下に位置するようになっています。ただし、曲線はかなり複雑な形状になります。今度は新しいデータ項目 (3, 7) が曲線の下に位置することになり、緑のクラスだと間違って予測されます。

オーバーフィットによって生じる予測曲線は滑らかではなくなります。つまり、「正規化」されません。このように不適切で複雑な予測曲線の特徴は、通常、重みが非常に大きな値になるか非常に小さな値になるかのいずれかです。したがって、オーバーフィットを減らす 1 つの方法は、モデルの重みが小さくなりすぎたり、大きくなりすぎないようにすることです。これが正規化の動機付けになります。

ML モデルをトレーニングする場合、誤差をなんらかの方法で測定し、適切な重みを決める必要があります。誤差を測定する方法はいくつかあります。最もよく使われる手法の 1 つが平均 2 乗誤差です。この手法では、重みの値のセットに対して計算した出力値とトレーニング データの既知の正確な出力値との差を 2 乗して合計し、その結果をトレーニング項目の数で除算します。たとえば、以下の 3 つのトレーニング項目の計算出力と正しい出力値 (期待値や目標値と呼ばれることもあります) を使って、ロジスティック回帰の重みの候補セットを求めるとします。

computed  desired
  0.60      1.0
  0.30      0.0
  0.80      1.0

次のように平均 2 乗誤差を求めます。

((0.6 - 1.0)^2 + (0.3 - 0.0)^2 + (0.8 - 1.0)^2) / 3 =
(0.16 + 0.09 + 0.04) / 3 =
0.097

数式で表現すると、平均 2 乗誤差は以下のようになります。

E = Sum(o - t)^2 / n

ここで Sum はすべてのトレーニング項目の累積合計、o は計算出力、t は目標出力、n はトレーニング データの項目数を表します。誤差は、約 12 個ある数値手法 (勾配降下、反復的なニュートン ラフソン、L-BFGS、逆伝播、群最適化など) の 1 つを使ったトレーニングによって最小化されます。

モデルの重み値が大きくならないように、正規化では、重み値を誤差項の計算に加算することで重み値にペナルティを課す、という考え方を取り入れています。最小化される誤差項の合計に重み値を含めると、重み値が小さくなるほど誤差値が小さくなります。L1 正規化の重みは、重み値の絶対値の合計を誤差項に加算することで、ペナルティを課します。数式で表すと以下のようになります。

E = Sum(o - t)^2 / n + Sum(Abs(w))

L2 正規化の重みは、2 乗した重み値の合計を誤差項に加算することで、ペナルティを課します。数式で表すと以下のようになります。

E = Sum(o - t)^2 / n + Sum(w^2)

たとえば、4 つの重みを決める必要があり、現在値が (2.0、-3.0、1.0、-4.0) だとします。平均 2 乗誤差 0.097 に加算される L1 正規化の重みのペナルティは、(2.0 + 3.0 + 1.0 + 4.0) = 10.0 です。L2 正規化の重みのペナルティは、2.0^2 + -3.0^2 + 1.0^2 + -4.0^2 = 4.0 + 9.0 + 1.0 + 16.0 = 30.0 です。

ここまでの説明をまとめると、モデルの重みが大きくなるほど、オーバーフィットが生じ、不適切な予測精度になる可能性があります。正規化は、モデルの誤差関数に重みのペナルティを加算することで、モデルの重みが大きくならないようにします。L1 正規化では、重みの絶対値の合計を使用されます。L2 正規化では、重みの値の 2 乗の合計が使用されます。

2 種類の正規化を使用する理由

L1 正規化と L2 正規化は似ています。どちらが優れているでしょう。特定の問題のシナリオでどちらの正規化が適切かを示した理論上のガイドラインはありますが、個人的には、どちらの正規化が適切かを実際に実験して判断するか、とにかく使用してみることをお勧めします。

L1 正規化を使用する場合、1 つ以上の重みを 0.0 にすることにより対応するフィーチャーを事実上除外できるという二次的な効果があります。これは、フィーチャーの選択という 1 つの形式です。たとえば、図 1 で実行しているデモは、L1 正規化を使用し、最後のモデルの重みを 0.0 にしています。つまり、最後の予測値は LR モデルに貢献しません。L2 正規化は、モデルの重み値を制限しますが、通常重みを 0.0 にしてもその重みが完全に取り除かれることはありません。

そのため、L1 正規化の方が L2 正規化よりも優れているように思えます。ただし、L1 正規化には、一部の ML トレーニング アルゴリズム (特に、微積分を使用して勾配を計算するアルゴリズム) では簡単に使用できないという問題点があります。L2 正規化は、すべての種類のトレーニング アルゴリズムで使用できます。

まとめると、L1 正規化は、関連する重みを 0.0 にして不要なフィーチャーを取り除くことができますが、機能しないトレーニングもあります。L2 正規化はすべてのトレーニングで機能しますが、暗黙のうちにフィーチャーを選択することはできません。実際のところ、特定の問題でどちらの形式の正規化が適切か (または両方とも適切でないか) を判断するには試行錯誤が必要になります。

正規化の実装

L1 正規化と L2 正規化を実装のは比較的簡単です。デモ プログラムでは、明示的な誤差関数と PSO トレーニングを併用するため、必要なのは L1 と L2 の重みのペナルティを加えることでだけです。Error メソッドの定義の先頭部分を以下に示します。

public double Error(double[][] trainData, double[] weights,
  double alpha1, double alpha2)
{
  int yIndex = trainData[0].Length - 1;
  double sumSquaredError = 0.0;
  for (int i = 0; i < trainData.Length; ++i)
  {
    double computed = ComputeOutput(trainData[i], weights);
    double desired = trainData[i][yIndex];
    sumSquaredError += (computed - desired) * (computed - desired);
  }
...

最初の手順では、計算出力値と目標出力の差の 2 乗を合計することで、平均 2 乗誤差を計算しています (誤差のもう 1 つの一般的な形式はクロス エントロピー誤差と呼ばれます)。次に、L1 ペナルティを計算しています。

double sumAbsVals = 0.0; // L1 penalty
for (int i = 0; i < weights.Length; ++i)
  sumAbsVals += Math.Abs(weights[i]);

次に、L2 ペナルティを計算しています。

double sumSquaredVals = 0.0; // L2 penalty
for (int i = 0; i < weights.Length; ++i)
  sumSquaredVals += (weights[i] * weights[i]);

Error メソッドは、MSE にペナルティを加算して返します。

...
  return (sumSquaredError / trainData.Length) +
    (alpha1 * sumAbsVals) +
    (alpha2 * sumSquaredVals);
}

デモでは、明示的な誤差関数を使用しています。トレーニング アルゴリズムの中には、勾配降下、逆伝播など、誤差関数の微積分偏導関数 (勾配) を計算することによって、暗黙のうちに誤差関数を使用するものもあります。このようなトレーニング アルゴリズムで L2 正規化を使用する場合は、(w^2 の導関数は 2w になるため)、2w 項を勾配に加算するだけです (詳細はもう少し複雑になります)。

適切な正規化の重みの算出

適切な (最適でなくてもよい) 正規化の重みを算出する方法はいくつかあります。デモ プログラムでは、候補値のセットを用意し、各候補に関連する誤差を計算し、求めた最適候補を返します。L1 の適切な重みを算出するメソッドの冒頭部分は以下のとおりです。

public double FindGoodL1Weight(double[][] trainData, int seed)
{
  double result = 0.0;
  double bestErr = double.MaxValue;
  double currErr = double.MaxValue;
  double[] candidates = new double[] { 0.000, 0.001, 0.005,
    0.010, 0.020, 0.050, 0.100, 0.150 };
  int maxEpochs = 1000;
  LogisticClassifier c =
    new LogisticClassifier(this.numFeatures);

別の候補を加算することで処理に時間がかかるようになりますが、正規化の最適な重みを算出する機会が増えます。次に、各候補を評価して、求めた最適候補を返しています。

for (int trial = 0; trial < candidates.Length; ++trial) {
    double alpha1 = candidates[trial];
    double[] wts = c.Train(trainData, maxEpochs, seed, alpha1, 0.0);
    currErr = Error(trainData, wts, 0.0, 0.0);
    if (currErr < bestErr) {
      bestErr = currErr; result = candidates[trial];
    }
  }
  return result;
}

候補となる正規化の重みは、評価分類機能のトレーニングに使用されますが、誤差は正規化の重みなしで計算されます。

まとめ

正規化は、数学方程式に基づくすべての ML 分類手法で使用されています。ロジスティック回帰、プロビット分類、ニューラル ネットワークなどがその例です。モデルの重み値の大きさが小さくなるため、正規化を重みの減衰と呼ぶこともあります。正規化を使用する主なメリットは、多くの場合、より正確なモデルにできることです。主なデメリットは、追加のパラメーター値 (正規化の重み) を決定する必要が生じることです。ロジスティック回帰の場合、通常学習率のパラメーターだけなのでそれほど問題ではありませんが、複雑な分類手法 (特にニューラル ネットワーク) を使用する場合、別のいわゆるハイパーパラメーターを追加することで、複数のパラメーターの組み合わせ値を調整するために膨大な追加作業が発生することがあります。


Dr. James McCaffrey は、ワシントン州レドモンドにある Microsoft Research に勤務しています。これまでに、Internet Explorer、Bing などの複数のマイクロソフト製品にも携わってきました。McCaffrey 博士の連絡先は、jammc@microsoft.com (英語のみ) です。

この記事のレビューに協力してくれた技術スタッフの Richard Hughes (Microsoft Research) に心より感謝いたします。