{ "nbformat_minor": 0, "nbformat": 4, "cells": [ { "execution_count": null, "cell_type": "code", "source": [ "%matplotlib inline" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "\n# Decoding real-time data\n\n\nSupervised machine learning applied to MEG data in sensor space.\nHere the classifier is updated every 5 trials and the decoding\naccuracy is plotted\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "# Authors: Mainak Jas \n#\n# License: BSD (3-clause)\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport mne\nfrom mne.realtime import MockRtClient, RtEpochs\nfrom mne.datasets import sample\n\nprint(__doc__)\n\n# Fiff file to simulate the realtime client\ndata_path = sample.data_path()\nraw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'\nraw = mne.io.read_raw_fif(raw_fname, preload=True)\n\ntmin, tmax = -0.2, 0.5\nevent_id = dict(aud_l=1, vis_l=3)\n\ntr_percent = 60 # Training percentage\nmin_trials = 10 # minimum trials after which decoding should start\n\n# select gradiometers\npicks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,\n stim=True, exclude=raw.info['bads'])\n\n# create the mock-client object\nrt_client = MockRtClient(raw)\n\n# create the real-time epochs object\nrt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1,\n reject=dict(grad=4000e-13, eog=150e-6), baseline=None,\n isi_max=4.)\n\n# start the acquisition\nrt_epochs.start()\n\n# send raw buffers\nrt_client.send_data(rt_epochs, picks, tmin=0, tmax=90, buffer_size=1000)\n\n# Decoding in sensor space using a linear SVM\nn_times = len(rt_epochs.times)\n\nfrom sklearn import preprocessing # noqa\nfrom sklearn.svm import SVC # noqa\nfrom sklearn.pipeline import Pipeline # noqa\nfrom sklearn.cross_validation import cross_val_score, ShuffleSplit # noqa\nfrom mne.decoding import Vectorizer, FilterEstimator # noqa\n\n\nscores_x, scores, std_scores = [], [], []\n\n# don't highpass filter because it's epoched data and the signal length\n# is small\nfilt = FilterEstimator(rt_epochs.info, None, 40)\nscaler = preprocessing.StandardScaler()\nvectorizer = Vectorizer()\nclf = SVC(C=1, kernel='linear')\n\nconcat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),\n ('scaler', scaler), ('svm', clf)])\n\ndata_picks = mne.pick_types(rt_epochs.info, meg='grad', eeg=False, eog=True,\n stim=False, exclude=raw.info['bads'])\nax = plt.subplot(111)\nax.set_xlabel('Trials')\nax.set_ylabel('Classification score (% correct)')\nax.set_title('Real-time decoding')\nax.set_xlim([min_trials, 50])\nax.set_ylim([30, 105])\nplt.axhline(50, color='k', linestyle='--', label=\"Chance level\")\nplt.show(block=False)\n\nfor ev_num, ev in enumerate(rt_epochs.iter_evoked()):\n\n print(\"Just got epoch %d\" % (ev_num + 1))\n\n if ev_num == 0:\n X = ev.data[None, data_picks, :]\n y = int(ev.comment) # the comment attribute contains the event_id\n else:\n X = np.concatenate((X, ev.data[None, data_picks, :]), axis=0)\n y = np.append(y, int(ev.comment))\n\n if ev_num >= min_trials:\n\n cv = ShuffleSplit(len(y), 5, test_size=0.2, random_state=42)\n scores_t = cross_val_score(concat_classifier, X, y, cv=cv,\n n_jobs=1) * 100\n\n std_scores.append(scores_t.std())\n scores.append(scores_t.mean())\n scores_x.append(ev_num)\n\n # Plot accuracy\n\n plt.plot(scores_x[-2:], scores[-2:], '-x', color='b',\n label=\"Classif. score\")\n ax.hold(True)\n ax.plot(scores_x[-1], scores[-1])\n\n hyp_limits = (np.asarray(scores) - np.asarray(std_scores),\n np.asarray(scores) + np.asarray(std_scores))\n fill = plt.fill_between(scores_x, hyp_limits[0], y2=hyp_limits[1],\n color='b', alpha=0.5)\n plt.pause(0.01)\n plt.draw()\n ax.collections.remove(fill) # Remove old fill area\n\nplt.fill_between(scores_x, hyp_limits[0], y2=hyp_limits[1], color='b',\n alpha=0.5)\nplt.draw() # Final figure" ], "outputs": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 2", "name": "python2", "language": "python" }, "language_info": { "mimetype": "text/x-python", "nbconvert_exporter": "python", "name": "python", "file_extension": ".py", "version": "2.7.13", "pygments_lexer": "ipython2", "codemirror_mode": { "version": 2, "name": "ipython" } } } }