Advertisement
El_Chaderino

source localization module

May 17th, 2025
361
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.19 KB | None | 0 0
  1. # In modules/source_localization.py (new module)
  2. #github.com/ElChaderino/The-Squiggle-Interpreter
  3.  
  4. import os
  5. import mne
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from pathlib import Path
  9. from modules.plotting import plot_source_estimate
  10. import logging
  11.  
  12. logger = logging.getLogger(__name__)
  13.  
  14. def setup_forward_solution(raw, subject="fsaverage", subjects_dir=None, ico=4, conductivity=(0.3, 0.006, 0.3)):
  15.     """
  16.    Set up forward solution using the fsaverage subject.
  17.    """
  18.     if subjects_dir is None:
  19.         fs_dir = mne.datasets.fetch_fsaverage(verbose=True)
  20.         subjects_dir = os.path.dirname(fs_dir)
  21.     montage = mne.channels.make_standard_montage("standard_1020")
  22.     raw.set_montage(montage, match_case=False)
  23.    
  24.     # Create a source space.
  25.     src = mne.setup_source_space(subject, spacing="oct6", subjects_dir=subjects_dir, add_dist=False)
  26.    
  27.     # Create BEM model and solution.
  28.     bem_model = mne.make_bem_model(subject=subject, ico=ico, conductivity=conductivity, subjects_dir=subjects_dir)
  29.     bem_solution = mne.make_bem_solution(bem_model)
  30.    
  31.     # Compute forward solution.
  32.     fwd = mne.make_forward_solution(raw.info, trans="fsaverage", src=src, bem=bem_solution,
  33.                                     eeg=True, meg=False, verbose=False)
  34.     return fwd, src, bem_solution, subjects_dir
  35.  
  36. def compute_noise_covariance(epochs, tmax=0.0):
  37.     """
  38.    Compute noise covariance from epochs (using the pre-stimulus period).
  39.    """
  40.     cov = mne.compute_covariance(epochs, tmax=tmax, method="empirical", verbose=False)
  41.     return cov
  42.  
  43. def compute_inverse_operator(raw, fwd, cov, loose=0.2, depth=0.8):
  44.     """
  45.    Construct an inverse operator.
  46.    """
  47.     inv_op = mne.minimum_norm.make_inverse_operator(raw.info, fwd, cov, loose=loose, depth=depth, verbose=False)
  48.     return inv_op
  49.  
  50. def apply_inverse_for_band(evoked, inv_op, lambda2=1.0/9.0, method="sLORETA"):
  51.     """
  52.    Apply inverse solution using the specified method.
  53.    
  54.    Parameters:
  55.      evoked (mne.Evoked): The evoked response (or pseudo-ERP).
  56.      inv_op: Inverse operator.
  57.      lambda2 (float): Regularization parameter.
  58.      method (str): "sLORETA", "MNE" (for LORETA-like, adjust parameters), etc.
  59.      
  60.    Returns:
  61.      mne.SourceEstimate: The source estimate.
  62.    """
  63.     stc = mne.minimum_norm.apply_inverse(evoked, inv_op, lambda2=lambda2,
  64.                                          method=method, pick_ori=None, verbose=False)
  65.     return stc
  66.  
  67. def compute_source_localization(raw, band_range, method, tmin, tmax, fwd, inv_op):
  68.     """
  69.    Filter raw data to a frequency band, compute epochs/pseudo-ERP, apply the inverse operator,
  70.    and return the source estimate.
  71.    
  72.    Parameters:
  73.      raw (mne.io.Raw): Raw EEG data.
  74.      band_range (tuple): Frequency band (fmin, fmax).
  75.      method (str): Inverse method, e.g., "sLORETA" or "MNE".
  76.      tmin, tmax (float): Time window for epochs.
  77.      fwd: Forward solution.
  78.      inv_op: Inverse operator.
  79.      
  80.    Returns:
  81.      mne.SourceEstimate: The computed source estimate.
  82.    """
  83.     # Bandpass filter the raw data to the band of interest.
  84.     raw_band = raw.copy().filter(band_range[0], band_range[1], verbose=False)
  85.     events = mne.make_fixed_length_events(raw_band, duration=tmax-tmin)
  86.     epochs = mne.Epochs(raw_band, events, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
  87.     evoked = epochs.average()
  88.     stc = apply_inverse_for_band(evoked, inv_op, method=method)
  89.     return stc
  90.  
  91. def save_source_estimate_topomap(stc, subjects_dir, subject, output_path, time_point=0.1, hemi="both", colormap="hot"):
  92.     """
  93.    Generate and save a screenshot of the source estimate topomap at a specific time point.
  94.    
  95.    Parameters:
  96.      stc (mne.SourceEstimate): Source estimate.
  97.      subjects_dir (str): Directory for subject MRI data.
  98.      subject (str): Subject name.
  99.      output_path (str): File path to save the image.
  100.      time_point (float): Time point to display.
  101.      hemi (str): Hemisphere to display.
  102.      colormap (str): Colormap to use.
  103.    """
  104.     brain = stc.plot(hemi=hemi, subjects_dir=subjects_dir, subject=subject,
  105.                      surface="inflated", time_viewer=False, colormap=colormap,
  106.                      smoothing_steps=10, show=False)
  107.     brain.set_time(time_point)
  108.     brain.save_image(output_path)
  109.     brain.close()
  110.  
  111. # Example function to loop over all bands and methods for a given raw data.
  112. def compute_and_save_source_maps(raw, methods, output_base, tmin=0.0, tmax=0.4):
  113.     """
  114.    For each frequency band in BANDS and for each specified inverse method, compute
  115.    the source estimate and save the topomap image.
  116.    
  117.    Parameters:
  118.      raw (mne.io.Raw): Raw EEG data.
  119.      methods (list): List of inverse methods, e.g., ["sLORETA", "MNE"].
  120.      output_base (str): Base output directory to save images.
  121.      tmin, tmax (float): Time window for epochs.
  122.    """
  123.     # Set up forward model.
  124.     fwd, src, bem_solution, subjects_dir = setup_forward_solution(raw)
  125.     # Compute noise covariance from raw data's fixed-length epochs (using tmax=0 for pre-stimulus).
  126.     events = mne.make_fixed_length_events(raw, duration=tmax-tmin)
  127.     epochs = mne.Epochs(raw, events, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
  128.     cov = compute_noise_covariance(epochs, tmax=0.0)
  129.     inv_op = compute_inverse_operator(raw, fwd, cov)
  130.    
  131.     subject = "fsaverage"  # Or your subject name.
  132.    
  133.     for band, band_range in BANDS.items():
  134.         for method in methods:
  135.             stc = compute_source_localization(raw, band_range, method, tmin, tmax, fwd, inv_op)
  136.             out_dir = os.path.join(output_base, method, band)
  137.             os.makedirs(out_dir, exist_ok=True)
  138.             # Save a topomap screenshot at a specific time point (e.g., 0.1 sec).
  139.             out_path = os.path.join(out_dir, f"topomap_{band}_{method}.png")
  140.             try:
  141.                 save_source_estimate_topomap(stc, subjects_dir, subject, out_path, time_point=0.1)
  142.                 logger.info(f"Saved {method} topomap for {band} to {out_path}")
  143.             except Exception as e:
  144.                 logger.warning(f"Error saving {method} topomap for {band}: {e}")
  145.  
  146. def compute_and_save_ica_source_maps(raw, ica, inv_op, methods, output_base, subjects_dir, cond_name, time_point=0.1):
  147.     """
  148.    For each ICA component, compute source localization using the given methods and save the topomap image.
  149.    Args:
  150.        raw (mne.io.Raw): Raw EEG data (for info).
  151.        ica (mne.preprocessing.ICA): Fitted ICA object.
  152.        inv_op: Inverse operator.
  153.        methods (dict): Dict of method labels, e.g., {"LORETA": "MNE", ...}.
  154.        output_base (str): Base output directory to save images.
  155.        subjects_dir (str): MNE subjects_dir for MRI.
  156.        cond_name (str): "EO" or "EC".
  157.        time_point (float): Time point for plotting.
  158.    Returns:
  159.        dict: {method: [list of rel_paths per component]}
  160.    """
  161.     results = {m: [] for m in methods}
  162.     if ica is None or inv_op is None:
  163.         return results
  164.     # Get ICA sources (components x time)
  165.     sources = ica.get_sources(raw).get_data()
  166.     n_components = sources.shape[0]
  167.     info = raw.info
  168.     subject_folder = Path(output_base).parent
  169.     for method, method_label in methods.items():
  170.         out_dir = Path(output_base) / cond_name / f"ICA_{method}"
  171.         out_dir.mkdir(parents=True, exist_ok=True)
  172.         for comp_idx in range(n_components):
  173.             comp_ts = sources[comp_idx]
  174.             # Project component back to sensor space (channels used in ICA)
  175.             sensor_proj = ica.mixing_matrix_[:, comp_idx][:, np.newaxis] @ comp_ts[np.newaxis, :]
  176.             # Now, ensure this matches the full channel count in info['ch_names']
  177.             n_channels = len(info['ch_names'])
  178.             # Get the channel names used in ICA (should match mixing_matrix rows)
  179.             ica_ch_names = ica.ch_names if hasattr(ica, 'ch_names') else info['ch_names'][:sensor_proj.shape[0]]
  180.             full_sensor_proj = np.zeros((n_channels, sensor_proj.shape[1]))
  181.             assigned_channels = []
  182.             skipped_channels = []
  183.             for i, ch in enumerate(ica_ch_names):
  184.                 if ch in info['ch_names']:
  185.                     idx = info['ch_names'].index(ch)
  186.                     if i < sensor_proj.shape[0] and idx < full_sensor_proj.shape[0]:
  187.                         full_sensor_proj[idx] = sensor_proj[i]
  188.                         logger.info(f"[ICA] Assigning sensor_proj[{i}] to full_sensor_proj[{idx}] for channel '{ch}'")
  189.                         assigned_channels.append((i, idx, ch))
  190.                     else:
  191.                         logger.warning(f"[ICA] Out-of-bounds: sensor_proj[{i}] or full_sensor_proj[{idx}] for channel '{ch}'")
  192.                         skipped_channels.append((i, idx, ch))
  193.                 else:
  194.                     logger.warning(f"[ICA] Channel '{ch}' not found in info['ch_names'] – skipping.")
  195.                     skipped_channels.append((i, None, ch))
  196.             logger.info(f"[ICA] Assigned channels for component {comp_idx}: {assigned_channels}")
  197.             if skipped_channels:
  198.                 logger.warning(f"[ICA] Skipped channels for component {comp_idx}: {skipped_channels}")
  199.             # Average over time to get a single "evoked" (mean across time)
  200.             comp_evoked = np.mean(full_sensor_proj, axis=1, keepdims=True)
  201.             evoked = mne.EvokedArray(comp_evoked, info, tmin=0)
  202.             try:
  203.                 stc = mne.minimum_norm.apply_inverse(evoked, inv_op, lambda2=1.0/9.0, method=method_label, pick_ori=None, verbose=False)
  204.                 fig = plot_source_estimate(stc, view="lateral", time_point=time_point, subjects_dir=subjects_dir)
  205.                 out_path = out_dir / f"component_{comp_idx:02d}.png"
  206.                 fig.savefig(str(out_path), dpi=150, facecolor='black')
  207.                 plt.close(fig)
  208.                 rel_path = str(out_path.relative_to(subject_folder))
  209.                 results[method].append(rel_path)
  210.                 logger.info(f"Saved {method} topomap for {band} to {out_path}")
  211.             except Exception as e:
  212.                 logger.warning(f"Error saving {method} topomap for {band}: {e}")
  213.     return results
  214.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement