-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Expand file tree
/
Copy pathmosaic_memory_profiling_tutorial.py
More file actions
1239 lines (1052 loc) · 41.5 KB
/
mosaic_memory_profiling_tutorial.py
File metadata and controls
1239 lines (1052 loc) · 41.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""
Mosaic: Memory Profiling for PyTorch
====================================
.. meta::
:description: Learn how to use Mosaic for PyTorch GPU memory profiling. Capture and analyze memory snapshots, identify memory savings from activation checkpointing, debug OOM errors, and integrate memory analysis into training pipelines.
:keywords: PyTorch, Mosaic, memory profiling, GPU memory, CUDA, activation checkpointing, OOM debugging, deep learning, memory optimization, memory snapshots, distributed training, LLaMA, transformer models
**Author:** `Basil Wong <https://github.com/basilwong>`_
.. grid:: 2
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites
* How to capture and analyze PyTorch memory snapshots
* Identify memory savings from activation checkpointing
* Debug unexpected memory usage from abandoned code
* Integrate memory analysis into training pipelines
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites
* PyTorch v2.0.0 or later
* CUDA-capable GPU
* Basic understanding of PyTorch training loops
This tutorial demonstrates how to use `Mosaic <https://github.com/facebookresearch/mosaic>`_, a post-processing memory
snapshot analysis tool for PyTorch. Mosaic helps analyze GPU memory usage in
distributed deep learning, providing detailed insights into memory allocations,
peak usage, and memory imbalances across parallel workers.
Mosaic was instrumental in debugging OOM issues during the
`405B LLaMA training <https://ai.meta.com/blog/meta-llama-3-1/>`_
and is now open source.
"""
######################################################################
# Introduction to Mosaic
# ======================
#
# Overview
# --------
#
# In distributed deep learning, understanding GPU memory usage is critical
# for optimizing training efficiency and debugging Out-of-Memory (OOM) errors.
# Mosaic is a post-analysis tool for memory usage designed to work with
# large-scale jobs. It helps analyze PyTorch memory snapshots captured during
# the execution of PyTorch training jobs, providing detailed insights into
# memory allocations, peak usage, and memory imbalances across parallel workers.
#
# Getting Started
# ---------------
#
# Clone the mosaic repository and install from the mosaic directory:
#
# .. code-block:: bash
#
# git clone https://github.com/facebookresearch/mosaic
# cd mosaic
# python3 -m venv venv
# source venv/bin/activate
# pip3 install -r requirements.txt
# pip3 install -e .
#
# Alternatively, install directly via pip:
#
# .. code-block:: bash
#
# pip install git+https://github.com/facebookresearch/mosaic.git
#
# Simple Usage Examples
# ---------------------
#
# **1. Peak Memory Usage Analysis**
#
# When addressing memory problems like OOM errors, focusing on peak memory
# usage is crucial. The ``mosaic_get_memory_usage_peak`` command presents a
# stack trace of the memory allocations that contributed to the peak memory
# usage:
#
# .. code-block:: bash
#
# mosaic_get_memory_usage_peak --snapshot <path to snapshot>
#
# **2. Categorical Memory Profiling**
#
# Mosaic classifies allocations into categories (activation, backward,
# optimizer, etc.):
#
# - **Activation Memory:** Tensors saved for backward pass
# - **Gradient Memory:** Gradients computed during backpropagation
# - **Optimizer State:** Adam/SGD momentum and variance buffers
# - **Parameter Memory:** Model weights
#
# .. code-block:: bash
#
# mosaic_get_memory_profile --snapshot <path> --out-path <html> \
# --profile categories
#
# An example HTML output looks like:
#
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-no-allocation-ordering.png
# :alt: Mosaic categorical memory profiling without allocation ordering
# :align: center
# :width: 600px
#
# Categorical memory profiling showing memory breakdown by type
# (activation, gradient, optimizer, etc.)
#
# To maintain allocation order for the categories, add ``--preserve-allocation-order``:
#
# .. code-block:: bash
#
# mosaic_get_memory_profile --snapshot <path> --out-path <html> \
# --profile categories --preserve-allocation-order
#
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-allocation-ordering.png
# :alt: Mosaic categorical memory profiling with allocation ordering preserved
# :align: center
# :width: 600px
#
# Categorical profiling with ``--preserve-allocation-order`` shows memory
# allocations in chronological order
#
# **3. Custom Dictionary Profiling**
#
# For targeted analysis via regex pattern matching:
#
# .. code-block:: bash
#
# mosaic_get_memory_profile --snapshot <path> --profile custom \
# --custom-profile '{"ncclx": "ncclx"}'
#
# This is invaluable for tracking specific kernels, optimizers, or custom code patterns:
#
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-ncclx.png
# :alt: Mosaic custom dictionary profiling with ncclx pattern
# :align: center
# :width: 600px
#
# Custom profiling with regex patterns to track specific operations like
# NCCL communications
#
######################################################################
# Dependencies and Imports
# ========================
#
# Let's set up the required dependencies and imports for this tutorial.
import subprocess
import sys
import shutil
from contextlib import contextmanager
import pickle
# Fix for sphinx-gallery environment where __main__.__file__ may not exist
# This is needed for transformers library compatibility
import os
if not hasattr(sys.modules["__main__"], "__file__"):
# Use this file's path as a fallback, or a dummy path if __file__ is not available
try:
sys.modules["__main__"].__file__ = os.path.abspath(__file__)
except NameError:
# __file__ not available, use transformers modeling file as fallback
import transformers.modeling_utils
sys.modules["__main__"].__file__ = transformers.modeling_utils.__file__
import torch
from torch.utils.data import DataLoader, Dataset
# Install dependencies if needed
try:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
except ImportError:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-q", "transformers"]
)
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
try:
from mosaic.libmosaic.analyzer.memory_abstract import MemoryAbstract
except ImportError:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
"git+https://github.com/facebookresearch/mosaic.git",
]
)
from mosaic.libmosaic.analyzer.memory_abstract import MemoryAbstract
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
######################################################################
# Shared Utilities
# ================
#
# These helper classes and functions are used throughout the tutorial.
class RandomTokenDataset(Dataset):
"""Generates random token sequences for training.
This dataset creates random input sequences suitable for language model
training, simulating real training data without requiring actual text.
"""
def __init__(self, vocab_size, seq_length=512, num_samples=100, seed=None):
self.vocab_size = vocab_size
self.seq_length = seq_length
self.num_samples = num_samples
self.generator = None
if seed is not None:
self.generator = torch.Generator().manual_seed(seed)
def __len__(self):
return self.num_samples
def __getitem__(self, idx): # noqa: ARG002
if self.generator is not None:
input_ids = torch.randint(
0, self.vocab_size, (self.seq_length,), generator=self.generator
)
else:
input_ids = torch.randint(0, self.vocab_size, (self.seq_length,))
return {"input_ids": input_ids, "labels": input_ids.clone()}
@contextmanager
def capture_memory_snapshot(output_path):
"""Context manager to capture and save PyTorch CUDA memory snapshots.
This captures all GPU memory allocations during the context and saves
them to a pickle file for later analysis with Mosaic.
Args:
output_path: Path to save the memory snapshot pickle file.
"""
torch.cuda.memory._record_memory_history(max_entries=100000)
try:
yield
finally:
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._record_memory_history(enabled=None)
with open(output_path, "wb") as f:
pickle.dump(snapshot, f)
print(f"✓ Memory snapshot saved to {output_path}")
######################################################################
# Case 1: Understanding Memory Differences with Activation Checkpointing
# =======================================================================
#
# This section demonstrates how to use Mosaic to analyze and compare GPU
# memory usage between different model configurations.
#
# **What we'll do:**
#
# 1. Train GPT-2 and capture a memory snapshot (baseline)
# 2. Enable activation checkpointing and train again (modified)
# 3. Use Mosaic to identify exactly where memory savings occur
#
######################################################################
# Training Function for Activation Checkpointing Comparison
# ----------------------------------------------------------
def run_training_ac(
activation_checkpointing: bool,
snapshot_path: str,
batch_size: int = 4,
seq_length: int = 512,
num_steps: int = 5,
):
"""Run training loop and capture memory snapshot.
Args:
activation_checkpointing: Whether to enable gradient checkpointing.
snapshot_path: Path to save the memory snapshot.
batch_size: Training batch size.
seq_length: Sequence length for input tokens.
num_steps: Number of training steps to run.
Returns:
Peak GPU memory usage in GB.
"""
# Clear any previous memory
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
# Load model
print(f"Loading GPT-2 (activation_checkpointing={activation_checkpointing})...")
model = GPT2LMHeadModel.from_pretrained("gpt2")
if activation_checkpointing:
model.gradient_checkpointing_enable()
print("Activation checkpointing is ENABLED")
else:
print("Activation checkpointing is DISABLED")
model = model.to(device)
model.train()
# Create dataset and dataloader
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size,
seq_length=seq_length,
num_samples=100,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Training loop with memory capture
print(f"Running {num_steps} training steps...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}/{num_steps}, Loss: {loss.item():.4f}")
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print(f"✓ Peak GPU memory: {peak_memory_gb:.2f} GB")
# Cleanup
del model, optimizer
torch.cuda.empty_cache()
return peak_memory_gb
######################################################################
# Run Baseline Training (Without Activation Checkpointing)
# ---------------------------------------------------------
#
# .. note::
#
# This tutorial requires a CUDA-capable GPU. If you're running in
# Google Colab, make sure to select a GPU runtime:
# Runtime → Change runtime type → Hardware accelerator → GPU
if not torch.cuda.is_available():
print("=" * 60)
print("WARNING: No CUDA GPU detected!")
print("=" * 60)
print("\nThis tutorial requires a CUDA-capable GPU for memory profiling.")
print("\nIf you're running in Google Colab:")
print(" 1. Go to Runtime → Change runtime type")
print(" 2. Set Hardware accelerator to 'GPU'")
print(" 3. Click 'Save' and re-run the notebook")
print("\nSkipping GPU memory profiling examples...")
HAS_CUDA = False
else:
HAS_CUDA = True
# Check if Mosaic CLI is available
HAS_MOSAIC_CLI = shutil.which("mosaic_get_memory_profile") is not None
if HAS_CUDA and not HAS_MOSAIC_CLI:
print("Note: Mosaic CLI not found. Install Mosaic to generate HTML profiles.")
print(" pip install git+https://github.com/facebookresearch/mosaic.git")
if HAS_CUDA:
print("=" * 60)
print("BASELINE: Training WITHOUT Activation Checkpointing")
print("=" * 60)
baseline_memory = run_training_ac(
activation_checkpointing=False,
snapshot_path="snapshot_baseline.pickle",
batch_size=4,
seq_length=512,
num_steps=5,
)
######################################################################
# Run Modified Training (With Activation Checkpointing)
# ------------------------------------------------------
if HAS_CUDA:
print("\n" + "=" * 60)
print("MODIFIED: Training WITH Activation Checkpointing")
print("=" * 60)
ac_memory = run_training_ac(
activation_checkpointing=True,
snapshot_path="snapshot_with_ac.pickle",
batch_size=4,
seq_length=512,
num_steps=5,
)
# Summary
print("\n" + "=" * 60)
print("MEMORY COMPARISON SUMMARY")
print("=" * 60)
print(f"Baseline (no AC): {baseline_memory:.2f} GB")
print(f"With AC: {ac_memory:.2f} GB")
if baseline_memory > 0:
saved_pct = 100 * (baseline_memory - ac_memory) / baseline_memory
print(
f"Memory Saved: {baseline_memory - ac_memory:.2f} GB ({saved_pct:.1f}%)"
)
######################################################################
# Generate Categorical Memory Profiles with Mosaic
# -------------------------------------------------
#
# Use Mosaic to generate HTML profiles for both snapshots.
if HAS_CUDA and HAS_MOSAIC_CLI:
print("\n" + "=" * 60)
print("MOSAIC: Categorical Memory Profiling")
print("=" * 60)
# Generate HTML profiles using subprocess
print("\nGenerating baseline profile...")
result1 = subprocess.run(
[
"mosaic_get_memory_profile",
"--snapshot",
"snapshot_baseline.pickle",
"--out-path",
"profile_baseline.html",
"--profile",
"categories",
"--preserve-allocation-order",
"--plotter_sampling_rate",
"20",
],
capture_output=True,
text=True,
)
print(result1.stdout)
if result1.stderr:
print(result1.stderr)
print("\nGenerating activation checkpointing profile...")
result2 = subprocess.run(
[
"mosaic_get_memory_profile",
"--snapshot",
"snapshot_with_ac.pickle",
"--out-path",
"profile_with_ac.html",
"--profile",
"categories",
"--preserve-allocation-order",
"--plotter_sampling_rate",
"20",
],
capture_output=True,
text=True,
)
print(result2.stdout)
if result2.stderr:
print(result2.stderr)
if result1.returncode == 0 and result2.returncode == 0:
print("\nGenerated profile_baseline.html")
print("Generated profile_with_ac.html")
print("\nDownload these files to view the interactive memory profiles.")
else:
print("\nNote: Mosaic profile generation encountered issues.")
print("This may happen if running in an environment without full Mosaic support.")
######################################################################
# Download Generated Files (Google Colab)
# ----------------------------------------
#
# If running in Google Colab, uncomment the following lines to download
# the generated snapshot and profile files:
# from google.colab import files
#
# print("Downloading memory snapshots and profiles...")
# files.download('snapshot_baseline.pickle')
# files.download('snapshot_with_ac.pickle')
# files.download('profile_baseline.html')
# files.download('profile_with_ac.html')
######################################################################
# Results Interpretation: Activation Checkpointing
# -------------------------------------------------
#
# The generated HTML profiles visualize memory usage over time, with
# allocations colored by category. Here's what the profiles look like:
#
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-without-ac.png
# :alt: GPT-2 memory profile without activation checkpointing
# :align: center
# :width: 600px
#
# **Baseline (without activation checkpointing):** Notice the large
# activation memory (shown in one color) that persists throughout
# the forward pass.
#
# .. figure:: /_static/img/mosaic/mosaic-categorical-memory-profiling-gpt2-with-ac.png
# :alt: GPT-2 memory profile with activation checkpointing
# :align: center
# :width: 600px
#
# **With activation checkpointing:** Activation memory is significantly
# reduced as intermediate activations are discarded and recomputed
# during the backward pass.
#
# What We Observed
# ~~~~~~~~~~~~~~~~
#
# Based on the Mosaic categorical profiling results:
#
# .. list-table:: Memory Comparison Results
# :header-rows: 1
#
# * - Metric
# - Baseline
# - With Activation Checkpointing
# - Difference
# * - **Total Peak Memory**
# - **4.62 GB**
# - **2.55 GB**
# - **2.07 GB (45% reduction)**
# * - Activation Memory
# - 2.93 GB
# - 872.79 MB
# - **2.08 GB saved (71% reduction)**
# * - Backward/Gradient Memory
# - 793.39 MB
# - 785.27 MB
# - 8 MB (minimal change)
# * - Optimizer State
# - 949.4 MB
# - 949.4 MB
# - No change
# * - Unknown
# - 32 KB
# - 32 KB
# - No change
#
# Key Insights
# ~~~~~~~~~~~~
#
# **Primary Finding:** Activation memory dropped from **2.93 GB → 872 MB**
# (71% reduction), which accounts for nearly all the total memory savings.
#
# Why Does This Happen?
# ~~~~~~~~~~~~~~~~~~~~~
#
# **Activation checkpointing** is a memory optimization technique that:
#
# 1. **Without AC (Baseline):** All intermediate activations from the forward
# pass are stored in memory for use during backpropagation. GPT-2 has 12
# transformer layers, each storing multiple activations (attention outputs,
# MLP outputs, etc.). For batch_size=4, seq_length=512, this adds up quickly.
#
# 2. **With AC (Optimized):** Only activations at checkpoint boundaries are
# stored; intermediate activations are recomputed during the backward pass.
# This dramatically reduces activation memory (71% in our case) while other
# memory categories remain unchanged.
#
# How Mosaic Helped
# ~~~~~~~~~~~~~~~~~
#
# Mosaic's categorical profiling immediately identified:
#
# - Activation memory is the category with the largest difference (2.08 GB saved)
# - Backward/Gradient memory stayed nearly constant (793 MB → 785 MB)
# - Optimizer state remained unchanged (949 MB) - expected since model
# parameters don't change
#
# **Without Mosaic:** You would need to manually instrument your code, track
# allocations, and categorize them yourself.
#
# **With Mosaic:** You get instant categorical breakdowns with exact numbers,
# making it trivial to identify/quantify memory optimizations.
#
######################################################################
# Case 2: Debugging Unexpected Memory Usage
# ==========================================
#
# This section demonstrates how to use Mosaic to debug when your model is
# using more memory than expected and you're not sure why.
#
# **What we'll do:**
#
# 1. Train GPT-2 and capture a memory snapshot.
# 2. Train GPT-2 with a bug that introduces additional memory and capture
# a memory snapshot.
# 3. Use Mosaic to identify potential culprits introducing additional memory.
#
######################################################################
# The Buggy Model
# ---------------
#
# This model has **abandoned debug code** that creates unnecessary GPU memory
# overhead. Someone added projection layers to "analyze hidden states" during
# debugging, but forgot to remove them before training.
class GPT2WithDebugOverhead(torch.nn.Module):
"""GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
This wrapper adds extra projection layers that consume memory but serve no
purpose - simulating abandoned debug code that was never cleaned up.
"""
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
config = base_model.config
# BUG: Large projection layers from an abandoned experiment
self.debug_projections = torch.nn.ModuleList(
[
torch.nn.Linear(config.n_embd, config.n_embd * 4)
for _ in range(config.n_layer)
]
)
debug_params = sum(p.numel() for p in self.debug_projections.parameters())
print(f" [DEBUG] Added {config.n_layer} debug projection layers")
print(f" [DEBUG] Extra parameters: {debug_params:,}")
def forward(self, input_ids=None, labels=None, **kwargs):
# Run normal GPT-2 forward with hidden states
outputs = self.base_model(
input_ids=input_ids,
labels=labels,
output_hidden_states=True,
**kwargs,
)
# BUG: Project all hidden states through debug layers
projected = []
for _layer_idx, (hidden, proj) in enumerate(
zip(outputs.hidden_states[1:], self.debug_projections)
):
proj_hidden = proj(hidden)
projected.append(proj_hidden)
# Tie to loss so gradients flow through
debug_regularization = sum(p.mean() for p in projected) * 1e-10
return CausalLMOutputWithCrossAttentions(
loss=outputs.loss + debug_regularization,
logits=outputs.logits,
)
######################################################################
# Training Functions for Debug Comparison
# ----------------------------------------
def run_training_clean(snapshot_path, num_steps=3):
"""Training with the normal model."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
print("Loading clean model (no debug overhead)...")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size, seq_length=512, seed=42
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print("Running training (should contain no debug overhead)...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}, Loss: {loss.item():.4f}")
peak_memory = torch.cuda.max_memory_allocated() / 1024**3
print(f"Peak GPU memory: {peak_memory:.2f} GB")
del model, optimizer
torch.cuda.empty_cache()
return peak_memory
def run_training_with_bug(snapshot_path, num_steps=3):
"""Training with the buggy model."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
print("Loading buggy model with debug overhead...")
# Load pretrained GPT-2 and wrap it with the debug overhead
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
model = GPT2WithDebugOverhead(base_model).to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size, seq_length=512, seed=42
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print("Running training (WITH debug overhead bug)...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}, Loss: {loss.item():.4f}")
peak_memory = torch.cuda.max_memory_allocated() / 1024**3
print(f"Peak GPU memory: {peak_memory:.2f} GB")
del model, optimizer
torch.cuda.empty_cache()
return peak_memory
######################################################################
# Run Training for Baseline (Clean Model)
# ----------------------------------------
if HAS_CUDA:
print("\n" + "=" * 60)
print("Training with baseline model")
print("=" * 60)
baseline_memory_debug = run_training_clean(
"snapshot_debug_baseline.pickle", num_steps=3
)
######################################################################
# Run Training WITH the Bug
# --------------------------
if HAS_CUDA:
print("\n" + "=" * 60)
print("Training with debug projection overhead (BUG)")
print("=" * 60)
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
######################################################################
# Use Mosaic to Find the Problem
# -------------------------------
#
# Analyze both snapshots to identify the source of extra memory usage.
# We'll run Mosaic's peak memory analysis on each snapshot separately.
######################################################################
# Analyze the Baseline (Clean) Snapshot
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if HAS_CUDA and HAS_MOSAIC_CLI:
print("=" * 60)
print("MOSAIC: Analyzing the Baseline Snapshot")
print("=" * 60)
result = subprocess.run(
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_debug_baseline.pickle"],
capture_output=True,
text=True,
)
print(result.stdout)
if result.stderr:
print(result.stderr)
######################################################################
# Analyze the Buggy Snapshot
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
if HAS_CUDA and HAS_MOSAIC_CLI:
print("=" * 60)
print("MOSAIC: Analyzing the Buggy Snapshot")
print("=" * 60)
result = subprocess.run(
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_with_bug.pickle"],
capture_output=True,
text=True,
)
print(result.stdout)
if result.stderr:
print(result.stderr)
######################################################################
# Analyzing The Mosaic Output
# ----------------------------
#
# When you run Mosaic's peak memory analysis, it shows stack traces for each
# memory allocation. Let's look at how to find abandoned or unnecessary code
# that's bloating the memory.
#
# **1. Optimizer State Allocations Delta**
#
# In the buggy snapshot output, we can see that the first two stack traces
# represent the **optimizer state allocations** (like ``zeros_like`` for Adam
# optimizer state). See ``torch/optim/adam.py`` in the stack trace.
#
# In the snapshot of the buggy model we can see around a total of 0.21 GB
# more memory:
#
# .. list-table:: Optimizer State Comparison
# :header-rows: 1
#
# * - Version
# - Stack Trace Position
# - Calls
# - Memory (per trace)
# * - Buggy model
# - 1st and 2nd
# - 172 calls
# - 0.569 GB + 0.569 GB
# * - Baseline
# - 2nd and 3rd
# - 148 calls
# - 0.464 GB + 0.464 GB
#
# What this tells us: The optimizer is tracking more tensors! This is your
# first clue that there are extra parameters or tensors in the computation graph.
#
# **2. Additional Activation Allocations**
#
# The buggy version shows **extra allocations** that don't appear in the
# baseline model. Scrolling down the Mosaic output of the buggy model we can
# see additional stack traces which contain:
#
# 1. ``torch::autograd::Engine::evaluate_function``: We're in the backward pass
# 2. ``AddmmBackward0::apply``: Computing gradients for an addmm operation
# 3. ``empty_cuda`` at the bottom: Allocating a new CUDA tensor to store
# the gradient
#
# - 0.176 GB from matrix multiply gradients (``AddmmBackward0``, ``mm_mat1_backward``)
#
# Memory Total Explanation
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# **Total Peak Dynamic Memory Usage:** This is the peak memory that changes
# during execution, measured relative to the starting point of the snapshot.
# It tracks memory allocations that occur during the traced execution timeline.
#
# **Total Static Memory Usage:** This is the "starting memory" or baseline
# memory that exists before tracing begins. It's estimated by the PyTorch
# visualizer and remains constant throughout the snapshot (doesn't come with
# stack traces).
#
# .. note::
#
# In the snapshots you may observe differences in total *static* memory
# usage, which accounts for the remaining difference.
#
# **Total Overall Peak Memory Usage:** Dynamic + Static
#
if HAS_CUDA:
print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
print(f"Baseline (clean model): {baseline_memory_debug:.2f} GB")
print(f"With bug (debug projections): {buggy_memory:.2f} GB")
print(
f"Extra memory from bug: {buggy_memory - baseline_memory_debug:.2f} GB"
)
######################################################################
# Case 3: Integrating Memory Analysis into Your Training Pipeline
# ================================================================
#
# This section demonstrates how to use Mosaic to automatically capture memory
# snapshots during training, get structured memory breakdown data for
# monitoring/dashboards, and build automated memory monitoring for large-scale
# training using Mosaic **programmatically** (as a Python dependency).
#
# Mosaic integrates memory analysis directly into your training pipeline.
#
######################################################################
# Training with Automatic Memory Capture
# ---------------------------------------
def run_training_with_memory_capture(
batch_size=4,
seq_length=512,
num_steps=5,
snapshot_path="training_snapshot.pickle",
):
"""Run training and automatically capture memory snapshot."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(tokenizer.vocab_size, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print(f"Running {num_steps} training steps with memory capture...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
outputs.loss.backward()
optimizer.step()
print(f" Step {step + 1}/{num_steps}, Loss: {outputs.loss.item():.4f}")
peak_memory_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"✓ PyTorch reported peak memory: {peak_memory_gb:.3f} GB")
del model, optimizer
torch.cuda.empty_cache()
return snapshot_path
if HAS_CUDA:
print("\n" + "=" * 60)
print("CASE 3: Pipeline Integration")
print("=" * 60)
pipeline_snapshot_path = run_training_with_memory_capture(batch_size=4, seq_length=512)
######################################################################
# Mosaic Memory Analysis via Python API
# --------------------------------------
#
# Instead of using CLI commands, we can use Mosaic's Python API directly
# for programmatic integration.
if HAS_CUDA:
print("\n" + "=" * 60)
print("MOSAIC MEMORY ANALYSIS (via Python API)")
print("=" * 60)
# Load and analyze the memory snapshot
memory_abstract = MemoryAbstract(memory_snapshot_file=pipeline_snapshot_path)
memory_abstract.load_memory_snapshot()
# Analyze peak memory usage