In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeinterface.full as si
import numpy as np
from pathlib import Path
import time
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import kachery_cloud as kcl
import figurl

import sortingview.views as vv


%matplotlib widget

In [None]:
n_jobs = 10
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
# load recording and sorting
if Path("rec_bin").is_dir():
    rec = si.load_extractor("rec_bin")
    sort = si.load_extractor("sort_bin")
else:
    rec, sort = si.toy_example(num_channels=32, num_units=20, num_segments=1, duration=120)
    rec = rec.save(folder="rec_bin")
    sort = sort.save(folder="sort_bin")

In [None]:
rec = rec.channel_slice(rec.channel_ids, renamed_channel_ids=rec.channel_ids.astype("str"))
sort = sort.select_units(sort.unit_ids, renamed_unit_ids=sort.unit_ids.astype("str"))

channel_ids = rec.channel_ids
unit_ids = sort.unit_ids

In [None]:
we = si.extract_waveforms(rec, sort, folder="wf_folder", load_if_exists=True, **job_kwargs)

In [None]:
sparsity = si.get_template_channel_sparsity(we, method="radius", radius_um=50)

In [None]:
# templates
templates = {}
for unit in sort.unit_ids:
    template_mean = we.get_template(unit, mode="average", sparsity=sparsity)
    template_std = we.get_template(unit, mode="std", sparsity=sparsity)
    
    templates[unit] = {}
    templates[unit]["mean"] = template_mean.T
    templates[unit]["std"] = template_std.T

In [None]:
# NOTE: the version on pypi has different parameter name: sorting

# ccgs
ccgs, bins = si.compute_correlograms(waveform_or_sorting_extractor=sort, symmetrize=True,
                                     bin_ms=0.5)

In [None]:
# NOTE: requires latest install from main branch

# spike localization
locs = si.compute_spike_locations(we, method="monopolar_triangulation", method_kwargs={"raidus": 100},
                                  outputs="by_unit", load_if_exists=True, **job_kwargs)

In [None]:
# spike amplitudes
amplitudes = si.compute_spike_amplitudes(we, outputs="by_unit", load_if_exists=True, **job_kwargs)

In [None]:
# similarity
similarity = si.compute_template_similarity(we)

In [None]:
# template metrics
tm = si.calculate_template_metrics(we, upsampling_factor=10)

In [None]:
# quality metrics
metric_names = si.get_quality_metric_list()
# metric_names += si.get_quality_pca_metric_list()
metric_names += ["nearest_neighbor"]

# compute PC
pc = si.compute_principal_components(we, n_jobs=n_jobs, mode="by_channel_local", progress_bar=True,
                                     load_if_exists=True)

In [None]:
qm = si.compute_quality_metrics(we, sparsity=sparsity, verbose=True, progress_bar=True, 
                                metric_names=metric_names, n_jobs=n_jobs, load_if_exists=False)

In [None]:
# merge metrics
metrics = qm.merge(tm, left_index=True, right_index=True)
metrics

In [None]:
# NOTE: output='dict' not supported parameter in latest on main branch (using outputs='dict' gives an output of None)

# unit locations
unit_locations = si.localize_units(we, method="monopolar_triangulation", outputs='by_unit')

In [None]:
# Units table

ut_rows = [
    vv.UnitsTableRow(unit_id=u, values={})
    for u in unit_ids
]
ut_columns = []

v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns)
url = v_units_table.url(label='Example units table')
print(url)

In [None]:
# Average waveforms

aw_items = [
    vv.AverageWaveformItem(
        unit_id=u,
        channel_ids=list(sparsity[u]),
        waveform=t['mean'].astype('float32'),
        waveform_std_dev=t['std'].astype('float32')
    )
    for u, t in templates.items()
]

locations = rec.get_channel_locations()
channel_locations = {channel_ids[ch]: locations[ch].astype("float32")
                     for ch in np.arange(rec.get_num_channels())}
v_average_waveforms = vv.AverageWaveforms(
    average_waveforms=aw_items,
    channel_locations=channel_locations
)
url = v_average_waveforms.url(label='Test average waveforms')
print(url)

In [None]:
# Correlograms

ac_items = []
cc_items = []
for i in range(ccgs.shape[0]):
    for j in range(i, ccgs.shape[0]):
        if i == j:
            ac_items.append(
                vv.AutocorrelogramItem(
                    unit_id=unit_ids[i],
                    bin_edges_sec=(bins/1000.).astype("float32"),
                    bin_counts=ccgs[i, j].astype("int32")
                )
            )
        cc_items.append(
            vv.CrossCorrelogramItem(
                unit_id1=unit_ids[i],
                unit_id2=unit_ids[j],
                bin_edges_sec=(bins/1000.).astype("float32"),
                bin_counts=ccgs[i, j].astype("int32")
            )
        )

