Coverage for hiphive/input_output/read_write_files.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1""" 

2Helper functions for reading and writing objects to tar files 

3""" 

4 

5import pickle 

6import tempfile 

7import warnings 

8from tarfile import TarFile 

9import ase.io as aseIO 

10from ase import Atoms 

11with warnings.catch_warnings(): 

12 warnings.filterwarnings('ignore', category=FutureWarning) 

13 import h5py 

14 

15 

16def add_ase_atoms_to_tarfile(tar_file: TarFile, 

17 atoms: Atoms, 

18 arcname: str, 

19 format: str = 'json'): 

20 """ Adds an ase.Atoms object to tar_file. 

21 

22 Parameters 

23 ---------- 

24 tar_file 

25 tar file to write to 

26 atoms 

27 atomic configuration 

28 arcname 

29 name of field in archive 

30 format 

31 format to use for writing (see ase) 

32 """ 

33 temp_file = tempfile.NamedTemporaryFile() 

34 aseIO.write(temp_file.name, atoms, format=format) 

35 temp_file.seek(0) 

36 tar_info = tar_file.gettarinfo(arcname=arcname, fileobj=temp_file) 

37 tar_file.addfile(tar_info, temp_file) 

38 

39 

40def read_ase_atoms_from_tarfile(tar_file: TarFile, 

41 arcname: str, 

42 format: str = 'json') -> Atoms: 

43 """ Reads ase.Atoms from tar file. 

44 

45 Parameters 

46 ---------- 

47 tar_file 

48 tar file from which to read 

49 arcname 

50 name of field in archive 

51 format 

52 format to use for writing (see ase) 

53 """ 

54 temp_file = tempfile.NamedTemporaryFile() 

55 temp_file.write(tar_file.extractfile(arcname).read()) 

56 temp_file.seek(0) 

57 atoms = aseIO.read(temp_file.name, format=format) 

58 return atoms 

59 

60 

61def add_items_to_tarfile_hdf5(tar_file, items, arcname): 

62 """ Add items to one hdf5 file """ 

63 temp_file = tempfile.NamedTemporaryFile() 

64 hf = h5py.File(temp_file.name, 'w') 

65 for key, value in items.items(): 

66 hf.create_dataset(key, data=value, compression='gzip') 

67 hf.close() 

68 temp_file.seek(0) 

69 tar_info = tar_file.gettarinfo(arcname=arcname, fileobj=temp_file) 

70 tar_file.addfile(tar_info, temp_file) 

71 temp_file.close() 

72 

73 

74def add_items_to_tarfile_pickle(tar_file, items, arcname): 

75 """ Add items by pickling them """ 

76 temp_file = tempfile.TemporaryFile() 

77 pickle.dump(items, temp_file) 

78 temp_file.seek(0) 

79 tar_info = tar_file.gettarinfo(arcname=arcname, fileobj=temp_file) 

80 tar_file.addfile(tar_info, temp_file) 

81 temp_file.close() 

82 

83 

84def add_items_to_tarfile_custom(tar_file, items): 

85 """ Add items assuming they have a custom write function """ 

86 for key, value in items.items(): 

87 temp_file = tempfile.TemporaryFile() 

88 value.write(temp_file) 

89 temp_file.seek(0) 

90 tar_info = tar_file.gettarinfo(arcname=key, fileobj=temp_file) 

91 tar_file.addfile(tar_info, temp_file) 

92 temp_file.close() 

93 

94 

95def add_list_to_tarfile_custom(tar_file, objects, arcname): 

96 """ Add list of objects assuming they have a custom write function """ 

97 for i, obj in enumerate(objects): 

98 temp_file = tempfile.TemporaryFile() 

99 obj.write(temp_file) 

100 temp_file.seek(0) 

101 fname = '{}_{}'.format(arcname, i) 

102 tar_info = tar_file.gettarinfo(arcname=fname, fileobj=temp_file) 

103 tar_file.addfile(tar_info, temp_file) 

104 temp_file.close() 

105 

106 

107def read_items_hdf5(tar_file, arcname): 

108 """ Read items from hdf5file inside tar_file """ 

109 

110 # read hdf5 

111 temp_file = tempfile.NamedTemporaryFile() 

112 temp_file.write(tar_file.extractfile(arcname).read()) 

113 temp_file.seek(0) 

114 hf = h5py.File(temp_file.name, 'r') 

115 items = {key: value[:] for key, value in hf.items()} 

116 hf.close() 

117 return items 

118 

119 

120def read_items_pickle(tar_file, arcname): 

121 items = dict() 

122 items = pickle.load(tar_file.extractfile(arcname)) 

123 return items 

124 

125 

126def read_list_custom(tar_file, arcname, read_function, **kwargs): 

127 objects = [] 

128 i = 0 

129 while True: 

130 try: 

131 fname = '{}_{}'.format(arcname, i) 

132 f = tar_file.extractfile(fname) 

133 obj = read_function(f, **kwargs) 

134 objects.append(obj) 

135 f.close() 

136 except KeyError: 

137 break 

138 i += 1 

139 return objects