Python小练习:object类型数据加载
作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/
给定npy文件,用Python加载后,发现该数据类型dtype=object,本文介绍object类型数据的调用/加载方法,并将数据转化为图像,保存为png与gif格式。
所用数据pool.npy为:https://files-cdn.cnblogs.com/files/kailugaji/pool.rar?t=1681308366
1. object_load.py
1 # -*- coding: utf-8 -*- 2 # Author:凯鲁嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Python小练习:object类型数据加载 5 # 以强化学习经验回放池数据为例 6 # 数据来源:DeepMind Control Suite中的cheetah-run 7 # 在当前时刻状态下,智能体随机产生动作,与环境交互,得到下一步的状态与奖励 8 # 交互50次,得到由50个样本集组成的经验回放池:{s, a, s', r, ter} 9 # 分别表示:当前时刻状态、动作、下一步的状态、奖励、终止符 10 import numpy as np 11 import torchvision.transforms as transforms 12 import matplotlib.pyplot as plt 13 from matplotlib import animation 14 # DMControlEnv("cheetah","run") 15 16 def save_frames_as_gif(frames, path, index): 17 filename = 'gym_'+ index + '.gif' 18 patch = plt.imshow(frames[0]) 19 plt.axis('off') 20 def animate(i): 21 patch.set_data(frames[i]) 22 anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50, repeat = True, repeat_delay = 10) 23 anim.save(path + filename, writer='pillow', fps=60) 24 return anim 25 26 num = 32 27 dataset = np.load(r'./pool.npy') 28 print('数据类型:', 'dtype =', dataset.dtype) 29 # dtype=object 30 observations = dataset.item()['observations'] # (50, 9, 64, 64) 31 print('样本个数:', len(observations)) # 50 32 print('每个样本包含的键名称:', dataset.item().keys()) 33 # dict_keys(['observations', 'next_observations', 'actions', 'rewards', 'terminals']) 34 next_observations = dataset.item()['next_observations'] # (50, 9, 64, 64) 35 terminals = dataset.item()['terminals'] # (50, 1) 36 rewards = dataset.item()['rewards'] # (50, 1) 37 actions = dataset.item()['actions'] # (50, 6) 38 toPIL = transforms.ToPILImage() 39 frames = [] 40 fig = plt.figure(figsize=(15, 6)) 41 print('选取前%d个样本:'%num) 42 for j in range(num): 43 state = observations[j, 0:3, :, :].transpose((1, 2, 0)) 44 frames.append(state.astype(np.uint8)) 45 pic = toPIL(state.astype(np.uint8)) 46 plt.subplot(4, num//4, j+1) 47 plt.axis('off') 48 plt.imshow(pic) 49 print(j, 50 '\t奖励:', np.round(rewards[j], 3), 51 '\t动作:', np.round(actions[j], 3), 52 '\t终止符:', terminals[j]) 53 plt.savefig('cheetah-run.png', bbox_inches='tight', pad_inches=0.0, dpi=500) 54 plt.show() 55 save_frames_as_gif(frames, path = './', index = 'cheetah-run')
2. 结果
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/dict/object_load.py" 数据类型: dtype = object 样本个数: 50 每个样本包含的键名称: dict_keys(['observations', 'next_observations', 'actions', 'rewards', 'terminals']) 选取前32个样本: 0 奖励: [0.125] 动作: [-0.807 0.717 -0.953 0.181 -0.283 0.841] 终止符: [0.] 1 奖励: [0.099] 动作: [-0.449 -0.307 0.473 0.719 0.055 -0.44 ] 终止符: [0.] 2 奖励: [0.075] 动作: [ 0.985 -0.704 -0.039 -0.867 0.092 -0.714] 终止符: [0.] 3 奖励: [0.108] 动作: [ 0.128 0.358 -0.66 0.788 -0.447 0.014] 终止符: [0.] 4 奖励: [0.105] 动作: [-0.871 0.691 0.301 0.521 -0.547 0.144] 终止符: [0.] 5 奖励: [0.043] 动作: [-0.687 0.79 0.455 0.584 0.179 0.568] 终止符: [0.] 6 奖励: [0.] 动作: [-0.022 0.306 0.66 0.978 -0.361 -0.869] 终止符: [0.] 7 奖励: [0.] 动作: [ 0.503 0.017 0.505 -0.649 -0.205 -0.179] 终止符: [0.] 8 奖励: [0.] 动作: [ 0.993 -0.424 -0.48 -0.127 0.341 0.458] 终止符: [0.] 9 奖励: [0.] 动作: [ 0.486 0.229 -0.494 -0.417 -0.93 0.258] 终止符: [0.] 10 奖励: [0.] 动作: [ 0.505 -0.009 -0.047 -0.004 0.64 -0.223] 终止符: [0.] 11 奖励: [0.] 动作: [ 0.103 0.038 0.757 -0.764 -0.852 0.023] 终止符: [0.] 12 奖励: [0.] 动作: [-0.385 -0.62 0.126 0.046 0.135 0.871] 终止符: [0.] 13 奖励: [0.] 动作: [-0.661 -0.92 0.128 0.705 0.841 0.32 ] 终止符: [0.] 14 奖励: [0.] 动作: [ 0.515 0.011 -0.085 -0.863 0.69 -0.899] 终止符: [0.] 15 奖励: [0.] 动作: [-0.16 0.08 0.342 -0.675 0.873 0.13 ] 终止符: [0.] 16 奖励: [0.] 动作: [-0.221 -0.102 0.862 -0.151 0.938 0.122] 终止符: [0.] 17 奖励: [0.] 动作: [ 0.915 0.735 -0.297 0.357 0.613 0.363] 终止符: [0.] 18 奖励: [0.] 动作: [ 0.752 -0.251 -0.505 -0.525 0.76 0.026] 终止符: [0.] 19 奖励: [0.] 动作: [-0.907 0.056 0.108 -0.921 -0.164 -0.508] 终止符: [0.] 20 奖励: [0.] 动作: [-0.522 -0.065 -0.66 -0.229 0.88 0.583] 终止符: [0.] 21 奖励: [0.] 动作: [-0.011 -0.137 0.209 0.014 -0.079 0.236] 终止符: [0.] 22 奖励: [0.] 动作: [-0.663 0.654 -0.068 -0.728 0.537 -0.359] 终止符: [0.] 23 奖励: [0.] 动作: [-0.602 -0.122 -0.313 -0.798 0.354 -0.558] 终止符: [0.] 24 奖励: [0.] 动作: [-0.667 0.071 0.508 -0.219 -0.007 0.041] 终止符: [0.] 25 奖励: [0.] 动作: [ 0.993 0.028 -0.229 0.809 0.502 0.281] 终止符: [0.] 26 奖励: [0.] 动作: [ 0.335 0.411 -0.902 -0.487 -0.564 0.109] 终止符: [0.] 27 奖励: [0.] 动作: [-0.509 -0.607 0.294 -0.391 0.997 0.134] 终止符: [0.] 28 奖励: [0.] 动作: [ 0.312 0.554 0.741 -0.098 -0.257 -0.768] 终止符: [0.] 29 奖励: [0.] 动作: [-0.855 -0.576 -0.122 -0.714 -0.436 -0.335] 终止符: [0.] 30 奖励: [0.] 动作: [ 0.797 0.024 -0.432 -0.378 -0.555 0.935] 终止符: [0.] 31 奖励: [0.] 动作: [ 0.768 0.445 0.59 -0.977 0.51 0.796] 终止符: [0.] Process finished with exit code 0


3. 参考文献
[2] Liu Q, Zhou Q, Yang R, et al. Robust Representation Learning by Clustering with Bisimulation Metrics for Visual Reinforcement Learning with Distractions[C]. AAAI, 2023.