The Environment: ViZDoom
Golem operates within a Partially Observable Markov Decision Process (POMDP) defined by the DOOM engine. We utilize ViZDoom as the API bridge to extract observations and inject actions. Formally, this POMDP is defined by the tuple \((\mathcal{S},\mathcal{A},\mathcal{T},\mathcal{R},\Omega,\mathcal{O},\gamma)\).
Observation Space
While the true underlying engine state \(s_t\in\mathcal{S}\) contains exact entity coordinates and internal variables, the agent's observation \(o_t\in\Omega\) at time \(t\) is strictly constrained to its egocentric sensory field. With the introduction of multi-modal sensor fusion, this observation space scales dynamically based on the active configuration.
The primary visual and spatial tensor \(o_{vis}\) is defined as:
\[
o_{vis}\in\mathbb{R}^{C\times64\times64}
\]
- Channels (\(C\)): 3 (RGB) by default, expanding to 4 if the stereoscopic depth buffer is enabled.
- Resolution: 64x64 pixels (processed via bilinear interpolation).
- Normalization: \(o_{i,j,k}\in[0,1]\).
If the auditory sensor is enabled, the agent also receives an audio tensor \(o_{aud}\). The raw audio is extracted from the engine as high-frequency stereo waveforms. To guarantee strict network stability, the raw buffer is immediately scaled to zero-mean and unit-variance during extraction.
To maximize GPU hardware utilization and prevent dataloader bottlenecks, these normalized 1D arrays are passed directly into the Liquid Neural Network (LNN). The LNN mathematically transforms them on-the-fly into dense 2D time-frequency representations (Mel Spectrograms) to leverage spatial locality within its parallel 2D Convolutional Auditory Cortex.
The Digital Signal Processing (DSP) transformation is mathematically defined by the active dsp configuration block:
- Mel Scale Transformation: The normalized waveform is processed via a GPU-accelerated Short-Time Fourier Transform (STFT) mapped to the Mel scale, governed by the
sample_rate, n_fft, hop_length, and n_mels hyperparameters.
- Decibel Scaling: The resulting magnitudes are compressed logarithmically using an Amplitude-to-DB conversion.
While the raw extracted buffer is \(\mathbb{R}^{2\times N_{samples}}\), the resulting phenomenological multi-modal representation evaluated by the internal network is defined as:
\[
o_{aud}\in\mathbb{R}^{C\times H_{mels}\times W_{time}}
\]
Where \(C=2\) (stereo channels), \(H_{mels}\) represents the frequency bins dictated by n_mels, and \(W_{time}\) is the temporal width calculated dynamically from the engine's audio buffer capacity and the STFT hop_length.
If the thermal sensor is enabled, the agent also receives a discrete thermal tensor \(o_{thm}\). Extracted via ViZDoom's semantic segmentation labels_buffer, this modality isolates active, dynamic entities (e.g., monsters, projectiles, and interactive items) from the static environmental background plane.
The transformation pipeline applies a strict binary threshold operation (\(o_{i,j}=1\text{ if }\text{label}_{i,j}>0\text{ else }0\)) to the raw buffer and subsequently downsamples the mask to \(64\times64\) utilizing nearest-neighbor interpolation to prevent edge anti-aliasing artifacts that would otherwise blur distinct entity boundaries.
The resulting thermal mask tensor is defined as:
\[
o_{thm}\in\{0,1\}^{1\times64\times64}
\]
We explicitly discard latent game variables (e.g., Health, Ammo, Coordinates) from the observation vector to force the model to learn multi-modal heuristics (e.g., a "red screen tint" implies damage, a specific visual spectrogram pattern implies a nearby threat, or a binary thermal cluster denotes a dynamic entity), encouraging robust topological generalization across unseen levels.
Action Space
Unlike standard architectures that utilize a rigid output structure, Golem's action space is a discrete, multi-label domain dynamically conditioned on the active environment configuration profile, denoted as \(\rho\).
Let \(\rho\in\{\text{basic},\text{classic},\text{fluid}\}\). The dimensionality of the action space \(n_\rho\) expands or contracts based on the superset defined by \(\rho\):
- Basic Profile (\(n_{\text{basic}}=8\)): \(\mathcal{A}_{\text{basic}}=\{\text{Fwd},\text{Back},\text{MoveL},\text{MoveR},\text{TurnL},\text{TurnR},\text{Attack},\text{Use}\}\)
- Fluid Profile (\(n_{\text{fluid}}=9\)): \(\mathcal{A}_{\text{fluid}}=\mathcal{A}_{\text{basic}}\cup\{\text{NextWeapon}\}\)
- Classic Profile (\(n_{\text{classic}}=10\)): \(\mathcal{A}_{\text{classic}}=\mathcal{A}_{\text{basic}}\cup\{\text{Weapon2},\text{Weapon3}\}\)
At any time step \(t\), the output vector \(y_t\) is drawn from a Multivariate Bernoulli distribution over these actions. Assuming conditional independence between individual key presses given the latent state representation, the network predicts the probability vector \(\mathbf{p}_t\), yielding:
\[
y_t\in\{0,1\}^{n_\rho}
\]
During inference, this distribution is thresholded at \(0.5\) to produce the deterministic binary vector fed back into the ViZDoom engine. This dynamic scaling prevents data sparsity and gradient dilution that would occur if unused weapon keys were permanently mapped to the output layer during non-combat tasks.
Temporal Dynamics & Network Synchronization
The environment runs at a fixed tic rate of 35 Hz (\(\Delta t\approx28.5\text{ms}\)). Because standard RNNs are discrete, they struggle with the asynchronous loop of live gameplay. Golem utilizes Liquid Time-Constant (LTC) networks to model the hidden state \(x(t)\) as an Ordinary Differential Equation (ODE).
To maintain strict temporal consistency between the continuous differential solver and the discrete game clock across a distributed architecture, Golem relies on Deterministic Lockstep networking via ViZDoom's Mode.PLAYER (Sync Mode).
Instead of relying on manual sleep heuristics (which are prone to drift and desynchronization), the central Host Server dictates the flow of time. The Host collects asynchronous ticcmds (action vectors) from all connected clients. The local ViZDoom engine inside the Golem container explicitly blocks execution until it receives the synchronized broadcast back from the Host. This guarantees that the numerical integration steps within the LNN perfectly align with the simulated passage of time within the multiplayer POMDP, regardless of hardware inference speeds.
During training, these temporal dynamics are strictly simulated and enforced by the StatefulStratifiedBatchSampler, which ensures that batch sequences remain perfectly contiguous, preventing the mathematical amnesia that occurs when standard dataloaders improperly shuffle temporal boundaries.
API Reference
The extraction and temporal windowing of the observation space is handled dynamically by the streaming dataset module and the custom sampling pipeline.
Bases: Dataset
A PyTorch Dataset for loading and streaming DOOM gameplay sequences.
This class loads raw frame and action arrays from compressed .npz files into memory. Rather than copying the data to create individual sequence tensors, it builds a lightweight pointer map. During training, it slices continuous arrays into overlapping sequences of length seq_len on-the-fly.
It supports multi-modal sensor fusion, dynamically yielding a dictionary of active sensory tensors (visual, depth, audio, and thermal masks) alongside the target action vectors. It also supports dynamic horizontal mirror augmentation to double the effective dataset size while mitigating left/right turning bias.
Parameters:
| Name |
Type |
Description |
Default |
data_dir
|
str or Path
|
The directory containing the .npz training files.
|
required
|
seq_len
|
int
|
The temporal length of the sequence window to slice. Default: 32.
|
32
|
file_pattern
|
str
|
The glob pattern used to locate training files within data_dir. Default: "*.npz".
|
'*.npz'
|
augment
|
bool
|
If True, dynamically mirrors visual and thermal frames horizontally and swaps corresponding left/right action labels. Default: False.
|
False
|
action_names
|
list of str
|
The ordered list of string action names (e.g., ["MOVE_FORWARD", "TURN_LEFT", ...]) used to calculate which indices to swap during mirror augmentation. Default: None.
|
None
|
dsp_config
|
DSPConfig
|
DSP tuning parameters for the Mel Spectrogram.
|
None
|
sensors
|
|
|
None
|
Source code in app/models/dataset.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244 | class DoomStreamingDataset(Dataset):
r"""
A PyTorch Dataset for loading and streaming DOOM gameplay sequences.
This class loads raw frame and action arrays from compressed ``.npz`` files into memory. Rather than copying the data to create individual sequence tensors, it builds a lightweight pointer map. During training, it slices continuous arrays into overlapping sequences of length ``seq_len`` on-the-fly.
It supports multi-modal sensor fusion, dynamically yielding a dictionary of active sensory tensors (visual, depth, audio, and thermal masks) alongside the target action vectors. It also supports dynamic horizontal mirror augmentation to double the effective dataset size while mitigating left/right turning bias.
Args:
data_dir (str or Path): The directory containing the ``.npz`` training files.
seq_len (int, optional): The temporal length of the sequence window to slice. Default: ``32``.
file_pattern (str, optional): The glob pattern used to locate training files within ``data_dir``. Default: ``"*.npz"``.
augment (bool, optional): If ``True``, dynamically mirrors visual and thermal frames horizontally and swaps corresponding left/right action labels. Default: ``False``.
action_names (list of str, optional): The ordered list of string action names (e.g., ``["MOVE_FORWARD", "TURN_LEFT", ...]``) used to calculate which indices to swap during mirror augmentation. Default: ``None``.
dsp_config (DSPConfig, optional): DSP tuning parameters for the Mel Spectrogram.
sensors: TODO
"""
def __init__(self, data_dir, seq_len=32, file_pattern="*.npz",
augment=False, action_names=None, dsp_config=None, sensors=None):
self.seq_len = seq_len
self.augment = augment
self.action_names = action_names or []
self.dsp_config = dsp_config
self.sensors = sensors
# Memory stores
self.video_arrays = []
self.action_arrays = []
self.depth_arrays = []
self.audio_arrays = []
self.thermal_arrays = []
# Enforce configuration-driven modalities instead of dynamic inference
self.has_depth = getattr(self.sensors, 'depth', False) if self.sensors else False
self.has_audio = getattr(self.sensors, 'audio', False) if self.sensors else False
self.has_thermal = getattr(self.sensors, 'thermal', False) if self.sensors else False
self.index_map = []
self.base_episodes = [] # <-- NEW: Stores lists of contiguous base indices
self.recovery_episodes = [] # <-- NEW: Stores lists of contiguous recovery indices
self.swap_pairs = []
if self.augment and self.action_names:
self._build_swap_map()
# Handle data_dir as a single string, Path, or a list of them
if isinstance(data_dir, (str, Path)):
search_dirs = [Path(data_dir)]
else:
search_dirs = [Path(d) for d in data_dir]
files = []
for d in search_dirs:
if d.exists():
files.extend(list(d.glob(file_pattern)))
files = sorted(files)
for file_idx, file_path in enumerate(files):
is_recovery = "recovery" in str(file_path).lower() # <-- Identify origin
with np.load(file_path) as data:
frames = data['frames']
actions = data['actions']
self.video_arrays.append(frames)
self.action_arrays.append(actions)
if self.has_depth:
if 'depths' in data:
self.depth_arrays.append(data['depths'])
else:
raise ValueError(f"Config requires 'depth', but {file_path.name} is missing 'depths' array.")
if self.has_audio:
if 'audios' in data:
self.audio_arrays.append(data['audios'])
else:
raise ValueError(f"Config requires 'audio', but {file_path.name} is missing 'audios' array.")
if self.has_thermal:
if 'thermals' in data:
self.thermal_arrays.append(data['thermals'])
else:
raise ValueError(f"Config requires 'thermal', but {file_path.name} is missing 'thermals' array.")
total_frames = len(frames)
if total_frames < self.seq_len:
continue
episode_indices_normal = []
episode_indices_mirrored = []
# Build pointers and group them into contiguous episodes
for start_idx in range(0, total_frames - self.seq_len + 1, self.seq_len):
is_first = (start_idx == 0)
global_idx = len(self.index_map)
self.index_map.append({
'file_idx': file_idx,
'start_idx': start_idx,
'is_mirrored': False,
'is_first': is_first
})
episode_indices_normal.append(global_idx)
if self.augment:
global_idx_aug = len(self.index_map)
self.index_map.append({
'file_idx': file_idx,
'start_idx': start_idx,
'is_mirrored': True,
'is_first': is_first
})
episode_indices_mirrored.append(global_idx_aug)
# Register completed episodes to their respective pools
if episode_indices_normal:
if is_recovery: self.recovery_episodes.append(episode_indices_normal)
else: self.base_episodes.append(episode_indices_normal)
if self.augment and episode_indices_mirrored:
if is_recovery: self.recovery_episodes.append(episode_indices_mirrored)
else: self.base_episodes.append(episode_indices_mirrored)
logger.info(f"Dataset mapped to RAM: {len(self.base_episodes)} base episodes, {len(self.recovery_episodes)} recovery episodes. Modalities: [Visual: True, Depth: {self.has_depth}, Audio: {self.has_audio}, Thermal: {self.has_thermal}]")
def _build_swap_map(self):
r"""
Constructs a mapping of action indices that must be swapped when applying horizontal mirror augmentation.
This ensures that when a spatial tensor (visual or thermal) is visually flipped, an action like ``TURN_LEFT`` correctly transforms into ``TURN_RIGHT`` in the target vector.
"""
try:
self.swap_pairs.append((self.action_names.index("MOVE_LEFT"), self.action_names.index("MOVE_RIGHT")))
except ValueError: pass
try:
self.swap_pairs.append((self.action_names.index("TURN_LEFT"), self.action_names.index("TURN_RIGHT")))
except ValueError: pass
def __len__(self):
r"""
Returns the total number of sliding window sequences available.
Returns:
int: The total sequence count, including augmented sequences if enabled.
"""
return len(self.index_map)
def __getitem__(self, idx):
r"""
Retrieves a temporal sequence of frames and corresponding actions by index.
The retrieved visual frames are dynamically transposed from the storage shape of :math:`(H, W, C)` to the PyTorch convolutional shape of :math:`(C, H, W)`.
If the index maps to an augmented sequence, the visual and thermal tensors are flipped horizontally and lateral actions are swapped. For audio, the waveforms are converted to Mel Spectrograms and spatial auditory channels are swapped.
Args:
idx (int): The index of the sequence pointer in the internal map.
Returns:
tuple: A tuple containing:
- dict: A dictionary of active sensory inputs:
- ``'visual'`` (Tensor): Visual frames of shape :math:`(\text{seq\_len}, C, 64, 64)`.
- ``'audio'`` (Tensor, optional): Mel spectrograms of shape :math:`(\text{seq\_len}, 2, H_{mels}, W_{time})`.
- ``'thermal'`` (Tensor, optional): Binary thermal masks of shape :math:`(\text{seq\_len}, 1, 64, 64)`.
- Tensor: A sequence of action vectors of shape :math:`(\text{seq\_len}, \text{n\_actions})`.
"""
meta = self.index_map[idx]
file_idx, start_idx, is_mirrored, is_first = meta['file_idx'], meta['start_idx'], meta['is_mirrored'], meta['is_first']
window_frames = self.video_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_actions = self.action_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_frames = np.transpose(window_frames, (0, 3, 1, 2))
x_vis = torch.from_numpy(window_frames).float()
y = torch.from_numpy(window_actions).float()
if self.has_depth:
window_depths = self.depth_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_depths = np.expand_dims(window_depths, axis=1)
x_depth = torch.from_numpy(window_depths).float()
x_vis = torch.cat((x_vis, x_depth), dim=1)
x_aud = None
if self.has_audio:
window_audios = self.audio_arrays[file_idx][start_idx : start_idx + self.seq_len]
x_aud = torch.from_numpy(window_audios).float()
x_thm = None
if self.has_thermal:
window_thermals = self.thermal_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_thermals = np.expand_dims(window_thermals, axis=1)
x_thm = torch.from_numpy(window_thermals).float()
if is_mirrored:
x_vis = torch.flip(x_vis, [3])
if self.has_audio:
# Flip channel 1 (stereo channels) for spatial auditory swapping
x_aud = torch.flip(x_aud, [1])
if self.has_thermal:
x_thm = torch.flip(x_thm, [3])
y_flip = y.clone()
for left_idx, right_idx in self.swap_pairs:
y_flip[:, left_idx] = y[:, right_idx]
y_flip[:, right_idx] = y[:, left_idx]
inputs = {'visual': x_vis}
if self.has_audio:
inputs['audio'] = x_aud
if self.has_thermal:
inputs['thermal'] = x_thm
inputs['is_first'] = torch.tensor([is_first], dtype=torch.bool)
return inputs, y_flip
inputs = {'visual': x_vis}
if self.has_audio:
inputs['audio'] = x_aud
if self.has_thermal:
inputs['thermal'] = x_thm
inputs['is_first'] = torch.tensor([is_first], dtype=torch.bool)
return inputs, y
|
__getitem__(idx)
Retrieves a temporal sequence of frames and corresponding actions by index.
The retrieved visual frames are dynamically transposed from the storage shape of :math:(H, W, C) to the PyTorch convolutional shape of :math:(C, H, W).
If the index maps to an augmented sequence, the visual and thermal tensors are flipped horizontally and lateral actions are swapped. For audio, the waveforms are converted to Mel Spectrograms and spatial auditory channels are swapped.
Parameters:
| Name |
Type |
Description |
Default |
idx
|
int
|
The index of the sequence pointer in the internal map.
|
required
|
Returns:
| Name | Type |
Description |
tuple |
|
A tuple containing:
- dict: A dictionary of active sensory inputs:
- 'visual' (Tensor): Visual frames of shape :math:(\text{seq\_len}, C, 64, 64).
- 'audio' (Tensor, optional): Mel spectrograms of shape :math:(\text{seq\_len}, 2, H_{mels}, W_{time}).
- 'thermal' (Tensor, optional): Binary thermal masks of shape :math:(\text{seq\_len}, 1, 64, 64).
- Tensor: A sequence of action vectors of shape :math:(\text{seq\_len}, \text{n\_actions}).
|
Source code in app/models/dataset.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244 | def __getitem__(self, idx):
r"""
Retrieves a temporal sequence of frames and corresponding actions by index.
The retrieved visual frames are dynamically transposed from the storage shape of :math:`(H, W, C)` to the PyTorch convolutional shape of :math:`(C, H, W)`.
If the index maps to an augmented sequence, the visual and thermal tensors are flipped horizontally and lateral actions are swapped. For audio, the waveforms are converted to Mel Spectrograms and spatial auditory channels are swapped.
Args:
idx (int): The index of the sequence pointer in the internal map.
Returns:
tuple: A tuple containing:
- dict: A dictionary of active sensory inputs:
- ``'visual'`` (Tensor): Visual frames of shape :math:`(\text{seq\_len}, C, 64, 64)`.
- ``'audio'`` (Tensor, optional): Mel spectrograms of shape :math:`(\text{seq\_len}, 2, H_{mels}, W_{time})`.
- ``'thermal'`` (Tensor, optional): Binary thermal masks of shape :math:`(\text{seq\_len}, 1, 64, 64)`.
- Tensor: A sequence of action vectors of shape :math:`(\text{seq\_len}, \text{n\_actions})`.
"""
meta = self.index_map[idx]
file_idx, start_idx, is_mirrored, is_first = meta['file_idx'], meta['start_idx'], meta['is_mirrored'], meta['is_first']
window_frames = self.video_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_actions = self.action_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_frames = np.transpose(window_frames, (0, 3, 1, 2))
x_vis = torch.from_numpy(window_frames).float()
y = torch.from_numpy(window_actions).float()
if self.has_depth:
window_depths = self.depth_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_depths = np.expand_dims(window_depths, axis=1)
x_depth = torch.from_numpy(window_depths).float()
x_vis = torch.cat((x_vis, x_depth), dim=1)
x_aud = None
if self.has_audio:
window_audios = self.audio_arrays[file_idx][start_idx : start_idx + self.seq_len]
x_aud = torch.from_numpy(window_audios).float()
x_thm = None
if self.has_thermal:
window_thermals = self.thermal_arrays[file_idx][start_idx : start_idx + self.seq_len]
window_thermals = np.expand_dims(window_thermals, axis=1)
x_thm = torch.from_numpy(window_thermals).float()
if is_mirrored:
x_vis = torch.flip(x_vis, [3])
if self.has_audio:
# Flip channel 1 (stereo channels) for spatial auditory swapping
x_aud = torch.flip(x_aud, [1])
if self.has_thermal:
x_thm = torch.flip(x_thm, [3])
y_flip = y.clone()
for left_idx, right_idx in self.swap_pairs:
y_flip[:, left_idx] = y[:, right_idx]
y_flip[:, right_idx] = y[:, left_idx]
inputs = {'visual': x_vis}
if self.has_audio:
inputs['audio'] = x_aud
if self.has_thermal:
inputs['thermal'] = x_thm
inputs['is_first'] = torch.tensor([is_first], dtype=torch.bool)
return inputs, y_flip
inputs = {'visual': x_vis}
if self.has_audio:
inputs['audio'] = x_aud
if self.has_thermal:
inputs['thermal'] = x_thm
inputs['is_first'] = torch.tensor([is_first], dtype=torch.bool)
return inputs, y
|
__len__()
Returns the total number of sliding window sequences available.
Returns:
| Name | Type |
Description |
int |
|
The total sequence count, including augmented sequences if enabled.
|
Source code in app/models/dataset.py
160
161
162
163
164
165
166
167 | def __len__(self):
r"""
Returns the total number of sliding window sequences available.
Returns:
int: The total sequence count, including augmented sequences if enabled.
"""
return len(self.index_map)
|
Bases: Sampler
Custom Sampler designed to maintain continuous temporal streams across
batch dimensions for Stateful Backpropagation Through Time (BPTT).
Dynamically mixes base expert trajectories with recovery (DAgger) trajectories
to prevent catastrophic forgetting.
Source code in app/models/dataset.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331 | class StatefulStratifiedBatchSampler(Sampler):
"""
Custom Sampler designed to maintain continuous temporal streams across
batch dimensions for Stateful Backpropagation Through Time (BPTT).
Dynamically mixes base expert trajectories with recovery (DAgger) trajectories
to prevent catastrophic forgetting.
"""
def __init__(self, base_episodes, recovery_episodes, batch_size, recovery_ratio=0.25):
self.base_episodes = base_episodes
self.recovery_episodes = recovery_episodes
self.batch_size = batch_size
# Calculate stream allocations
self.n_recovery = int(batch_size * recovery_ratio) if recovery_episodes else 0
self.n_base = batch_size - self.n_recovery
if self.n_base <= 0:
raise ValueError("Batch size is too small or recovery ratio is too high to maintain base streams.")
def __iter__(self):
# 1. Initialize shuffled pools for the epoch
base_pool = list(self.base_episodes)
rec_pool = list(self.recovery_episodes)
random.shuffle(base_pool)
random.shuffle(rec_pool)
active_streams = []
# 2. Allocate initial temporal streams
for i in range(self.batch_size):
is_rec = (i < self.n_recovery)
pool = rec_pool if is_rec else base_pool
original = self.recovery_episodes if is_rec else self.base_episodes
if not pool and original: # Refill immediately if pool is unexpectedly small
pool.extend(original)
random.shuffle(pool)
if pool:
active_streams.append(iter(pool.pop()))
else:
active_streams.append(None)
# 3. Yield continuous batches
while True:
batch = []
for i in range(self.batch_size):
is_rec = (i < self.n_recovery)
pool = rec_pool if is_rec else base_pool
original = self.recovery_episodes if is_rec else self.base_episodes
stream = active_streams[i]
if stream is None:
continue
try:
idx = next(stream)
batch.append(idx)
except StopIteration:
# Stream exhausted. Fetch a new episode.
if not pool and original:
# Endlessly recycle recovery data.
# Exhausting the base pool signals the end of the epoch.
if is_rec:
pool.extend(original)
random.shuffle(pool)
if pool:
new_stream = iter(pool.pop())
active_streams[i] = new_stream
batch.append(next(new_stream))
else:
active_streams[i] = None
# Drop the last batch if it cannot fully populate the parallel streams
if len(batch) == self.batch_size:
yield batch
else:
break
def __len__(self):
if self.n_base == 0:
return 0
total_base_seqs = sum(len(ep) for ep in self.base_episodes)
return total_base_seqs // self.n_base
|