数理最適化セミナーのご案内

2.18 ロジスティック回帰モデル

 本章では,ロジスティック回帰モデルによるパラメータ推定方法を,アヤメの識別問題を例に紹介します.

例題

 同じアヤメ科のIris setosaとIris versicolorに関して,がくの長さ,がくの幅,花弁の長さ,花弁の幅に関する各々50個のデータが存在する.これらの4つを説明変数として,Iris setosaとIris versicolorを判別するようなロジスティック回帰モデルを構築せよ.

 この問題の定式化は次のようになります.

集合
$I$ サンプル集合
$J$ 説明変数集合
 
定数
$d_{i,j}, i \in I, j \in J$ 計測データ
$k_{i}, i \in I$ アヤメの種類(0:Iris setosa,1:Iris versicolor)
 
変数
$a0$

$a_{j}, j \in J$
推定したいパラメータ
${x}_{i}, i\in I$(中間変数)

$\displaystyle \left( x_i = a0 + \sum_j a_j d_{ij} \right)$
 
 
目的関数(最大化)
$\displaystyle \sum_i (k_i x_i - \log(1 + \exp(x_i)))$ 対数尤度

 サンプル集合の要素$i \in I$に対し,ロジスティック関数$L(x_i)$は,以下のような式で表わすことができます.

\[L(x_i) = \frac{\exp (x_i)}{1 + \exp (x_i)}\]

ただし$x_i$は,$\displaystyle x_i = a0 + \sum_j a_j d_{ij}$で表わされるものとします.

 次に,ロジスティック関数を用いた$a0$, $a_j$のパラメータ推定方法の説明を行います.対象データ全てに対する尤もらしさを考え,それが最大となるような$a0$, $a_j$を求めます.即ち,目的関数$f$

\[f = \prod_i L(x_i)^{k_i} (1 - L(x_i))^{1 - k_i}\]

で定義し,これを最大化する問題を考えることにより,$a0$, $a_j$の推定を行います.

 なお,このままでは目的関数の形状が複雑なため,目的関数$f$に対して対数をとった$\tilde{f}$

\begin{align*} \tilde{f} &= \log f \\ &  = \sum_{i} (k_{i} \log L(x_{i}) + (1 - k_{i}) \log (1 - L(x_{i}))) \\ &= \sum_{i} \left( k_{i} \log \frac{\exp (x_{i})}{1 + \exp (x_{i})} + (1 - k_{i}) \log \frac{1}{1 + \exp (x_{i})} \right) \\ &= \sum_{i} \left( k_{i} \log (\exp (x_{i})) + \log \frac{1}{1 + \exp (x_{i})} \right) \\ &= \sum_{i} (k_{i} x_{i} - \log (1 + \exp (x_{i}))) \end{align*}

で目的関数を置きなおしても一般性は失われない性質を利用して,対象となる問題の置き換えを行います.

 以上が定式化の説明となります.

 次に,アヤメの計測データですが,これは[5]の文献にあるデータ(Iris setosa,Iris versicolor,各々50個)を使用します.

 具体的には,以下のような2種類のcsvファイルを用意します.

