• 정책평가는 정책π 가 고정된 상태에서 상태가치함수 V(s)를 계산하는 것입니다.
  • 그리고 이 과정을 V(s) 변화량이 아주 작아질 때까지 반복하는 것입니다.
  • V(s)의 값이 대칭적으로 나와야 하는데.. 왜 다르게 나오는지 모르겠네요. 문제가 있는 것 같은데 나중에 찾으면 고칠게요

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import matplotlib.pyplot as plt

# 상태가치함수 v_k-1을 입력으로 새로운 상태의 값을 계산하는 함수 
def get_value(v_prev, row, col, pi, gamma=1.0):
    # N, E, S, W
    north_value = pi[0] * ( -1 + gamma * v_prev[max(0, row - 1), col])
    east_value = pi[1] * ( -1 + gamma * v_prev[row, min(v_prev.shape[1] - 1, col + 1)])
    south_value = pi[2] * ( -1 + gamma * v_prev[min(v_prev.shape[0] - 1, row + 1), col])
    west_value = pi[3] * ( -1 + gamma * v_prev[row, max(0, col - 1)])
    # 밸류의 합 계산
    value = north_value + south_value + east_value + west_value;
    return value

# 상태가치함수 v_k-1을 입력으로 새로운 상태가치함수 v_k를 만드는 함수
def update_value(v_prev, pi, gamma=1.0):
    v_new = np.zeros_like(v_prev)
    for row in range(v_new.shape[0]):
        for col in range(v_new.shape[1]):
            if row == v_new.shape[0] - 1 and col == v_new.shape[1] - 1:
                pass
            else:
                v_new[row, col] = get_value(v_prev, row, col, pi, gamma)
    return v_new

def plot_value(V, title="Value Function"):
    plt.figure(figsize=(4, 4))
    im = plt.imshow(V, cmap="viridis")
    plt.colorbar(im, fraction=0.046, pad=0.04)

    # 좌표계 정리
    # plt.gca().invert_yaxis()
    plt.xticks(range(V.shape[1]))
    plt.yticks(range(V.shape[0]))
    plt.xlabel("col")
    plt.ylabel("row")
    plt.title(title)

    # 값 숫자로 표시 (디버깅에 매우 유용)
    for row in range(V.shape[0]):
        for col in range(V.shape[1]):
            plt.text(
                col, row,
                f"{V[row, col]:.2f}",
                ha="center", va="center",
                color="white" if V[row, col] < np.mean(V) else "black"
            )

    plt.tight_layout()
    plt.show()

def main():
    # N, E, S, W
    pi = [0.25, 0.25, 0.25, 0.25]
    gamma = 1.0

    # v_0 초기화
    v_0 = np.zeros((4, 4))

    v_prev = v_0
    for k in range(1000):
        v_new = update_value(v_prev, pi, gamma)
        print(f'-------- {k+1:03d} --------')
        print(np.round(v_new, 2))
        if np.sum(np.abs(v_new - v_prev)) < 0.01:
            break
        v_prev = v_new

    plot_value(v_prev)

if __name__ == "__main__":
    main()