Coverage for hiphive/input_output/pretty_table_prints.py: 97%
54 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1from itertools import product
2import numpy as np
5def _obj2str(a, none_char='-'):
6 """ Casts object a to str. """
7 if isinstance(a, float):
8 # if the float is 2.49999999 then round
9 if str(a)[::-1].find('.') > 5: 9 ↛ 10line 9 didn't jump to line 10 because the condition on line 9 was never true
10 return '{:.5f}'.format(a)
11 elif a is None:
12 return none_char
13 return str(a)
16_array2str = np.vectorize(_obj2str)
19def print_table(matrix: np.ndarray,
20 include_sum: bool = False) -> None:
21 """ Prints matrix data in a nice table format.
23 The matrix element matrix[i][j] should correspond to information about
24 order j+2 and n-body i+1.
26 Example
27 --------
28 >> matrix = numpy.array([[None, None], [4.0, 3.0]])
29 >> print_table(matrix)
31 body/order | 2 | 3
32 ------------------------
33 1 | - | -
34 2 | 4.0 | 3.0
36 Parameters
37 ----------
38 matrix
39 matrix to be printed
40 include_sum
41 whether or not to print the sum along each row and column
42 """
43 table_str = table_array_to_string(matrix, include_sum)
44 print(table_str)
47def table_array_to_string(matrix: np.ndarray,
48 include_sum: bool = False) -> str:
49 """Generate nice table string from a numpy array with floats/ints.
51 Parameters
52 ----------
53 matrix
54 matrix to be printed
55 include_sum
56 whether or not to print the sum along each row and column
57 """
58 table_array = _generate_table_array(matrix, include_sum)
59 table_array_str = _array2str(table_array)
60 table_str = _generate_table_str(table_array_str)
61 return table_str
64def _generate_table_array(table_array: np.ndarray,
65 include_sum: bool = False):
66 """Generate table in numpy array format.
68 Parameters
69 ----------
70 table_array
71 matrix to be printed
72 include_sum
73 whether or not to print the sum along each row and column
74 """
76 # initialze table
77 n_rows, n_cols = table_array.shape
78 A = _build_table_frame(order=n_cols+1, nbody=n_rows, include_sum=include_sum)
80 # fill table
81 for order, nbody in product(range(2, n_cols+2), range(1, n_rows+1)):
82 if nbody <= order:
83 A[nbody, order-1] = table_array[nbody-1, order-2]
85 if include_sum:
86 for i, row in enumerate(A[1:-1, 1:-1], start=1):
87 A[i, -1] = sum(val for val in row if val is not None)
88 for i, col in enumerate(A[1:-1, 1:-1].T, start=1):
89 A[-1, i] = sum(val for val in col if val is not None)
90 A[-1, -1] = ''
92 return A
95def _generate_table_str(table_array: np.ndarray) -> str:
96 """Generate a string from a numpy array of strings.
98 Parameters
99 ----------
100 table_array
101 matrix to be printed
102 """
103 table_str = []
104 n_rows, n_cols = table_array.shape
106 # find maximum widths for each column
107 widths = []
108 for i in range(n_cols):
109 widths.append(max(len(val) for val in table_array[:, i])+2)
111 # formatting str for each row
112 row_format = '|'.join('{:^'+str(width)+'}' for width in widths)
114 # finalize
115 for i in range(n_rows):
116 if i == 1:
117 table_str.append('-' * (sum(widths)+n_cols-1))
118 table_str.append(row_format.format(*table_array[i, :]))
119 table_str = '\n'.join(table_str)
120 return table_str
123def _build_table_frame(order: int,
124 nbody: int,
125 include_sum: bool = False):
126 """ Builds/initializes table/array. """
127 if include_sum:
128 A = np.empty((nbody+2, order+1), dtype='object')
129 A[0, -1] = 'sum'
130 A[-1, 0] = 'sum'
131 else:
132 A = np.empty((nbody+1, order), dtype='object')
134 A[0][0] = 'body/order'
135 A[0, 1:order] = range(2, order+1)
136 A[1:nbody+1, 0] = range(1, nbody+1)
137 return A
140if __name__ == '__main__':
141 # input dummy cutoff table
142 # insert row for nbody=1
143 cutoffs = np.array([[None, None, None, None, None],
144 [6.0, 6.0, 6.0, 3.7, 3.7],
145 [5.0, 5.0, 5.0, 3.0, 3.0],
146 [3.7, 3.7, 3.7, 0.0, 0.0]])
148 # input dummy cluster count table
149 cluster_counts = np.array([[1, 3, 5, 5, 2],
150 [12, 22, 39, 42, 58],
151 [19, 41, 123, 421, 912],
152 [42, 112, 410, 617, 3271]])
154 print_table(cutoffs)
155 print('\n')
156 print_table(cluster_counts, include_sum=False)
157 print('\n')
158 print_table(cluster_counts, include_sum=True)