計測データ, がくの長さ, がくの幅, 花弁の長さ, 花弁の幅
1, 5.1, 3.5, 1.4, 0.2
2, 4.9, 3, 1.4, 0.2
3, 4.7, 3.2, 1.3, 0.2
4, 4.6, 3.1, 1.5, 0.2
5, 5, 3.6, 1.4, 0.2
6, 5.4, 3.9, 1.7, 0.4
7, 4.6, 3.4, 1.4, 0.3
8, 5, 3.4, 1.5, 0.2
9, 4.4, 2.9, 1.4, 0.2
10, 4.9, 3.1, 1.5, 0.1
11, 5.4, 3.7, 1.5, 0.2
12, 4.8, 3.4, 1.6, 0.2
13, 4.8, 3, 1.4, 0.1
14, 4.3, 3, 1.1, 0.1
15, 5.8, 4, 1.2, 0.2
16, 5.7, 4.4, 1.5, 0.4
17, 5.4, 3.9, 1.3, 0.4
18, 5.1, 3.5, 1.4, 0.3
19, 5.7, 3.8, 1.7, 0.3
20, 5.1, 3.8, 1.5, 0.3
21, 5.4, 3.4, 1.7, 0.2
22, 5.1, 3.7, 1.5, 0.4
23, 4.6, 3.6, 1, 0.2
24, 5.1, 3.3, 1.7, 0.5
25, 4.8, 3.4, 1.9, 0.2
26, 5, 3, 1.6, 0.2
27, 5, 3.4, 1.6, 0.4
28, 5.2, 3.5, 1.5, 0.2
29, 5.2, 3.4, 1.4, 0.2
30, 4.7, 3.2, 1.6, 0.2
31, 4.8, 3.1, 1.6, 0.2
32, 5.4, 3.4, 1.5, 0.4
33, 5.2, 4.1, 1.5, 0.1
34, 5.5, 4.2, 1.4, 0.2
35, 4.9, 3.1, 1.5, 0.2
36, 5, 3.2, 1.2, 0.2
37, 5.5, 3.5, 1.3, 0.2
38, 4.9, 3.6, 1.4, 0.1
39, 4.4, 3, 1.3, 0.2
40, 5.1, 3.4, 1.5, 0.2
41, 5, 3.5, 1.3, 0.3
42, 4.5, 2.3, 1.3, 0.3
43, 4.4, 3.2, 1.3, 0.2
44, 5, 3.5, 1.6, 0.6
45, 5.1, 3.8, 1.9, 0.4
46, 4.8, 3, 1.4, 0.3
47, 5.1, 3.8, 1.6, 0.2
48, 4.6, 3.2, 1.4, 0.2
49, 5.3, 3.7, 1.5, 0.2
50, 5, 3.3, 1.4, 0.2
51, 7, 3.2, 4.7, 1.4
52, 6.4, 3.2, 4.5, 1.5
53, 6.9, 3.1, 4.9, 1.5
54, 5.5, 2.3, 4, 1.3
55, 6.5, 2.8, 4.6, 1.5
56, 5.7, 2.8, 4.5, 1.3
57, 6.3, 3.3, 4.7, 1.6
58, 4.9, 2.4, 3.3, 1
59, 6.6, 2.9, 4.6, 1.3
60, 5.2, 2.7, 3.9, 1.4
61, 5, 2, 3.5, 1
62, 5.9, 3, 4.2, 1.5
63, 6, 2.2, 4, 1
64, 6.1, 2.9, 4.7, 1.4
65, 5.6, 2.9, 3.6, 1.3
66, 6.7, 3.1, 4.4, 1.4
67, 5.6, 3, 4.5, 1.5
68, 5.8, 2.7, 4.1, 1
69, 6.2, 2.2, 4.5, 1.5
70, 5.6, 2.5, 3.9, 1.1
71, 5.9, 3.2, 4.8, 1.8
72, 6.1, 2.8, 4, 1.3
73, 6.3, 2.5, 4.9, 1.5
74, 6.1, 2.8, 4.7, 1.2
75, 6.4, 2.9, 4.3, 1.3
76, 6.6, 3, 4.4, 1.4
77, 6.8, 2.8, 4.8, 1.4
78, 6.7, 3, 5, 1.7
79, 6, 2.9, 4.5, 1.5
80, 5.7, 2.6, 3.5, 1
81, 5.5, 2.4, 3.8, 1.1
82, 5.5, 2.4, 3.7, 1
83, 5.8, 2.7, 3.9, 1.2
84, 6, 2.7, 5.1, 1.6
85, 5.4, 3, 4.5, 1.5
86, 6, 3.4, 4.5, 1.6
87, 6.7, 3.1, 4.7, 1.5
88, 6.3, 2.3, 4.4, 1.3
89, 5.6, 3, 4.1, 1.3
90, 5.5, 2.5, 4, 1.3
91, 5.5, 2.6, 4.4, 1.2
92, 6.1, 3, 4.6, 1.4
93, 5.8, 2.6, 4, 1.2
94, 5, 2.3, 3.3, 1
95, 5.6, 2.7, 4.2, 1.3
96, 5.7, 3, 4.2, 1.2
97, 5.7, 2.9, 4.2, 1.3
98, 6.2, 2.9, 4.3, 1.3
99, 5.1, 2.5, 3, 1.1
100, 5.7, 2.8, 4.1, 1.3
i, 種類
1, 1
2, 1
3, 1
4, 1
5, 1
6, 1
7, 1
8, 1
9, 1
10, 1
11, 1
12, 1
13, 1
14, 1
15, 1
16, 1
17, 1
18, 1
19, 1
20, 1
21, 1
22, 1
23, 1
24, 1
25, 1
26, 1
27, 1
28, 1
29, 1
30, 1
31, 1
32, 1
33, 1
34, 1
35, 1
36, 1
37, 1
38, 1
39, 1
40, 1
41, 1
42, 1
43, 1
44, 1
45, 1
46, 1
47, 1
48, 1
49, 1
50, 1
51, 0
52, 0
53, 0
54, 0
55, 0
56, 0
57, 0
58, 0
59, 0
60, 0
61, 0
62, 0
63, 0
64, 0
65, 0
66, 0
67, 0
68, 0
69, 0
70, 0
71, 0
72, 0
73, 0
74, 0
75, 0
76, 0
77, 0
78, 0
79, 0
80, 0
81, 0
82, 0
83, 0
84, 0
85, 0
86, 0
87, 0
88, 0
89, 0
90, 0
91, 0
92, 0
93, 0
94, 0
95, 0
96, 0
97, 0
98, 0
99, 0
100, 0

 以上をもとにC++SIMPLEで記述すると,以下のようになります.

// 集合と添字
Set I(name = "サンプル集合");
Element i(set = I);
Set J(name = "説明変数集合");
Element j(set = J);

// パラメータ
Parameter d(name = "計測データ", index = (i, j));
Parameter k(name = "種類", index = i); // アヤメの種類(0 : Iris setosa, 1 : Iris versicolor)

// 変数
Variable a(name = "a", index = j);
Variable a0(name = "a0");
Variable x(name = "x", index = i); // 中間変数

// 制約条件
x[i] == a0 + sum(a[j] * d[i, j], j);

// 目的関数
Objective f(name = "対数尤度", type = maximize);
f = sum(k[i] * x[i] - log(1 + exp(x[i])), i);

// 求解
solve();

// 結果出力
a0.val.print();
a.val.print();

 このモデルを実行すると,以下のような解が得られます.

a0 = -3.37151
a["がくの長さ"] = 6.43568
a["がくの幅"] = 5.8427
a["花弁の長さ"] = -13.2033
a["花弁の幅"] = -18.566

 

 

上に戻る