Advertisement
El_Chaderino

EDF Data Export Module

May 17th, 2025
353
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.81 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. data_to_csv.py - EDF Data Export Module
  5. By El Chaderino Github.com/ElChaderino/The-Squiggle-Interpreter
  6. This module provides comprehensive data export functionality supporting multiple clinical
  7. and research formats. It handles:
  8. - Standard epoch-based metrics
  9. - Clinical summary formats (QEEG style)
  10. - Research-grade exports with detailed feature sets
  11. - Multiple standardized formats (EDF+, European Data Format, etc.)
  12. """
  13.  
  14. import argparse
  15. import numpy as np
  16. import pandas as pd
  17. import mne
  18. from pathlib import Path
  19. from typing import Dict, List, Optional, Tuple, Union, Any
  20. from modules import io_utils
  21.  
  22. # Enhanced frequency bands including clinical and research definitions
  23. BANDS = {
  24.     # Clinical bands
  25.     "Delta": (1, 4),
  26.     "Theta": (4, 8),
  27.     "Alpha": (8, 12),
  28.     "SMR": (12, 15),
  29.     "Beta": (15, 27),
  30.     "HighBeta": (28, 38),
  31.     # Research bands
  32.     "Low_Delta": (0.5, 2),
  33.     "High_Delta": (2, 4),
  34.     "Low_Theta": (4, 6),
  35.     "High_Theta": (6, 8),
  36.     "Low_Alpha": (8, 10),
  37.     "High_Alpha": (10, 12),
  38.     "Low_Beta": (12, 15),
  39.     "Mid_Beta": (15, 18),
  40.     "High_Beta": (18, 25),
  41.     "Gamma": (35, 45),
  42.     "High_Gamma": (45, 80),
  43.     # Additional specialized bands
  44.     "Mu": (8, 13),        # Motor cortex rhythm
  45.     "Sigma": (12, 16),    # Sleep spindles
  46.     "SCP": (0.1, 1),      # Slow cortical potentials
  47.     "HFO": (80, 200)      # High-frequency oscillations
  48. }
  49.  
  50. # Clinical feature sets
  51. CLINICAL_FEATURES = {
  52.     "basic": ["Delta", "Theta", "Alpha", "SMR", "Beta", "HighBeta"],
  53.     "extended": ["Low_Delta", "High_Delta", "Low_Theta", "High_Theta",
  54.                 "Low_Alpha", "High_Alpha", "Low_Beta", "Mid_Beta", "High_Beta"],
  55.     "full": list(BANDS.keys()),
  56.     "sleep": ["Delta", "Theta", "Alpha", "Sigma", "Beta"],
  57.     "motor": ["Delta", "Theta", "Mu", "Beta", "Gamma"],
  58.     "cognitive": ["Theta", "Alpha", "Beta", "Gamma"]
  59. }
  60.  
  61. # Export format specifications
  62. EXPORT_FORMATS = {
  63.     "standard": {
  64.         "description": "Basic epoch-based metrics",
  65.         "features": CLINICAL_FEATURES["basic"],
  66.         "include_ratios": False
  67.     },
  68.     "clinical": {
  69.         "description": "Clinical QEEG format with ratios",
  70.         "features": CLINICAL_FEATURES["basic"],
  71.         "include_ratios": True,
  72.         "ratios": [
  73.             ("Theta", "Beta", "Theta/Beta"),
  74.             ("Alpha", "Theta", "Alpha/Theta"),
  75.             ("Beta", "Alpha", "Beta/Alpha"),
  76.             ("Delta", "Alpha", "Delta/Alpha"),
  77.             ("Theta", "Alpha", "Theta/Alpha"),
  78.             ("Beta", "Theta", "Beta/Theta")
  79.         ]
  80.     },
  81.     "research": {
  82.         "description": "Comprehensive research format",
  83.         "features": CLINICAL_FEATURES["full"],
  84.         "include_ratios": True,
  85.         "include_connectivity": True,
  86.         "include_complexity": True,
  87.         "include_advanced_stats": True
  88.     },
  89.     "minimal": {
  90.         "description": "Minimal clinical format",
  91.         "features": ["Delta", "Theta", "Alpha", "Beta"],
  92.         "include_ratios": False
  93.     },
  94.     "sleep": {
  95.         "description": "Sleep analysis format",
  96.         "features": CLINICAL_FEATURES["sleep"],
  97.         "include_ratios": True,
  98.         "include_spindles": True,
  99.         "ratios": [
  100.             ("Delta", "Beta", "Delta/Beta"),
  101.             ("Theta", "Beta", "Theta/Beta"),
  102.             ("Sigma", "Beta", "Sigma/Beta")
  103.         ]
  104.     },
  105.     "motor": {
  106.         "description": "Motor analysis format",
  107.         "features": CLINICAL_FEATURES["motor"],
  108.         "include_ratios": True,
  109.         "include_mu_rhythm": True,
  110.         "ratios": [
  111.             ("Mu", "Beta", "Mu/Beta"),
  112.             ("Beta", "Gamma", "Beta/Gamma")
  113.         ]
  114.     },
  115.     "cognitive": {
  116.         "description": "Cognitive analysis format",
  117.         "features": CLINICAL_FEATURES["cognitive"],
  118.         "include_ratios": True,
  119.         "include_complexity": True,
  120.         "ratios": [
  121.             ("Theta", "Alpha", "Theta/Alpha"),
  122.             ("Alpha", "Beta", "Alpha/Beta"),
  123.             ("Theta", "Beta", "Theta/Beta")
  124.         ]
  125.     },
  126.     "connectivity": {
  127.         "description": "Connectivity-focused format",
  128.         "features": CLINICAL_FEATURES["basic"],
  129.         "include_connectivity": True,
  130.         "connectivity_metrics": ["wpli", "plv", "pli", "dpli", "imcoh"],
  131.         "include_network_metrics": True
  132.     }
  133. }
  134.  
  135. def compute_band_power(data: np.ndarray, sfreq: float, band: Tuple[float, float]) -> float:
  136.     """Enhanced band power computation with better error handling and normalization."""
  137.     try:
  138.         fmin, fmax = band
  139.         filtered = mne.filter.filter_data(data.astype(np.float64), sfreq, fmin, fmax, verbose=False)
  140.         # Normalize by frequency range to make bands comparable
  141.         power = np.mean(filtered ** 2) / (fmax - fmin)
  142.         return float(power)
  143.     except Exception as e:
  144.         print(f"Error computing band power for band {band}: {e}")
  145.         return np.nan
  146.  
  147. def compute_connectivity_metrics(epoch_data: np.ndarray, sfreq: float, band: Tuple[float, float]) -> Dict[str, float]:
  148.     """Compute connectivity metrics between channels."""
  149.     try:
  150.         from mne.connectivity import spectral_connectivity
  151.         conn = spectral_connectivity(
  152.             epoch_data[np.newaxis, :, :],
  153.             method='wpli',
  154.             sfreq=sfreq,
  155.             fmin=band[0],
  156.             fmax=band[1],
  157.             verbose=False
  158.         )
  159.         return {
  160.             'wpli_mean': float(np.mean(conn[0])),
  161.             'wpli_std': float(np.std(conn[0]))
  162.         }
  163.     except Exception as e:
  164.         print(f"Error computing connectivity for band {band}: {e}")
  165.         return {'wpli_mean': np.nan, 'wpli_std': np.nan}
  166.  
  167. def compute_complexity_metrics(data: np.ndarray) -> Dict[str, float]:
  168.     """Compute signal complexity metrics."""
  169.     try:
  170.         from antropy import sample_entropy, perm_entropy
  171.         return {
  172.             'sample_entropy': float(sample_entropy(data)),
  173.             'perm_entropy': float(perm_entropy(data))
  174.         }
  175.     except Exception as e:
  176.         print(f"Error computing complexity metrics: {e}")
  177.         return {'sample_entropy': np.nan, 'perm_entropy': np.nan}
  178.  
  179. def compute_advanced_stats(values: np.ndarray) -> Dict[str, float]:
  180.     """Compute advanced statistical measures."""
  181.     try:
  182.         from scipy import stats
  183.         from numpy import ma
  184.        
  185.         # Handle potential NaN values
  186.         masked_values = ma.masked_invalid(values)
  187.        
  188.         return {
  189.             # Central tendency
  190.             'mean': float(np.mean(masked_values)),
  191.             'median': float(np.median(masked_values)),
  192.             'mode': float(stats.mode(masked_values, keepdims=True)[0][0]),
  193.             'trimmed_mean': float(stats.trim_mean(masked_values, 0.1)),
  194.            
  195.             # Dispersion
  196.             'std': float(np.std(masked_values)),
  197.             'var': float(np.var(masked_values)),
  198.             'mad': float(stats.median_abs_deviation(masked_values)),
  199.             'iqr': float(stats.iqr(masked_values)),
  200.             'range': float(np.ptp(masked_values)),
  201.            
  202.             # Shape
  203.             'skew': float(stats.skew(masked_values)),
  204.             'kurtosis': float(stats.kurtosis(masked_values)),
  205.            
  206.             # Distribution
  207.             'shapiro_stat': float(stats.shapiro(masked_values)[0]),
  208.             'shapiro_p': float(stats.shapiro(masked_values)[1]),
  209.            
  210.             # Robust statistics
  211.             'winsorized_mean': float(stats.winsorize(masked_values, limits=0.05).mean()),
  212.             'huber_mean': float(stats.huber(masked_values)[0])
  213.         }
  214.     except Exception as e:
  215.         print(f"Error computing advanced stats: {e}")
  216.         return {stat: np.nan for stat in [
  217.             'mean', 'median', 'mode', 'trimmed_mean', 'std', 'var', 'mad', 'iqr',
  218.             'range', 'skew', 'kurtosis', 'shapiro_stat', 'shapiro_p',
  219.             'winsorized_mean', 'huber_mean'
  220.         ]}
  221.  
  222. def compute_network_metrics(connectivity_matrix: np.ndarray) -> Dict[str, float]:
  223.     """Compute graph theory metrics from connectivity matrix."""
  224.     try:
  225.         import networkx as nx
  226.        
  227.         # Create weighted graph from connectivity matrix
  228.         G = nx.from_numpy_array(np.abs(connectivity_matrix))
  229.        
  230.         return {
  231.             'density': float(nx.density(G)),
  232.             'avg_clustering': float(nx.average_clustering(G, weight='weight')),
  233.             'avg_path_length': float(nx.average_shortest_path_length(G, weight='weight')),
  234.             'global_efficiency': float(nx.global_efficiency(G)),
  235.             'modularity': float(nx.community.modularity_max(G)[0]),
  236.             'assortativity': float(nx.degree_assortativity_coefficient(G, weight='weight')),
  237.             'small_worldness': float(nx.sigma(G)) if nx.is_connected(G) else np.nan
  238.         }
  239.     except Exception as e:
  240.         print(f"Error computing network metrics: {e}")
  241.         return {metric: np.nan for metric in [
  242.             'density', 'avg_clustering', 'avg_path_length', 'global_efficiency',
  243.             'modularity', 'assortativity', 'small_worldness'
  244.         ]}
  245.  
  246. def compute_enhanced_connectivity_metrics(epoch_data: np.ndarray, sfreq: float, band: Tuple[float, float]) -> Dict[str, float]:
  247.     """Compute enhanced connectivity metrics between channels."""
  248.     try:
  249.         from mne.connectivity import spectral_connectivity
  250.        
  251.         methods = ['wpli', 'plv', 'pli', 'dpli', 'imcoh']
  252.         results = {}
  253.        
  254.         for method in methods:
  255.             conn = spectral_connectivity(
  256.                 epoch_data[np.newaxis, :, :],
  257.                 method=method,
  258.                 sfreq=sfreq,
  259.                 fmin=band[0],
  260.                 fmax=band[1],
  261.                 verbose=False
  262.             )
  263.             results[f'{method}_mean'] = float(np.mean(conn[0]))
  264.             results[f'{method}_std'] = float(np.std(conn[0]))
  265.            
  266.             # Add network metrics if matrix is available
  267.             if conn[0].shape[-1] > 2:  # If we have a connectivity matrix
  268.                 network_metrics = compute_network_metrics(conn[0])
  269.                 results.update({f'{method}_{k}': v for k, v in network_metrics.items()})
  270.        
  271.         return results
  272.     except Exception as e:
  273.         print(f"Error computing enhanced connectivity for band {band}: {e}")
  274.         return {f'{method}_{metric}': np.nan
  275.                 for method in ['wpli', 'plv', 'pli', 'dpli', 'imcoh']
  276.                 for metric in ['mean', 'std']}
  277.  
  278. def process_edf_to_csv(edf_path: str,
  279.                       epoch_length: float,
  280.                       output_dir: str,
  281.                       export_format: str = "standard",
  282.                       conditions: List[str] = ["EO", "EC"],
  283.                       overwrite: bool = False) -> None:
  284.     """
  285.    Enhanced EDF processing with multiple export formats and conditions.
  286.    
  287.    Args:
  288.        edf_path: Path to the EDF file
  289.        epoch_length: Duration of epochs in seconds
  290.        output_dir: Directory for output files
  291.        export_format: One of EXPORT_FORMATS keys
  292.        conditions: List of conditions to process
  293.        overwrite: Whether to overwrite existing files
  294.    """
  295.     output_dir = Path(output_dir)
  296.     output_dir.mkdir(parents=True, exist_ok=True)
  297.    
  298.     # Load data
  299.     raw = io_utils.load_eeg_data(edf_path, use_csd=False, for_source=False, apply_filter=True)
  300.     sfreq = raw.info["sfreq"]
  301.    
  302.     # Get format specification
  303.     format_spec = EXPORT_FORMATS.get(export_format, EXPORT_FORMATS["standard"])
  304.    
  305.     for condition in conditions:
  306.         # Process each condition
  307.         output_files = {
  308.             "epochs": output_dir / f"{condition}_epochs.csv",
  309.             "summary": output_dir / f"{condition}_summary.csv",
  310.             "connectivity": output_dir / f"{condition}_connectivity.csv",
  311.             "advanced_stats": output_dir / f"{condition}_advanced_stats.csv",
  312.             "network": output_dir / f"{condition}_network_metrics.csv"
  313.         }
  314.        
  315.         # Skip if files exist and not overwriting
  316.         if not overwrite and all(f.exists() for f in output_files.values()):
  317.             print(f"Files already exist for condition {condition}. Skipping...")
  318.             continue
  319.  
  320.         # Create epochs
  321.         events = mne.make_fixed_length_events(raw, duration=epoch_length, verbose=False)
  322.         epochs = mne.Epochs(raw, events, tmin=0, tmax=epoch_length, baseline=None, preload=True, verbose=False)
  323.        
  324.         # Process epochs
  325.         epoch_rows = []
  326.         summary_data = {ch: {band: [] for band in format_spec["features"]} for ch in epochs.ch_names}
  327.        
  328.         for i, epoch in enumerate(epochs.get_data()):
  329.             epoch_start = events[i, 0] / sfreq
  330.            
  331.             # Process each channel
  332.             for ch_idx, ch in enumerate(epochs.ch_names):
  333.                 row = {
  334.                     "Condition": condition,
  335.                     "Channel": ch,
  336.                     "Epoch": i,
  337.                     "Start_Time": epoch_start,
  338.                     "End_Time": epoch_start + epoch_length
  339.                 }
  340.                
  341.                 # Compute band powers
  342.                 for band_name in format_spec["features"]:
  343.                     power = compute_band_power(epoch[ch_idx], sfreq, BANDS[band_name])
  344.                     row[band_name] = power
  345.                     summary_data[ch][band_name].append(power)
  346.                
  347.                 # Add ratios if specified
  348.                 if format_spec.get("include_ratios"):
  349.                     for band1, band2, ratio_name in format_spec.get("ratios", []):
  350.                         if row[band1] != 0:
  351.                             row[ratio_name] = row[band2] / row[band1]
  352.                         else:
  353.                             row[ratio_name] = np.nan
  354.                
  355.                 # Add complexity metrics for research format
  356.                 if format_spec.get("include_complexity"):
  357.                     complexity = compute_complexity_metrics(epoch[ch_idx])
  358.                     row.update(complexity)
  359.                
  360.                 epoch_rows.append(row)
  361.            
  362.             # Add connectivity metrics if specified
  363.             if format_spec.get("include_connectivity"):
  364.                 for band_name in format_spec["features"]:
  365.                     conn_metrics = compute_enhanced_connectivity_metrics(
  366.                         epoch, sfreq, BANDS[band_name])
  367.                     for metric, value in conn_metrics.items():
  368.                         row[f"{band_name}_{metric}"] = value
  369.        
  370.         # Save epoch-level data
  371.         pd.DataFrame(epoch_rows).to_csv(output_files["epochs"], index=False)
  372.        
  373.         # Create and save summary statistics
  374.         summary_rows = []
  375.         for ch in epochs.ch_names:
  376.             row = {"Channel": ch, "Condition": condition}
  377.             for band in format_spec["features"]:
  378.                 values = summary_data[ch][band]
  379.                 row.update({
  380.                     f"{band}_Mean": np.mean(values),
  381.                     f"{band}_Std": np.std(values),
  382.                     f"{band}_Median": np.median(values)
  383.                 })
  384.             summary_rows.append(row)
  385.        
  386.         pd.DataFrame(summary_rows).to_csv(output_files["summary"], index=False)
  387.        
  388.         # Add advanced statistics if specified
  389.         if format_spec.get("include_advanced_stats"):
  390.             advanced_stats_rows = []
  391.             for ch in epochs.ch_names:
  392.                 row = {"Channel": ch, "Condition": condition}
  393.                 for band in format_spec["features"]:
  394.                     values = summary_data[ch][band]
  395.                     stats = compute_advanced_stats(np.array(values))
  396.                     row.update({f"{band}_{stat}": value
  397.                               for stat, value in stats.items()})
  398.                 advanced_stats_rows.append(row)
  399.             pd.DataFrame(advanced_stats_rows).to_csv(output_files["advanced_stats"], index=False)
  400.        
  401.         # Add enhanced connectivity metrics if specified
  402.         if format_spec.get("include_connectivity"):
  403.             connectivity_rows = []
  404.             for i, epoch in enumerate(epochs.get_data()):
  405.                 row = {"Epoch": i, "Condition": condition}
  406.                 for band_name in format_spec["features"]:
  407.                     conn_metrics = compute_enhanced_connectivity_metrics(
  408.                         epoch, sfreq, BANDS[band_name])
  409.                     row.update({f"{band_name}_{k}": v
  410.                               for k, v in conn_metrics.items()})
  411.                 connectivity_rows.append(row)
  412.             pd.DataFrame(connectivity_rows).to_csv(output_files["connectivity"], index=False)
  413.        
  414.         print(f"Processed {condition}. Files saved to {output_dir}")
  415.  
  416. def save_computed_features_to_csv(features: Dict[str, Any], info: Dict[str, Any], output_path: str) -> None:
  417.     """
  418.    Save computed EEG features to a CSV file.
  419.    
  420.    Args:
  421.        features: Dictionary containing computed features
  422.        info: Dictionary containing metadata and information about the recording
  423.        output_path: Path where to save the CSV file
  424.    """
  425.     # Create a flat dictionary for DataFrame
  426.     flat_dict = {}
  427.    
  428.     # Add metadata
  429.     for key, value in info.items():
  430.         if isinstance(value, (str, int, float)):
  431.             flat_dict[f"meta_{key}"] = [value]
  432.    
  433.     # Add features
  434.     for feature_type, feature_dict in features.items():
  435.         if isinstance(feature_dict, dict):
  436.             for metric_name, value in feature_dict.items():
  437.                 if isinstance(value, (int, float, str, np.number)):
  438.                     flat_dict[f"{feature_type}_{metric_name}"] = [value]
  439.                 elif isinstance(value, np.ndarray) and value.size == 1:
  440.                     flat_dict[f"{feature_type}_{metric_name}"] = [float(value)]
  441.    
  442.     # Create DataFrame and save
  443.     df = pd.DataFrame(flat_dict)
  444.     df.to_csv(output_path, index=False)
  445.  
  446. def save_band_powers_to_csv(band_powers: dict, output_path: str):
  447.     """Save band powers (dict[channel][band]) to CSV."""
  448.     rows = []
  449.     for ch, bands in band_powers.items():
  450.         row = {"Channel": ch}
  451.         row.update(bands)
  452.         rows.append(row)
  453.     pd.DataFrame(rows).to_csv(output_path, index=False)
  454.  
  455. def save_zscores_to_csv(zscores: dict, output_path: str):
  456.     """Save z-scores (dict[band] = list of z-scores per channel) to CSV."""
  457.     df = pd.DataFrame(zscores)
  458.     df.to_csv(output_path, index=False)
  459.  
  460. def save_tfr_to_csv(tfr, output_path: str):
  461.     """Save TFR (mne.time_frequency.AverageTFR) as long-form CSV: channel, freq, time, value."""
  462.     data = []
  463.     for ch_idx, ch_name in enumerate(tfr.ch_names):
  464.         for f_idx, freq in enumerate(tfr.freqs):
  465.             for t_idx, time in enumerate(tfr.times):
  466.                 data.append({
  467.                     "Channel": ch_name,
  468.                     "Frequency": freq,
  469.                     "Time": time,
  470.                     "Power": tfr.data[ch_idx, f_idx, t_idx]
  471.                 })
  472.     pd.DataFrame(data).to_csv(output_path, index=False)
  473.  
  474. def save_ica_to_csv(ica, raw, output_path_prefix: str):
  475.     """Save ICA mixing matrix and component time series to CSVs."""
  476.     # Mixing matrix
  477.     pd.DataFrame(ica.mixing_matrix_).to_csv(f"{output_path_prefix}_mixing_matrix.csv", index=False)
  478.     # Component time series
  479.     sources = ica.get_sources(raw).get_data()
  480.     pd.DataFrame(sources).to_csv(f"{output_path_prefix}_sources.csv", index=False)
  481.  
  482. def save_source_localization_to_csv(stc, output_path: str):
  483.     """Save source estimate (mne.SourceEstimate) as long-form CSV: vertex, time, value."""
  484.     data = []
  485.     for v_idx, vertex in enumerate(stc.vertices[0]):
  486.         for t_idx, time in enumerate(stc.times):
  487.             data.append({
  488.                 "Vertex": vertex,
  489.                 "Time": time,
  490.                 "Value": stc.data[v_idx, t_idx]
  491.             })
  492.     pd.DataFrame(data).to_csv(output_path, index=False)
  493.  
  494. def main():
  495.     parser = argparse.ArgumentParser(
  496.         description="Enhanced EDF processing with multiple export formats"
  497.     )
  498.     parser.add_argument("--edf", required=True, help="Path to the EDF file")
  499.     parser.add_argument("--epoch_length", type=float, default=2.0, help="Epoch length in seconds")
  500.     parser.add_argument("--output_dir", required=True, help="Output directory")
  501.     parser.add_argument("--format", choices=list(EXPORT_FORMATS.keys()), default="standard",
  502.                       help="Export format specification")
  503.     parser.add_argument("--conditions", nargs="+", default=["EO", "EC"],
  504.                       help="Conditions to process")
  505.     parser.add_argument("--overwrite", action="store_true",
  506.                       help="Overwrite existing files")
  507.    
  508.     args = parser.parse_args()
  509.     process_edf_to_csv(args.edf, args.epoch_length, args.output_dir,
  510.                       args.format, args.conditions, args.overwrite)
  511.  
  512. if __name__ == "__main__":
  513.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement