Skip to content

Vectorize per-channel PCA transform in run_for_all_spikes#4488

Open
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:feat/vectorize-per-channel-pca
Open

Vectorize per-channel PCA transform in run_for_all_spikes#4488
galenlynch wants to merge 2 commits intoSpikeInterface:mainfrom
galenlynch:feat/vectorize-per-channel-pca

Conversation

@galenlynch
Copy link
Copy Markdown

@galenlynch galenlynch commented Apr 1, 2026

Problem

_all_pc_extractor_chunk calls pca_model[chan_ind].transform() once per spike per channel in a Python loop:

for i in range(i0, i1):
    wf = traces[st - start - nbefore : st - start + nafter, :]
    for c, chan_ind in enumerate(chan_inds):
        w = wf[:, chan_ind]
        all_pcs[i, :, c] = pca_model[chan_ind].transform(w[None, :])

For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels, this is ~260M individual sklearn transform calls, each on a 1×210 matrix. The Python/sklearn per-call overhead dominates.

Solution

Batch all spikes within a chunk by channel and call transform once per channel:

  1. Extract all valid waveform snippets in the chunk at once using vectorized fancy indexing
  2. Group spikes by channel index across all units
  3. Call pca_model[chan_ind].transform(wfs_batch) once per channel with the full batch

In by_channel_local mode the PCA model is per-channel (not per-unit), so all spikes on a given channel share the same model regardless of unit identity.

Benchmarks

Synthetic (500 spikes, 10 channels, 50-sample waveforms)

Time Speedup
Original 0.126s
Vectorized 0.002s 53x

Results match: max absolute difference 9.5e-7 (float rounding).

Real data (Neuropixels probe, 5 minutes of recording from local .dat)

  • 706K total spikes
  • 379 channels, 26 sparse channels per unit, 210-sample waveforms
  • n_jobs=1, chunk_duration=10s
Time Per chunk Speedup
Original 548.7s (9.1 min) 18.3s/chunk
Vectorized 110.3s (1.8 min) 2.1s/chunk 5.0x
Vectorized + numpy grouping 108.2s 2.0s/chunk 5.6x

Results match: max absolute difference 1.49e-08. np.allclose confirms identical output.

The real-data speedup is lower than synthetic because disk I/O, waveform extraction, and memory allocation are shared costs. The optimization only affects the transform call overhead, which is ~80% of chunk time in the original code.

Projected impact

For a full 69-minute recording, PC extraction drops from ~60 min to ~12 min.

RAM impact

Minimal. For a 10s chunk with ~25K spikes, the batch waveform array is ~25K × 210 × 4 bytes ≈ 20 MB per channel call, reused across channels. Peak additional memory vs. original: ~50 MB.

Changes

  • _all_pc_extractor_chunk in principal_component.py: replaced per-spike per-channel loop with vectorized batch-by-channel approach
  • No API changes, no new dependencies
  • All existing test_principal_component.py tests pass
  • A follow-up optimization replaces the per-spike Python dict loop (building the channel→spike mapping) with numpy unit grouping: loop over unique unit indices and use vectorized boolean masks. This reduces Python iterations from n_spikes × n_sparse_channels (~600K) to n_units × n_sparse_channels (~10K) per chunk.

Fixes #4485
Related: #979

galenlynch and others added 2 commits April 1, 2026 10:43
`_all_pc_extractor_chunk` calls `pca_model[chan_ind].transform()` once per spike
per channel in a Python loop.

For a ~70-minute Neuropixels recording with ~10M spikes and ~26 sparse channels,
this is ~260M individual sklearn `transform` calls, each on a 1×210 matrix. The
Python/sklearn per-call overhead dominates.

This commit improves performance by batching all spikes within a chunk by
channel:

1. Extract all valid waveform snippets in the chunk at once using vectorized
   fancy indexing
2. Group spikes by channel index across all units
3. Call `pca_model[chan_ind].transform(wfs_batch)` once per channel with the
   full batch

For synthetic data of 500 spikes, 10 channels, with 50-sample waveforms, this
improves performance 53x, from 0.126s to 0.002s, with max absolute difference in
projections of 9.5e-7.

For an integration benchmark extracting PCs from a 5-minute, 379 channel .dat in
RAM with 706k spikes and 26 sparse channels x 210-sample waveforms, the
vectorization improves performance 5x, from 9.1 minutes to 1.8 minutes. Max
absolute difference between the two paths was 1.49e-08.
Instead of iterating over every spike to build the channel-to-spike
mapping, loop over unique unit indices and use vectorized boolean
masks. Reduces Python loop iterations from n_spikes*n_sparse_channels
(~600K) to n_units*n_sparse_channels (~10K) per chunk.

5.0x → 5.6x speedup vs original on real Neuropixels data.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@galenlynch
Copy link
Copy Markdown
Author

The test failure is unrelated. Seems like a problem with cross-correlagram tests and fast_mode?

@alejoe91 alejoe91 added postprocessing Related to postprocessing module performance Performance issues/improvements labels Apr 2, 2026
# valid_mask tracks which spikes have valid (in-bounds) waveforms
chunk_spike_times = spike_times[i0:i1]
offsets = chunk_spike_times - start - nbefore
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me probably misreading but I'm super bad at parsing these type of > and < in general. If we have to be less than or = to the shape couldn't we run into an issue where we are = to the shape which is out of bounds?

ie an array of (4,5) the shape[0] = 4, but if I try to index on 4 it will be an out of bounds error. Again I don't work on the PC code at all so maybe I'm completely wrong here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues/improvements postprocessing Related to postprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

run_for_all_spikes: vectorize per-channel PCA transform

3 participants