1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
| import numpy as np
import matplotlib.pyplot as plt
# 상태가치함수 v_k-1을 입력으로 새로운 상태의 값을 계산하는 함수
def get_value(v_prev, row, col, pi, gamma=1.0):
# N, E, S, W
north_value = pi[0] * ( -1 + gamma * v_prev[max(0, row - 1), col])
east_value = pi[1] * ( -1 + gamma * v_prev[row, min(v_prev.shape[1] - 1, col + 1)])
south_value = pi[2] * ( -1 + gamma * v_prev[min(v_prev.shape[0] - 1, row + 1), col])
west_value = pi[3] * ( -1 + gamma * v_prev[row, max(0, col - 1)])
# 밸류의 합 계산
value = north_value + south_value + east_value + west_value;
return value
# 상태가치함수 v_k-1을 입력으로 새로운 상태가치함수 v_k를 만드는 함수
def update_value(v_prev, pi, gamma=1.0):
v_new = np.zeros_like(v_prev)
for row in range(v_new.shape[0]):
for col in range(v_new.shape[1]):
if row == v_new.shape[0] - 1 and col == v_new.shape[1] - 1:
pass
else:
v_new[row, col] = get_value(v_prev, row, col, pi, gamma)
return v_new
def plot_value(V, title="Value Function"):
plt.figure(figsize=(4, 4))
im = plt.imshow(V, cmap="viridis")
plt.colorbar(im, fraction=0.046, pad=0.04)
# 좌표계 정리
# plt.gca().invert_yaxis()
plt.xticks(range(V.shape[1]))
plt.yticks(range(V.shape[0]))
plt.xlabel("col")
plt.ylabel("row")
plt.title(title)
# 값 숫자로 표시 (디버깅에 매우 유용)
for row in range(V.shape[0]):
for col in range(V.shape[1]):
plt.text(
col, row,
f"{V[row, col]:.2f}",
ha="center", va="center",
color="white" if V[row, col] < np.mean(V) else "black"
)
plt.tight_layout()
plt.show()
def main():
# N, E, S, W
pi = [0.25, 0.25, 0.25, 0.25]
gamma = 1.0
# v_0 초기화
v_0 = np.zeros((4, 4))
v_prev = v_0
for k in range(1000):
v_new = update_value(v_prev, pi, gamma)
print(f'-------- {k+1:03d} --------')
print(np.round(v_new, 2))
if np.sum(np.abs(v_new - v_prev)) < 0.01:
break
v_prev = v_new
plot_value(v_prev)
if __name__ == "__main__":
main()
|