Skip to content

tsam.plot

tsam.plot

Plotting accessor for tsam aggregation results.

Provides convenient plotting methods directly on the result object for validation and visualization of aggregation quality.

Usage: >>> result = tsam.aggregate(df, n_clusters=8) >>> result.plot.compare() # Compare original vs reconstructed >>> result.plot.residuals() # View reconstruction errors >>> result.plot.cluster_representatives() >>> result.plot.cluster_members() # All periods per cluster >>> result.plot.cluster_weights() >>> result.plot.accuracy()

For exploring raw data before aggregation, use plotly directly with tsam.unstack_to_periods() to reshape data for heatmaps: >>> import plotly.express as px >>> unstacked = tsam.unstack_to_periods(df, period_duration=24) >>> px.imshow(unstacked["Load"].values.T)

Note: This module requires the 'plotly' optional dependency. Install with: pip install tsam[plot]

ResultPlotAccessor

Plotting accessor for AggregationResult.

Provides convenient plotting methods directly on the result object.

Examples

result = tsam.aggregate(df, n_clusters=8) result.plot.compare() # Compare original vs reconstructed result.plot.residuals() # View reconstruction errors result.plot.cluster_representatives() result.plot.cluster_members() result.plot.cluster_weights()

Source code in src/tsam/plot.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
class ResultPlotAccessor:
    """Plotting accessor for AggregationResult.

    Provides convenient plotting methods directly on the result object.

    Examples
    --------
    >>> result = tsam.aggregate(df, n_clusters=8)
    >>> result.plot.compare()  # Compare original vs reconstructed
    >>> result.plot.residuals()  # View reconstruction errors
    >>> result.plot.cluster_representatives()
    >>> result.plot.cluster_members()
    >>> result.plot.cluster_weights()
    """

    def __init__(self, result: AggregationResult):
        self._result = result

    def cluster_representatives(
        self,
        columns: list[str] | None = None,
        title: str = "Cluster Representatives",
    ) -> go.Figure:
        """Plot all cluster representatives (typical periods).

        Parameters
        ----------
        columns : list[str], optional
            Columns to plot.
        title : str, default "Cluster Representatives"
            Plot title.

        Returns
        -------
        go.Figure
        """
        typ = self._result.cluster_representatives
        weights = self._result.cluster_weights

        available_columns = [c for c in typ.columns if c not in ["cluster", "timestep"]]
        columns = _validate_columns(
            columns, available_columns, "cluster_representatives"
        )

        # Reset index to get period/timestep as columns
        df = typ[columns].reset_index()
        df.columns = pd.Index(["Period", "Timestep", *columns])

        # Map period IDs to labels with weights
        df["Period"] = df["Period"].map(lambda p: f"Period {p} (n={weights.get(p, 1)})")

        long_df = df.melt(
            id_vars=["Period", "Timestep"],
            var_name="Column",
            value_name="Value",
        )

        fig = px.line(
            long_df,
            x="Timestep",
            y="Value",
            color="Period",
            facet_col="Column" if len(columns) > 1 else None,
            title=title,
        )

        return fig

    def cluster_members(
        self,
        columns: list[str] | None = None,
        clusters: list[int] | None = None,
        slider: Literal["cluster", "column"] = "cluster",
        title: str | None = None,
    ) -> go.Figure:
        """Plot all original periods grouped by cluster with representative highlighted.

        Shows individual member periods as faint lines and the cluster
        representative as a bold line. A slider lets you flip through
        either clusters or columns.

        Parameters
        ----------
        columns : list[str], optional
            Columns to plot. If None, plots all columns.
        clusters : list[int], optional
            Cluster indices to include. If None, includes all clusters.
        slider : ``"cluster"`` or ``"column"``, default ``"cluster"``
            Which dimension to put on the slider.
            The other dimension becomes ``facet_col``.

            - ``"cluster"``: slider flips through clusters, columns are facets.
            - ``"column"``: slider flips through columns, clusters are facets.
        title : str, optional
            Plot title. Defaults to "Cluster Members".

        Returns
        -------
        go.Figure

        Examples
        --------
        >>> result.plot.cluster_members(columns=["Load"])
        >>> result.plot.cluster_members(clusters=[0, 3])  # specific clusters
        >>> result.plot.cluster_members(slider="column")  # flip through columns
        """
        from plotly.subplots import make_subplots

        from tsam.api import unstack_to_periods

        _slider = slider.lower()
        result = self._result
        columns = _validate_columns(
            columns, list(result.original.columns), "original data"
        )
        n_ts = result.n_timesteps_per_period
        idx = result.original.index
        if isinstance(idx, pd.DatetimeIndex) and len(idx) > 1:
            timestep_hours = (idx[1] - idx[0]).total_seconds() / 3600
        else:
            timestep_hours = 1.0
        unstacked = unstack_to_periods(result.original, n_ts * timestep_hours)
        assignments = result.cluster_assignments
        representatives = result.cluster_representatives
        weights = result.cluster_weights
        timesteps = np.arange(n_ts)

        all_cluster_ids = sorted(set(assignments))
        if clusters is not None:
            invalid = [c for c in clusters if c not in all_cluster_ids]
            cluster_ids = [c for c in clusters if c in all_cluster_ids]
            if invalid and cluster_ids:
                warnings.warn(
                    f"Cluster indices not found and will be ignored: {invalid}. "
                    f"Available clusters: {all_cluster_ids}",
                    UserWarning,
                    stacklevel=2,
                )
            if not cluster_ids:
                raise ValueError(
                    f"None of the requested clusters {clusters} exist. "
                    f"Available clusters: {all_cluster_ids}"
                )
        else:
            cluster_ids = all_cluster_ids
        members_by_cluster = {
            cid: np.where(assignments == cid)[0] for cid in cluster_ids
        }

        def _rep_values(cluster_id: int, col: str) -> np.ndarray:
            """Get representative values expanded to full timesteps."""
            rep = representatives.loc[cluster_id]
            if result.n_segments is not None:
                durations = rep.index.get_level_values("Segment Duration").astype(int)
                return np.repeat(rep[col].values, durations)
            return rep[col].values  # type: ignore[no-any-return]

        if _slider not in ("cluster", "column"):
            raise ValueError(f"slider must be 'cluster' or 'column', got {slider!r}")

        # Pre-extract member data as numpy arrays for fast access.
        # member_arrays[cid][col] = 2D array (n_members, n_ts)
        member_arrays: dict[int, dict[str, np.ndarray]] = {}
        for cid in cluster_ids:
            members = members_by_cluster[cid]
            member_arrays[cid] = {
                col: np.asarray(unstacked[col].iloc[members].values) for col in columns
            }

        cluster_labels = {
            cid: f"Cluster {cid} (n={weights.get(cid, 1)})" for cid in cluster_ids
        }

        # Determine which dimension is animated vs faceted.
        anim_keys: list[int | str]
        if _slider == "cluster":
            anim_keys = list(cluster_ids)
            anim_labels = [cluster_labels[c] for c in cluster_ids]
            facet_labels = columns
        else:
            anim_keys = list(columns)
            anim_labels = list(columns)
            facet_labels = [cluster_labels[c] for c in cluster_ids]

        n_facets = len(facet_labels)
        traces_per_facet = 2  # one bundled member trace + one representative
        MEMBER = {"color": "rgba(99, 110, 250, 0.3)"}
        REP = {"color": "#EF553B", "width": 3}

        # Precompute NaN-separated x-arrays (one per unique member count).
        # Each member's timesteps are separated by a NaN to break the line.
        _member_x: dict[int, np.ndarray] = {}
        for cid in cluster_ids:
            n_m = len(members_by_cluster[cid])
            if n_m not in _member_x:
                tile = np.empty(n_ts + 1)
                tile[:n_ts] = timesteps
                tile[n_ts] = np.nan
                _member_x[n_m] = np.tile(tile, n_m)[:-1]

        def _member_y(cid: int, col: str) -> np.ndarray:
            """All members as NaN-separated y-values (vectorized)."""
            data = member_arrays[cid][col]  # (n_members, n_ts)
            padded = np.column_stack([data, np.full(data.shape[0], np.nan)])
            return padded.ravel()[:-1]

        def _frame_traces(anim_key: int | str) -> list[go.Scatter]:
            """Build Scatter traces for one animation frame."""
            out: list[go.Scatter] = []
            first_member = True
            first_rep = True
            for facet_idx in range(n_facets):
                if _slider == "cluster":
                    cid, col = cast("int", anim_key), columns[facet_idx]
                else:
                    cid, col = cluster_ids[facet_idx], cast("str", anim_key)

                n_m = len(members_by_cluster[cid])
                out.append(
                    go.Scatter(
                        x=_member_x[n_m],
                        y=_member_y(cid, col),
                        mode="lines",
                        line=MEMBER,
                        name="Member",
                        legendgroup="Member",
                        showlegend=first_member,
                    )
                )
                first_member = False

                out.append(
                    go.Scatter(
                        x=timesteps,
                        y=_rep_values(cid, col),
                        mode="lines",
                        line=REP,
                        name="Representative",
                        legendgroup="Representative",
                        showlegend=first_rep,
                    )
                )
                first_rep = False

            return out

        # Build figure with subplots for facets.
        if n_facets > 1:
            fig = make_subplots(rows=1, cols=n_facets, subplot_titles=facet_labels)
        else:
            fig = go.Figure()

        # Initial traces (first animation frame).
        initial = _frame_traces(anim_keys[0])
        if n_facets > 1:
            rows = [1] * len(initial)
            cols_idx = [i // traces_per_facet + 1 for i in range(len(initial))]
            fig.add_traces(initial, rows=rows, cols=cols_idx)
        else:
            fig.add_traces(initial)

        # Animation frames.
        fig.frames = [
            go.Frame(data=_frame_traces(key), name=label)
            for key, label in zip(anim_keys, anim_labels)
        ]

        # Slider.
        steps = [
            {
                "args": [
                    [f.name],
                    {
                        "frame": {"duration": 0, "redraw": True},
                        "mode": "immediate",
                    },
                ],
                "label": f.name,
                "method": "animate",
            }
            for f in fig.frames
        ]
        fig.update_layout(
            sliders=[{"active": 0, "steps": steps}],
            title=title or "Cluster Members",
        )

        # Y-axis scaling.
        if _slider == "cluster":
            # Facets are columns (different units) — independent y-axes,
            # fixed across all cluster frames.
            if n_facets > 1:
                fig.update_yaxes(matches=None, showticklabels=True)
            for i, col in enumerate(columns):
                vals = np.concatenate(
                    [member_arrays[cid][col].ravel() for cid in cluster_ids]
                )
                ymin, ymax = float(np.nanmin(vals)), float(np.nanmax(vals))
                margin = (ymax - ymin) * 0.05
                key = "yaxis" if i == 0 else f"yaxis{i + 1}"
                fig.layout[key].range = [ymin - margin, ymax + margin]
        else:
            # Facets are clusters (same column) — y-axis range adapts per
            # column frame.
            for frame_idx, col in enumerate(columns):
                vals = np.concatenate(
                    [member_arrays[cid][col].ravel() for cid in cluster_ids]
                )
                ymin, ymax = float(np.nanmin(vals)), float(np.nanmax(vals))
                margin = (ymax - ymin) * 0.05
                n_axes = max(n_facets, 1)
                axis_ranges = {}
                for i in range(n_axes):
                    key = "yaxis" if i == 0 else f"yaxis{i + 1}"
                    axis_ranges[key] = {"range": [ymin - margin, ymax + margin]}
                fig.frames[frame_idx].layout = go.Layout(**axis_ranges)
            if fig.frames:
                for key, val in fig.frames[0].layout.to_plotly_json().items():
                    if key.startswith("yaxis"):
                        fig.layout[key].range = val["range"]

        return fig

    def cluster_weights(self, title: str = "Cluster Weights") -> go.Figure:
        """Plot cluster weight distribution.

        Parameters
        ----------
        title : str, default "Cluster Weights"
            Plot title.

        Returns
        -------
        go.Figure
        """
        weights = self._result.cluster_weights
        df = pd.DataFrame(
            {
                "Period": [f"Period {p}" for p in weights],
                "Count": list(weights.values()),
            }
        )

        fig = px.bar(
            df,
            x="Period",
            y="Count",
            title=title,
            text="Count",
            color="Count",
            color_continuous_scale="Viridis",
        )
        fig.update_traces(textposition="auto")
        fig.update_layout(showlegend=False)

        return fig

    def accuracy(self, title: str = "Accuracy Metrics") -> go.Figure:
        """Plot accuracy metrics by column.

        Parameters
        ----------
        title : str, default "Accuracy Metrics"
            Plot title.

        Returns
        -------
        go.Figure
        """
        acc = self._result.accuracy
        columns = list(acc.rmse.index)

        records = []
        for col in columns:
            records.append({"Column": col, "Metric": "RMSE", "Value": acc.rmse[col]})
            records.append({"Column": col, "Metric": "MAE", "Value": acc.mae[col]})
            records.append(
                {
                    "Column": col,
                    "Metric": "RMSE (Duration)",
                    "Value": acc.rmse_duration[col],
                }
            )

        df = pd.DataFrame(records)

        fig = px.bar(
            df,
            x="Column",
            y="Value",
            color="Metric",
            barmode="group",
            title=title,
        )

        return fig

    def segment_durations(self, title: str = "Segment Durations") -> go.Figure:
        """Plot segment durations (if segmentation was used).

        Parameters
        ----------
        title : str, default "Segment Durations"
            Plot title.

        Returns
        -------
        go.Figure

        Raises
        ------
        ValueError
            If no segmentation was used.
        """
        if self._result.segment_durations is None:
            raise ValueError("No segmentation was used in this aggregation")

        # segment_durations is tuple[tuple[int, ...], ...] - one tuple per period
        # Average durations across all typical periods for the bar chart
        durations = self._result.segment_durations

        # Validate uniform structure across periods
        segment_counts = {len(period) for period in durations}
        if len(segment_counts) != 1:
            raise ValueError(
                f"Inconsistent segment counts across periods: {segment_counts}. "
                "Cannot compute average durations."
            )

        n_segments = len(durations[0])
        avg_durations = [
            sum(period[s] for period in durations) / len(durations)
            for s in range(n_segments)
        ]

        df = pd.DataFrame(
            {
                "Segment": [f"Segment {s}" for s in range(n_segments)],
                "Duration": avg_durations,
            }
        )

        fig = px.bar(
            df,
            x="Segment",
            y="Duration",
            title=title,
            text="Duration",
            color="Duration",
            color_continuous_scale="Viridis",
        )
        fig.update_traces(texttemplate="%{text:.1f}", textposition="auto")
        fig.update_layout(showlegend=False, yaxis_title="Duration (timesteps)")

        return fig

    def compare(
        self,
        columns: list[str] | None = None,
        mode: str = "overlay",
        title: str | None = None,
    ) -> go.Figure:
        """Compare original vs reconstructed time series.

        Parameters
        ----------
        columns : list[str], optional
            Columns to compare. If None, compares all columns.
        mode : str, default "overlay"
            Comparison mode:
            - "overlay": Both series on same axes
            - "side_by_side": Separate subplots
            - "duration_curve": Compare sorted values
        title : str, optional
            Plot title.

        Returns
        -------
        go.Figure

        Examples
        --------
        >>> result.plot.compare()  # Compare all columns
        >>> result.plot.compare(columns=["Load"])  # Compare specific column
        >>> result.plot.compare(mode="duration_curve")
        """
        orig = self._result.original
        recon = self._result.reconstructed

        columns = _validate_columns(columns, list(orig.columns), "original data")

        if mode == "duration_curve":
            return _duration_curve_figure(
                {"Original": orig, "Reconstructed": recon},
                columns=columns,
                title=title,
            )

        elif mode in ("overlay", "side_by_side"):
            # Build long-form data with Source (Original/Reconstructed) and Column
            orig_df = orig[columns].copy()
            orig_df["Source"] = "Original"
            recon_df = recon[columns].copy()
            recon_df["Source"] = "Reconstructed"

            combined = pd.concat([orig_df, recon_df])
            combined.index.name = "Time"
            long_df = combined.reset_index().melt(
                id_vars=["Time", "Source"],
                var_name="Column",
                value_name="Value",
            )

            if mode == "overlay":
                # Color by Column, dash by Source (Original/Reconstructed)
                fig = px.line(
                    long_df,
                    x="Time",
                    y="Value",
                    color="Column",
                    line_dash="Source",
                    title=title or "Original vs Reconstructed",
                )
            else:  # side_by_side
                fig = px.line(
                    long_df,
                    x="Time",
                    y="Value",
                    color="Column",
                    facet_row="Source",
                    title=title or "Original vs Reconstructed",
                )
                fig.update_layout(height=600)

            return fig

        else:
            raise ValueError(
                f"Unknown mode: {mode}. Use 'overlay', 'side_by_side', or 'duration_curve'."
            )

    def residuals(
        self,
        columns: list[str] | None = None,
        mode: str = "time_series",
        title: str | None = None,
    ) -> go.Figure:
        """Plot residuals (original - reconstructed).

        Parameters
        ----------
        columns : list[str], optional
            Columns to plot. If None, plots all.
        mode : str, default "time_series"
            Display mode:
            - "time_series": Residuals over time
            - "histogram": Distribution of residuals
            - "by_period": Mean absolute error per period (bar chart)
            - "by_timestep": Mean absolute error by timestep within period
        title : str, optional
            Plot title.

        Returns
        -------
        go.Figure

        Examples
        --------
        >>> result.plot.residuals()  # Time series of residuals
        >>> result.plot.residuals(mode="histogram")  # Error distribution
        >>> result.plot.residuals(mode="by_period")  # Which periods have highest error
        >>> result.plot.residuals(mode="by_timestep")  # Error pattern within day
        """
        resid = self._result.residuals
        columns = _validate_columns(columns, list(resid.columns), "residuals")

        if mode == "time_series":
            df_plot = resid[columns].copy()
            df_plot.index.name = "Time"
            long_df = df_plot.reset_index().melt(
                id_vars=["Time"],
                var_name="Column",
                value_name="Residual",
            )
            fig = px.line(
                long_df,
                x="Time",
                y="Residual",
                color="Column",
                title=title or "Residuals Over Time",
            )
            fig.add_hline(y=0, line_dash="dash", line_color="gray")
            return fig

        elif mode == "histogram":
            long_df = resid[columns].melt(var_name="Column", value_name="Residual")
            fig = px.histogram(
                long_df,
                x="Residual",
                color="Column",
                barmode="overlay",
                opacity=0.7,
                title=title or "Residual Distribution",
            )
            fig.add_vline(x=0, line_dash="dash", line_color="red")
            return fig

        elif mode == "by_period":
            n_timesteps = self._result.n_timesteps_per_period
            abs_resid = resid[columns].abs().copy()
            abs_resid["Period"] = np.arange(len(abs_resid)) // n_timesteps

            df = abs_resid.groupby("Period")[columns].mean().reset_index()
            long_df = df.melt(id_vars="Period", var_name="Column", value_name="MAE")

            fig = px.bar(
                long_df,
                x="Period",
                y="MAE",
                color="Column",
                barmode="group",
                title=title or "Mean Absolute Error by Period",
            )
            return fig

        elif mode == "by_timestep":
            n_timesteps = self._result.n_timesteps_per_period
            abs_resid = resid[columns].abs().copy()
            abs_resid["Timestep"] = np.arange(len(abs_resid)) % n_timesteps

            df = abs_resid.groupby("Timestep")[columns].mean().reset_index()
            long_df = df.melt(id_vars="Timestep", var_name="Column", value_name="MAE")

            fig = px.line(
                long_df,
                x="Timestep",
                y="MAE",
                color="Column",
                title=title or "Mean Absolute Error by Timestep",
            )
            return fig

        else:
            raise ValueError(
                f"Unknown mode: {mode}. Use 'time_series', 'histogram', 'by_period', or 'by_timestep'."
            )

cluster_representatives

cluster_representatives(
    columns: list[str] | None = None,
    title: str = "Cluster Representatives",
) -> go.Figure

Plot all cluster representatives (typical periods).

Parameters:

Name Type Description Default
columns list[str]

Columns to plot.

None
title str

Plot title.

"Cluster Representatives"

Returns:

Type Description
Figure
Source code in src/tsam/plot.py
def cluster_representatives(
    self,
    columns: list[str] | None = None,
    title: str = "Cluster Representatives",
) -> go.Figure:
    """Plot all cluster representatives (typical periods).

    Parameters
    ----------
    columns : list[str], optional
        Columns to plot.
    title : str, default "Cluster Representatives"
        Plot title.

    Returns
    -------
    go.Figure
    """
    typ = self._result.cluster_representatives
    weights = self._result.cluster_weights

    available_columns = [c for c in typ.columns if c not in ["cluster", "timestep"]]
    columns = _validate_columns(
        columns, available_columns, "cluster_representatives"
    )

    # Reset index to get period/timestep as columns
    df = typ[columns].reset_index()
    df.columns = pd.Index(["Period", "Timestep", *columns])

    # Map period IDs to labels with weights
    df["Period"] = df["Period"].map(lambda p: f"Period {p} (n={weights.get(p, 1)})")

    long_df = df.melt(
        id_vars=["Period", "Timestep"],
        var_name="Column",
        value_name="Value",
    )

    fig = px.line(
        long_df,
        x="Timestep",
        y="Value",
        color="Period",
        facet_col="Column" if len(columns) > 1 else None,
        title=title,
    )

    return fig

cluster_members

cluster_members(
    columns: list[str] | None = None,
    clusters: list[int] | None = None,
    slider: Literal["cluster", "column"] = "cluster",
    title: str | None = None,
) -> go.Figure

Plot all original periods grouped by cluster with representative highlighted.

Shows individual member periods as faint lines and the cluster representative as a bold line. A slider lets you flip through either clusters or columns.

Parameters:

Name Type Description Default
columns list[str]

Columns to plot. If None, plots all columns.

None
clusters list[int]

Cluster indices to include. If None, includes all clusters.

None
slider ``"cluster"`` or ``"column"``

Which dimension to put on the slider. The other dimension becomes facet_col.

  • "cluster": slider flips through clusters, columns are facets.
  • "column": slider flips through columns, clusters are facets.
``"cluster"``
title str

Plot title. Defaults to "Cluster Members".

None

Returns:

Type Description
Figure

Examples:

>>> result.plot.cluster_members(columns=["Load"])
>>> result.plot.cluster_members(clusters=[0, 3])  # specific clusters
>>> result.plot.cluster_members(slider="column")  # flip through columns
Source code in src/tsam/plot.py
def cluster_members(
    self,
    columns: list[str] | None = None,
    clusters: list[int] | None = None,
    slider: Literal["cluster", "column"] = "cluster",
    title: str | None = None,
) -> go.Figure:
    """Plot all original periods grouped by cluster with representative highlighted.

    Shows individual member periods as faint lines and the cluster
    representative as a bold line. A slider lets you flip through
    either clusters or columns.

    Parameters
    ----------
    columns : list[str], optional
        Columns to plot. If None, plots all columns.
    clusters : list[int], optional
        Cluster indices to include. If None, includes all clusters.
    slider : ``"cluster"`` or ``"column"``, default ``"cluster"``
        Which dimension to put on the slider.
        The other dimension becomes ``facet_col``.

        - ``"cluster"``: slider flips through clusters, columns are facets.
        - ``"column"``: slider flips through columns, clusters are facets.
    title : str, optional
        Plot title. Defaults to "Cluster Members".

    Returns
    -------
    go.Figure

    Examples
    --------
    >>> result.plot.cluster_members(columns=["Load"])
    >>> result.plot.cluster_members(clusters=[0, 3])  # specific clusters
    >>> result.plot.cluster_members(slider="column")  # flip through columns
    """
    from plotly.subplots import make_subplots

    from tsam.api import unstack_to_periods

    _slider = slider.lower()
    result = self._result
    columns = _validate_columns(
        columns, list(result.original.columns), "original data"
    )
    n_ts = result.n_timesteps_per_period
    idx = result.original.index
    if isinstance(idx, pd.DatetimeIndex) and len(idx) > 1:
        timestep_hours = (idx[1] - idx[0]).total_seconds() / 3600
    else:
        timestep_hours = 1.0
    unstacked = unstack_to_periods(result.original, n_ts * timestep_hours)
    assignments = result.cluster_assignments
    representatives = result.cluster_representatives
    weights = result.cluster_weights
    timesteps = np.arange(n_ts)

    all_cluster_ids = sorted(set(assignments))
    if clusters is not None:
        invalid = [c for c in clusters if c not in all_cluster_ids]
        cluster_ids = [c for c in clusters if c in all_cluster_ids]
        if invalid and cluster_ids:
            warnings.warn(
                f"Cluster indices not found and will be ignored: {invalid}. "
                f"Available clusters: {all_cluster_ids}",
                UserWarning,
                stacklevel=2,
            )
        if not cluster_ids:
            raise ValueError(
                f"None of the requested clusters {clusters} exist. "
                f"Available clusters: {all_cluster_ids}"
            )
    else:
        cluster_ids = all_cluster_ids
    members_by_cluster = {
        cid: np.where(assignments == cid)[0] for cid in cluster_ids
    }

    def _rep_values(cluster_id: int, col: str) -> np.ndarray:
        """Get representative values expanded to full timesteps."""
        rep = representatives.loc[cluster_id]
        if result.n_segments is not None:
            durations = rep.index.get_level_values("Segment Duration").astype(int)
            return np.repeat(rep[col].values, durations)
        return rep[col].values  # type: ignore[no-any-return]

    if _slider not in ("cluster", "column"):
        raise ValueError(f"slider must be 'cluster' or 'column', got {slider!r}")

    # Pre-extract member data as numpy arrays for fast access.
    # member_arrays[cid][col] = 2D array (n_members, n_ts)
    member_arrays: dict[int, dict[str, np.ndarray]] = {}
    for cid in cluster_ids:
        members = members_by_cluster[cid]
        member_arrays[cid] = {
            col: np.asarray(unstacked[col].iloc[members].values) for col in columns
        }

    cluster_labels = {
        cid: f"Cluster {cid} (n={weights.get(cid, 1)})" for cid in cluster_ids
    }

    # Determine which dimension is animated vs faceted.
    anim_keys: list[int | str]
    if _slider == "cluster":
        anim_keys = list(cluster_ids)
        anim_labels = [cluster_labels[c] for c in cluster_ids]
        facet_labels = columns
    else:
        anim_keys = list(columns)
        anim_labels = list(columns)
        facet_labels = [cluster_labels[c] for c in cluster_ids]

    n_facets = len(facet_labels)
    traces_per_facet = 2  # one bundled member trace + one representative
    MEMBER = {"color": "rgba(99, 110, 250, 0.3)"}
    REP = {"color": "#EF553B", "width": 3}

    # Precompute NaN-separated x-arrays (one per unique member count).
    # Each member's timesteps are separated by a NaN to break the line.
    _member_x: dict[int, np.ndarray] = {}
    for cid in cluster_ids:
        n_m = len(members_by_cluster[cid])
        if n_m not in _member_x:
            tile = np.empty(n_ts + 1)
            tile[:n_ts] = timesteps
            tile[n_ts] = np.nan
            _member_x[n_m] = np.tile(tile, n_m)[:-1]

    def _member_y(cid: int, col: str) -> np.ndarray:
        """All members as NaN-separated y-values (vectorized)."""
        data = member_arrays[cid][col]  # (n_members, n_ts)
        padded = np.column_stack([data, np.full(data.shape[0], np.nan)])
        return padded.ravel()[:-1]

    def _frame_traces(anim_key: int | str) -> list[go.Scatter]:
        """Build Scatter traces for one animation frame."""
        out: list[go.Scatter] = []
        first_member = True
        first_rep = True
        for facet_idx in range(n_facets):
            if _slider == "cluster":
                cid, col = cast("int", anim_key), columns[facet_idx]
            else:
                cid, col = cluster_ids[facet_idx], cast("str", anim_key)

            n_m = len(members_by_cluster[cid])
            out.append(
                go.Scatter(
                    x=_member_x[n_m],
                    y=_member_y(cid, col),
                    mode="lines",
                    line=MEMBER,
                    name="Member",
                    legendgroup="Member",
                    showlegend=first_member,
                )
            )
            first_member = False

            out.append(
                go.Scatter(
                    x=timesteps,
                    y=_rep_values(cid, col),
                    mode="lines",
                    line=REP,
                    name="Representative",
                    legendgroup="Representative",
                    showlegend=first_rep,
                )
            )
            first_rep = False

        return out

    # Build figure with subplots for facets.
    if n_facets > 1:
        fig = make_subplots(rows=1, cols=n_facets, subplot_titles=facet_labels)
    else:
        fig = go.Figure()

    # Initial traces (first animation frame).
    initial = _frame_traces(anim_keys[0])
    if n_facets > 1:
        rows = [1] * len(initial)
        cols_idx = [i // traces_per_facet + 1 for i in range(len(initial))]
        fig.add_traces(initial, rows=rows, cols=cols_idx)
    else:
        fig.add_traces(initial)

    # Animation frames.
    fig.frames = [
        go.Frame(data=_frame_traces(key), name=label)
        for key, label in zip(anim_keys, anim_labels)
    ]

    # Slider.
    steps = [
        {
            "args": [
                [f.name],
                {
                    "frame": {"duration": 0, "redraw": True},
                    "mode": "immediate",
                },
            ],
            "label": f.name,
            "method": "animate",
        }
        for f in fig.frames
    ]
    fig.update_layout(
        sliders=[{"active": 0, "steps": steps}],
        title=title or "Cluster Members",
    )

    # Y-axis scaling.
    if _slider == "cluster":
        # Facets are columns (different units) — independent y-axes,
        # fixed across all cluster frames.
        if n_facets > 1:
            fig.update_yaxes(matches=None, showticklabels=True)
        for i, col in enumerate(columns):
            vals = np.concatenate(
                [member_arrays[cid][col].ravel() for cid in cluster_ids]
            )
            ymin, ymax = float(np.nanmin(vals)), float(np.nanmax(vals))
            margin = (ymax - ymin) * 0.05
            key = "yaxis" if i == 0 else f"yaxis{i + 1}"
            fig.layout[key].range = [ymin - margin, ymax + margin]
    else:
        # Facets are clusters (same column) — y-axis range adapts per
        # column frame.
        for frame_idx, col in enumerate(columns):
            vals = np.concatenate(
                [member_arrays[cid][col].ravel() for cid in cluster_ids]
            )
            ymin, ymax = float(np.nanmin(vals)), float(np.nanmax(vals))
            margin = (ymax - ymin) * 0.05
            n_axes = max(n_facets, 1)
            axis_ranges = {}
            for i in range(n_axes):
                key = "yaxis" if i == 0 else f"yaxis{i + 1}"
                axis_ranges[key] = {"range": [ymin - margin, ymax + margin]}
            fig.frames[frame_idx].layout = go.Layout(**axis_ranges)
        if fig.frames:
            for key, val in fig.frames[0].layout.to_plotly_json().items():
                if key.startswith("yaxis"):
                    fig.layout[key].range = val["range"]

    return fig

cluster_weights

cluster_weights(
    title: str = "Cluster Weights",
) -> go.Figure

Plot cluster weight distribution.

Parameters:

Name Type Description Default
title str

Plot title.

"Cluster Weights"

Returns:

Type Description
Figure
Source code in src/tsam/plot.py
def cluster_weights(self, title: str = "Cluster Weights") -> go.Figure:
    """Plot cluster weight distribution.

    Parameters
    ----------
    title : str, default "Cluster Weights"
        Plot title.

    Returns
    -------
    go.Figure
    """
    weights = self._result.cluster_weights
    df = pd.DataFrame(
        {
            "Period": [f"Period {p}" for p in weights],
            "Count": list(weights.values()),
        }
    )

    fig = px.bar(
        df,
        x="Period",
        y="Count",
        title=title,
        text="Count",
        color="Count",
        color_continuous_scale="Viridis",
    )
    fig.update_traces(textposition="auto")
    fig.update_layout(showlegend=False)

    return fig

accuracy

accuracy(title: str = 'Accuracy Metrics') -> go.Figure

Plot accuracy metrics by column.

Parameters:

Name Type Description Default
title str

Plot title.

"Accuracy Metrics"

Returns:

Type Description
Figure
Source code in src/tsam/plot.py
def accuracy(self, title: str = "Accuracy Metrics") -> go.Figure:
    """Plot accuracy metrics by column.

    Parameters
    ----------
    title : str, default "Accuracy Metrics"
        Plot title.

    Returns
    -------
    go.Figure
    """
    acc = self._result.accuracy
    columns = list(acc.rmse.index)

    records = []
    for col in columns:
        records.append({"Column": col, "Metric": "RMSE", "Value": acc.rmse[col]})
        records.append({"Column": col, "Metric": "MAE", "Value": acc.mae[col]})
        records.append(
            {
                "Column": col,
                "Metric": "RMSE (Duration)",
                "Value": acc.rmse_duration[col],
            }
        )

    df = pd.DataFrame(records)

    fig = px.bar(
        df,
        x="Column",
        y="Value",
        color="Metric",
        barmode="group",
        title=title,
    )

    return fig

segment_durations

segment_durations(
    title: str = "Segment Durations",
) -> go.Figure

Plot segment durations (if segmentation was used).

Parameters:

Name Type Description Default
title str

Plot title.

"Segment Durations"

Returns:

Type Description
Figure

Raises:

Type Description
ValueError

If no segmentation was used.

Source code in src/tsam/plot.py
def segment_durations(self, title: str = "Segment Durations") -> go.Figure:
    """Plot segment durations (if segmentation was used).

    Parameters
    ----------
    title : str, default "Segment Durations"
        Plot title.

    Returns
    -------
    go.Figure

    Raises
    ------
    ValueError
        If no segmentation was used.
    """
    if self._result.segment_durations is None:
        raise ValueError("No segmentation was used in this aggregation")

    # segment_durations is tuple[tuple[int, ...], ...] - one tuple per period
    # Average durations across all typical periods for the bar chart
    durations = self._result.segment_durations

    # Validate uniform structure across periods
    segment_counts = {len(period) for period in durations}
    if len(segment_counts) != 1:
        raise ValueError(
            f"Inconsistent segment counts across periods: {segment_counts}. "
            "Cannot compute average durations."
        )

    n_segments = len(durations[0])
    avg_durations = [
        sum(period[s] for period in durations) / len(durations)
        for s in range(n_segments)
    ]

    df = pd.DataFrame(
        {
            "Segment": [f"Segment {s}" for s in range(n_segments)],
            "Duration": avg_durations,
        }
    )

    fig = px.bar(
        df,
        x="Segment",
        y="Duration",
        title=title,
        text="Duration",
        color="Duration",
        color_continuous_scale="Viridis",
    )
    fig.update_traces(texttemplate="%{text:.1f}", textposition="auto")
    fig.update_layout(showlegend=False, yaxis_title="Duration (timesteps)")

    return fig

compare

compare(
    columns: list[str] | None = None,
    mode: str = "overlay",
    title: str | None = None,
) -> go.Figure

Compare original vs reconstructed time series.

Parameters:

Name Type Description Default
columns list[str]

Columns to compare. If None, compares all columns.

None
mode str

Comparison mode: - "overlay": Both series on same axes - "side_by_side": Separate subplots - "duration_curve": Compare sorted values

"overlay"
title str

Plot title.

None

Returns:

Type Description
Figure

Examples:

>>> result.plot.compare()  # Compare all columns
>>> result.plot.compare(columns=["Load"])  # Compare specific column
>>> result.plot.compare(mode="duration_curve")
Source code in src/tsam/plot.py
def compare(
    self,
    columns: list[str] | None = None,
    mode: str = "overlay",
    title: str | None = None,
) -> go.Figure:
    """Compare original vs reconstructed time series.

    Parameters
    ----------
    columns : list[str], optional
        Columns to compare. If None, compares all columns.
    mode : str, default "overlay"
        Comparison mode:
        - "overlay": Both series on same axes
        - "side_by_side": Separate subplots
        - "duration_curve": Compare sorted values
    title : str, optional
        Plot title.

    Returns
    -------
    go.Figure

    Examples
    --------
    >>> result.plot.compare()  # Compare all columns
    >>> result.plot.compare(columns=["Load"])  # Compare specific column
    >>> result.plot.compare(mode="duration_curve")
    """
    orig = self._result.original
    recon = self._result.reconstructed

    columns = _validate_columns(columns, list(orig.columns), "original data")

    if mode == "duration_curve":
        return _duration_curve_figure(
            {"Original": orig, "Reconstructed": recon},
            columns=columns,
            title=title,
        )

    elif mode in ("overlay", "side_by_side"):
        # Build long-form data with Source (Original/Reconstructed) and Column
        orig_df = orig[columns].copy()
        orig_df["Source"] = "Original"
        recon_df = recon[columns].copy()
        recon_df["Source"] = "Reconstructed"

        combined = pd.concat([orig_df, recon_df])
        combined.index.name = "Time"
        long_df = combined.reset_index().melt(
            id_vars=["Time", "Source"],
            var_name="Column",
            value_name="Value",
        )

        if mode == "overlay":
            # Color by Column, dash by Source (Original/Reconstructed)
            fig = px.line(
                long_df,
                x="Time",
                y="Value",
                color="Column",
                line_dash="Source",
                title=title or "Original vs Reconstructed",
            )
        else:  # side_by_side
            fig = px.line(
                long_df,
                x="Time",
                y="Value",
                color="Column",
                facet_row="Source",
                title=title or "Original vs Reconstructed",
            )
            fig.update_layout(height=600)

        return fig

    else:
        raise ValueError(
            f"Unknown mode: {mode}. Use 'overlay', 'side_by_side', or 'duration_curve'."
        )

residuals

residuals(
    columns: list[str] | None = None,
    mode: str = "time_series",
    title: str | None = None,
) -> go.Figure

Plot residuals (original - reconstructed).

Parameters:

Name Type Description Default
columns list[str]

Columns to plot. If None, plots all.

None
mode str

Display mode: - "time_series": Residuals over time - "histogram": Distribution of residuals - "by_period": Mean absolute error per period (bar chart) - "by_timestep": Mean absolute error by timestep within period

"time_series"
title str

Plot title.

None

Returns:

Type Description
Figure

Examples:

>>> result.plot.residuals()  # Time series of residuals
>>> result.plot.residuals(mode="histogram")  # Error distribution
>>> result.plot.residuals(mode="by_period")  # Which periods have highest error
>>> result.plot.residuals(mode="by_timestep")  # Error pattern within day
Source code in src/tsam/plot.py
def residuals(
    self,
    columns: list[str] | None = None,
    mode: str = "time_series",
    title: str | None = None,
) -> go.Figure:
    """Plot residuals (original - reconstructed).

    Parameters
    ----------
    columns : list[str], optional
        Columns to plot. If None, plots all.
    mode : str, default "time_series"
        Display mode:
        - "time_series": Residuals over time
        - "histogram": Distribution of residuals
        - "by_period": Mean absolute error per period (bar chart)
        - "by_timestep": Mean absolute error by timestep within period
    title : str, optional
        Plot title.

    Returns
    -------
    go.Figure

    Examples
    --------
    >>> result.plot.residuals()  # Time series of residuals
    >>> result.plot.residuals(mode="histogram")  # Error distribution
    >>> result.plot.residuals(mode="by_period")  # Which periods have highest error
    >>> result.plot.residuals(mode="by_timestep")  # Error pattern within day
    """
    resid = self._result.residuals
    columns = _validate_columns(columns, list(resid.columns), "residuals")

    if mode == "time_series":
        df_plot = resid[columns].copy()
        df_plot.index.name = "Time"
        long_df = df_plot.reset_index().melt(
            id_vars=["Time"],
            var_name="Column",
            value_name="Residual",
        )
        fig = px.line(
            long_df,
            x="Time",
            y="Residual",
            color="Column",
            title=title or "Residuals Over Time",
        )
        fig.add_hline(y=0, line_dash="dash", line_color="gray")
        return fig

    elif mode == "histogram":
        long_df = resid[columns].melt(var_name="Column", value_name="Residual")
        fig = px.histogram(
            long_df,
            x="Residual",
            color="Column",
            barmode="overlay",
            opacity=0.7,
            title=title or "Residual Distribution",
        )
        fig.add_vline(x=0, line_dash="dash", line_color="red")
        return fig

    elif mode == "by_period":
        n_timesteps = self._result.n_timesteps_per_period
        abs_resid = resid[columns].abs().copy()
        abs_resid["Period"] = np.arange(len(abs_resid)) // n_timesteps

        df = abs_resid.groupby("Period")[columns].mean().reset_index()
        long_df = df.melt(id_vars="Period", var_name="Column", value_name="MAE")

        fig = px.bar(
            long_df,
            x="Period",
            y="MAE",
            color="Column",
            barmode="group",
            title=title or "Mean Absolute Error by Period",
        )
        return fig

    elif mode == "by_timestep":
        n_timesteps = self._result.n_timesteps_per_period
        abs_resid = resid[columns].abs().copy()
        abs_resid["Timestep"] = np.arange(len(abs_resid)) % n_timesteps

        df = abs_resid.groupby("Timestep")[columns].mean().reset_index()
        long_df = df.melt(id_vars="Timestep", var_name="Column", value_name="MAE")

        fig = px.line(
            long_df,
            x="Timestep",
            y="MAE",
            color="Column",
            title=title or "Mean Absolute Error by Timestep",
        )
        return fig

    else:
        raise ValueError(
            f"Unknown mode: {mode}. Use 'time_series', 'histogram', 'by_period', or 'by_timestep'."
        )