v_autocorrelograms = vv.Autocorrelograms(
    autocorrelograms=ac_items
)
v_cross_correlograms = vv.CrossCorrelograms(
    cross_correlograms=cc_items
)

view = vv.Splitter(
    direction='horizontal',
    item1=vv.LayoutItem(v_autocorrelograms),
    item2=vv.LayoutItem(v_cross_correlograms)
)

url = view.url(label='Test correlograms')
print(url)

In [None]:
# Spike amplitudes

sa_items = [
    vv.SpikeAmplitudesItem(
        unit_id=u,
        spike_times_sec=(sort.get_unit_spike_train(u) / sort.get_sampling_frequency()).astype("float32"),
        spike_amplitudes=(amps).astype("float32")
    )
    for u, amps in amplitudes[0].items()
]

v_spike_amplitudes = vv.SpikeAmplitudes(
    start_time_sec=0,
    end_time_sec=rec.get_total_duration(),
    plots=sa_items
)

url = v_spike_amplitudes.url(label='Test spike amplitudes')
print(url)

In [None]:
# sl_unit_data = [{"unitId": u, 
#                  "xLocations": loc["x"].astype("float32"),
#                  "yLocations": loc["y"].astype("float32"),
#                  "zLocations": loc["z"].astype("float32"),
#                  "spikeTimesSec":  (sort.get_unit_spike_train(u) / sort.get_sampling_frequency()).astype("float32")}
#                 for u, loc in locs[0].items()]
# # channel_locations = [chan: loc]
# spike_locations_view_data = dict(type="SpikeLocations",
#                                  startTimeSec=0, 
#                                  endTimeSec=rec.get_total_duration(), 
#                                  units=sl_unit_data)
# sv_dict.update(dict(SpikeLocationsViewData=spike_locations_view_data))

In [None]:
# Unit similarity matrix

ss_items = []
for i1, u1 in enumerate(unit_ids):
    for i2, u2 in enumerate(unit_ids):
        ss_items.append(vv.UnitSimilarityScore(
            unit_id1=u1,
            unit_id2=u2,
            similarity=similarity[i1, i2]
        ))

v_unit_similarity_matrix = vv.UnitSimilarityMatrix(
    unit_ids=list(unit_ids),
    similarity_scores=ss_items
)

url = v_unit_similarity_matrix.url(label='Test unit similarity matrix')
print(url)

In [None]:
# ul_unit_data = [dict(unitId=u, location=loc.astype("float32")) for u, loc in unit_locations.items()]

# unit_locations_view_data = dict(type="UnitLocations", 
#                                 units=ul_unit_data, 
#                                 channelLocations=channel_locations)
# sv_dict.update(dict(UnitLocationsViewData=unit_locations_view_data))

In [None]:
# template_metric_names = si.get_template_metric_names()

# skip_metrics = ['isi_violations_rate', 'isi_violations_count']

# um_metrics = []
# for metric in metrics.columns:
#     if metric not in skip_metrics:
#         if metric in template_metric_names:
#             metric_type = "template"
#         else:
#             metric_type = "quality"
#         um_metrics.append(dict(name=metric, metricType=metric_type, description=""))

# um_units = []
# for index, row in metrics.iterrows():
#     values = {}
#     for metric in row.keys():
#         if metric not in skip_metrics:
#             values[metric] = row[metric]
#     um_units.append(dict(unitId=int(index), values=values))
    
# unit_metrics_view_data = dict(type="UnitMetrics", 
#                               metrics=um_metrics, units=um_units)

# sv_dict.update(dict(UnitMetricsViewData=unit_metrics_view_data))

In [None]:
# Create layout

view = vv.Box(
    direction='horizontal',
    items=[
        vv.LayoutItem(v_units_table, max_size=150),
        vv.LayoutItem(
            vv.Splitter(
                direction='horizontal',
                item1=vv.LayoutItem(
                    vv.Box(
                        direction='vertical',
                        items=[
                            vv.LayoutItem(v_spike_amplitudes),
                            vv.LayoutItem(
                                vv.Splitter(
                                    direction='horizontal',
                                    item1=vv.LayoutItem(v_cross_correlograms, stretch=3),
                                    item2=vv.LayoutItem(v_unit_similarity_matrix, stretch=1)
                                )
                            )
                        ]
                    )
                ),
                item2=vv.LayoutItem(
                    vv.Box(
                        direction='vertical',
                        items=[
                            vv.LayoutItem(v_average_waveforms),
                            vv.LayoutItem(v_autocorrelograms)
                        ]
                    )
                )
            )
        )
    ]
)

url = view.url(label='prepare-data-for-sv.ipynb')
print(url)