▶ code
▶ output
▶ uv-logs
|
Cell: combine | 36.89s
|
Raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch",
# "kernels-benchmark-tools",
# "matplotlib",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
# ///
import os
import sys
from pathlib import Path
import json
import torch # noqa: F401 # imported because upstream may expect torch to be importable
import kernels_benchmark_tools as kbt
# --- Matplotlib setup and helpers ------------------------------------------------
import matplotlib as mpl
import matplotlib.pyplot as plt
import csv
# Keep text as text (not paths) so CSS can style fonts, size, etc.
mpl.rcParams["svg.fonttype"] = "none"
# Make ids deterministic across builds
mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined"
# Avoid auto-closed figures interfering with our tagging
mpl.rcParams["figure.autolayout"] = True
# Make background transparent
mpl.rcParams["figure.facecolor"] = "none"
mpl.rcParams["axes.facecolor"] = "none"
mpl.rcParams["savefig.facecolor"] = "none"
mpl.rcParams["savefig.edgecolor"] = "none"
def _slugify(s: str) -> str:
s = (s or "").strip().lower()
keep = []
for ch in s:
if ch.isalnum():
keep.append(ch)
elif ch in (" ", "-", "_", "/", ".", ":"):
keep.append("-")
else:
keep.append("")
out = "".join(keep)
while "--" in out:
out = out.replace("--", "-")
return out.strip("-") or "unnamed"
def _tag_current_figure(default_series_prefix="series"):
"""Attach SVG ids (gid) to key artists so they can be targeted from CSS."""
fig = plt.gcf()
if fig is None:
return
# Tag the figure itself
fig.set_gid("figure--latency")
for ax_idx, ax in enumerate(fig.get_axes(), start=1):
ax.set_gid(f"axes--{ax_idx}")
# Axis labels & title
if ax.get_title():
for t in ax.texts:
if t.get_text() == ax.get_title():
t.set_gid("title--main")
if ax.xaxis and ax.xaxis.get_label():
ax.xaxis.label.set_gid("label--x")
if ax.yaxis and ax.yaxis.get_label():
ax.yaxis.label.set_gid("label--y")
# Gridlines
for i, gl in enumerate(ax.get_xgridlines(), start=1):
gl.set_gid(f"grid-x--{i}")
for i, gl in enumerate(ax.get_ygridlines(), start=1):
gl.set_gid(f"grid-y--{i}")
# Legend block & entries
leg = ax.get_legend()
if leg is not None:
leg.set_gid("legend")
for i, txt in enumerate(leg.get_texts(), start=1):
label_slug = _slugify(txt.get_text())
txt.set_gid(f"legend-label--{label_slug or i}")
# Series (lines, patches)
# Lines
line_seen = {}
for ln in getattr(ax, "lines", []):
raw_label = ln.get_label() or ""
# Matplotlib uses labels beginning with "_" for non-legendable items
label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}"
slug = _slugify(label)
line_seen[slug] = line_seen.get(slug, 0) + 1
suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}"
ln.set_gid(f"series--{slug}{suffix}")
# Patches (bars, areas)
patch_seen = {}
for pt in getattr(ax, "patches", []):
label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}"
if isinstance(label, str) and label.startswith("_"):
label = default_series_prefix
slug = _slugify(label)
patch_seen[slug] = patch_seen.get(slug, 0) + 1
suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}"
pt.set_gid(f"series--{slug}{suffix}")
def _postprocess_svg_add_classes(svg_path: Path):
"""Add convenient CSS classes alongside ids (e.g., class='series grid grid-x')."""
try:
import xml.etree.ElementTree as ET
ET.register_namespace("", "http://www.w3.org/2000/svg")
tree = ET.parse(svg_path)
root = tree.getroot()
for el in root.iter():
el_id = el.attrib.get("id", "")
if not el_id:
continue
cls = []
if el_id.startswith("figure--"):
cls.append("figure")
elif el_id.startswith("axes--"):
cls.append("axes")
elif el_id.startswith("grid-x--"):
cls += ["grid", "grid-x"]
elif el_id.startswith("grid-y--"):
cls += ["grid", "grid-y"]
elif el_id.startswith("legend"):
cls.append("legend")
elif el_id.startswith("label--x"):
cls.append("xlabel")
elif el_id.startswith("label--y"):
cls.append("ylabel")
elif el_id.startswith("title--"):
cls.append("title")
elif el_id.startswith("series--"):
cls.append("series")
if cls:
# Preserve any existing class (unlikely from Matplotlib)
existing = el.attrib.get("class", "")
el.set("class", (existing + " " + " ".join(cls)).strip())
tree.write(svg_path, encoding="utf-8", xml_declaration=True)
except Exception as e:
print(f"✗ SVG postprocess (classes) skipped: {e}")
# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally.
_orig_savefig = plt.savefig
def _savefig_svg(fname, *args, **kwargs):
# Always save as SVG at a stable path for the artifact system
out = Path("latency.svg")
kwargs["format"] = "svg"
# Ensure everything we care about has ids before export
_tag_current_figure()
res = _orig_savefig(out, *args, **kwargs)
# Add helpful CSS classes on top of ids
_postprocess_svg_add_classes(out)
print(f"✓ Combined visualization saved as {out}")
return res
plt.savefig = _savefig_svg # apply patch
# Capture close calls in case kbt.viz() closes figures before we re-save
_orig_close = plt.close
_last_closed = {"fig": None}
def _capture_close(arg=None):
try:
if hasattr(arg, "savefig"): # looks like a Figure
_last_closed["fig"] = arg
else:
_last_closed["fig"] = plt.gcf()
finally:
return _orig_close(arg)
plt.close = _capture_close
# --- Locate benchmark artifacts --------------------------------------------------
cache_dirs = {
"Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'),
"MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'),
"Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'),
"xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'),
"SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'),
"Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'),
"Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'),
"HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'),
"HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'),
}
print("LOADING BENCHMARK DATA")
for name, cache_dir in cache_dirs.items():
print(f"{name:30s}: {cache_dir}")
print()
file_mapping = {
"Flash (PyTorch SDPA)": "attn.jsonl",
"MemEff (PyTorch SDPA)": "attn.jsonl",
"Flash Attn 2": "attn.jsonl",
"xFormers": "attn.jsonl",
"SageAttention": "attn.jsonl",
"Compiled (default)": "attn_default.jsonl",
"Compiled (max-autotune)": "attn_max_autotune.jsonl",
"HF Kernels Flash Attn": "attn.jsonl",
"HF Kernels Flash Attn3": "attn.jsonl",
}
all_paths = []
for name, cache_dir in cache_dirs.items():
if cache_dir:
path = Path(cache_dir) / file_mapping[name]
if path.exists() and path.stat().st_size > 0:
all_paths.append(str(path))
print(f"✓ Found {name}: {path}")
else:
print(f"⊘ Empty/Missing {name}: {path}")
else:
print(f"✗ No cache dir for {name}")
print()
if not all_paths:
print("ERROR: No benchmark data files found!")
# restore patched functions before exiting
plt.savefig = _orig_savefig
plt.close = _orig_close
sys.exit(1)
# --- Summary + Visualization -----------------------------------------------------
print("COMBINED BENCHMARK SUMMARY\n")
kbt.summarize(all_paths)
print("\nGENERATING COMBINED VISUALIZATION\n")
try:
# If kbt.viz saves internally, our patched savefig ensures SVG gets written,
# and it will carry ids/classes for CSS styling.
kbt.viz(all_paths)
# Safety net: if kbt.viz didn't save, save now.
# if not Path("latency.svg").exists():
# _tag_current_figure()
# plt.savefig("latency.svg")
plt.savefig("latency.svg") # ensure saved with tagging
print("✓ SVG visualization ready: latency.svg!")
except ImportError as e:
print(f"✗ Visualization requires matplotlib: {e}")
except Exception as e:
print(f"✗ Visualization failed: {e}")
finally:
# Clean up patches to avoid side effects in later cells
plt.savefig = _orig_savefig
plt.close = _orig_close
print()
print("ANALYSIS COMPLETE")
print(f"Total implementations analyzed: {len(all_paths)}")
print(f"\nImplementations included:")
for name, cache_dir in cache_dirs.items():
if cache_dir:
path = Path(cache_dir) / file_mapping[name]
if path.exists() and path.stat().st_size > 0:
print(f" ✓ {name}")
# Collect all benchmark data and export to CSV
all_data = {}
for name, cache_dir in cache_dirs.items():
if cache_dir:
path = Path(cache_dir) / file_mapping[name]
if path.exists() and path.stat().st_size > 0:
with open(path, 'r') as f:
records = [json.loads(line) for line in f]
all_data[name] = records
# Export to CSV
csv_path = Path("latency.csv")
with open(csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
# Write header
header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype",
"Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps",
# "Compile (ms)",
"Peak Mem (MB)", "Backend", "Family"]
writer.writerow(header)
# Write data rows
for impl_name, records in all_data.items():
for record in records:
wl = record.get('wl', {})
lat = record.get('lat_ms', {})
tags = record.get('tags', {})
row = [
impl_name,
record.get('impl', ''),
wl.get('name', ''),
wl.get('batch', ''),
wl.get('seq_len', ''),
wl.get('heads', ''),
wl.get('head_dim', ''),
wl.get('dtype', ''),
lat.get('mean', ''),
lat.get('p10', ''),
lat.get('p50', ''),
lat.get('p90', ''),
lat.get('reps', ''),
# record.get('compile_ms', ''),
round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '',
tags.get('backend', ''),
tags.get('family', ''),
]
writer.writerow(row)
print(f"✓ CSV export complete: {csv_path}")
print(f"Total implementations: {len(all_data)}")
print(f"Total records: {sum(len(records) for records in all_data.values())}")
LOADING BENCHMARK DATA
Flash (PyTorch SDPA) : /repo/flash_attn/impls/.uvnote/cache/327a3408e7cdfeef6984786686ce13137074d9f083e6e434c29f02589d28a0f8
MemEff (PyTorch SDPA) : /repo/flash_attn/impls/.uvnote/cache/25ca9e52daa50b9289780b3e1302f2949db718140ef9eedd44a8a554afaff9ee
Flash Attn 2 : None
xFormers : /repo/flash_attn/impls/.uvnote/cache/6802a31176fbf22c1f5dd5442cf5ae77d8e3527d679642244908984c16933902
SageAttention : None
Compiled (default) : /repo/flash_attn/impls/.uvnote/cache/bd779935ea10d468a5a99c29b029da0e0ef4dc2a7b82bc8595d04b2f142a3a44
Compiled (max-autotune) : /repo/flash_attn/impls/.uvnote/cache/f4bc4785407df53e53f91c190279cdf3dbe3cf7028e2e352d1cc90b92bfcf86e
HF Kernels Flash Attn : /repo/flash_attn/impls/.uvnote/cache/e4c157a6bfc7f8394530835e3b63d4e2032ebfed3f19a0693978eb24ba415910
HF Kernels Flash Attn3 : /repo/flash_attn/impls/.uvnote/cache/65da999faf55d11c76155fa1d198e77708e1fe8247e3d0b5fd7093a206551ce5
✓ Found Flash (PyTorch SDPA): /repo/flash_attn/impls/.uvnote/cache/327a3408e7cdfeef6984786686ce13137074d9f083e6e434c29f02589d28a0f8/attn.jsonl
✓ Found MemEff (PyTorch SDPA): /repo/flash_attn/impls/.uvnote/cache/25ca9e52daa50b9289780b3e1302f2949db718140ef9eedd44a8a554afaff9ee/attn.jsonl
✗ No cache dir for Flash Attn 2
✓ Found xFormers: /repo/flash_attn/impls/.uvnote/cache/6802a31176fbf22c1f5dd5442cf5ae77d8e3527d679642244908984c16933902/attn.jsonl
✗ No cache dir for SageAttention
✓ Found Compiled (default): /repo/flash_attn/impls/.uvnote/cache/bd779935ea10d468a5a99c29b029da0e0ef4dc2a7b82bc8595d04b2f142a3a44/attn_default.jsonl
✓ Found Compiled (max-autotune): /repo/flash_attn/impls/.uvnote/cache/f4bc4785407df53e53f91c190279cdf3dbe3cf7028e2e352d1cc90b92bfcf86e/attn_max_autotune.jsonl
✓ Found HF Kernels Flash Attn: /repo/flash_attn/impls/.uvnote/cache/e4c157a6bfc7f8394530835e3b63d4e2032ebfed3f19a0693978eb24ba415910/attn.jsonl
✓ Found HF Kernels Flash Attn3: /repo/flash_attn/impls/.uvnote/cache/65da999faf55d11c76155fa1d198e77708e1fe8247e3d0b5fd7093a206551ce5/attn.jsonl
COMBINED BENCHMARK SUMMARY
impl wl p50(ms) ok
hf_kernels_flash_attn flux_L128 0.35 True
hf_kernels_flash_attn flux_L256 0.38 True
hf_kernels_flash_attn flux_L320 0.49 True
hf_kernels_flash_attn flux_L384 0.52 True
hf_kernels_flash_attn flux_L448 0.54 True
hf_kernels_flash_attn flux_L512 0.56 True
hf_kernels_flash_attn3 flux_L128 0.36 True
hf_kernels_flash_attn3 flux_L256 0.39 True
hf_kernels_flash_attn3 flux_L320 0.52 True
hf_kernels_flash_attn3 flux_L384 0.53 True
hf_kernels_flash_attn3 flux_L448 0.57 True
hf_kernels_flash_attn3 flux_L512 0.57 True
torch_flash_compiled_default flux_L128 0.52 True
torch_flash_compiled_default flux_L256 0.56 True
torch_flash_compiled_default flux_L320 0.69 True
torch_flash_compiled_default flux_L384 0.72 True
torch_flash_compiled_default flux_L448 0.74 True
torch_flash_compiled_default flux_L512 0.77 True
torch_flash_compiled_max_autotune flux_L128 0.62 True
torch_flash_compiled_max_autotune flux_L256 0.69 True
torch_flash_compiled_max_autotune flux_L320 0.82 True
torch_flash_compiled_max_autotune flux_L384 0.85 True
torch_flash_compiled_max_autotune flux_L448 0.88 True
torch_flash_compiled_max_autotune flux_L512 0.92 True
torch_flash_ma flux_L128 0.49 True
torch_flash_ma flux_L256 0.52 True
torch_flash_ma flux_L320 0.65 True
torch_flash_ma flux_L384 0.68 True
torch_flash_ma flux_L448 0.71 True
torch_flash_ma flux_L512 0.74 True
torch_mem_eff flux_L128 0.59 True
torch_mem_eff flux_L256 0.65 True
torch_mem_eff flux_L320 0.78 True
torch_mem_eff flux_L384 0.79 True
torch_mem_eff flux_L448 0.85 True
torch_mem_eff flux_L512 0.95 True
xformers_meff flux_L128 0.45 True
xformers_meff flux_L256 0.47 True
xformers_meff flux_L320 0.60 True
xformers_meff flux_L384 0.60 True
xformers_meff flux_L448 0.64 True
xformers_meff flux_L512 0.65 True
GENERATING COMBINED VISUALIZATION
Loaded 42 records
✓ Combined visualization saved as latency.svg
Saved latency.png
✓ Combined visualization saved as latency.svg
✓ SVG visualization ready: latency.svg!
ANALYSIS COMPLETE
Total implementations analyzed: 7
Implementations included:
✓ Flash (PyTorch SDPA)
✓ MemEff (PyTorch SDPA)
✓ xFormers
✓ Compiled (default)
✓ Compiled (max-autotune)
✓ HF Kernels Flash Attn
✓ HF Kernels Flash Attn3
✓ CSV export complete: latency.csv
Total implementations: 7
Total records: 42
▶ UV Install Logs
Artifacts:
latency.csv latency.svgImplementation | Impl ID | Workload | Batch | Seq Length | Heads | Head Dim | Dtype | Mean (ms) | P10 (ms) | P50 (ms) | P90 (ms) | Reps | Peak Mem (MB) | Backend | Family |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Flash (PyTorch SDPA) | torch_flash_ma | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.49411200881004336 | 0.48844799399375916 | 0.4936000108718872 | 0.4944640100002289 | 5 | 83.38 | FLASH | torch-sdpa |
Flash (PyTorch SDPA) | torch_flash_ma | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.5234112024307251 | 0.5224320292472839 | 0.5235199928283691 | 0.5235840082168579 | 5 | 90.62 | FLASH | torch-sdpa |
Flash (PyTorch SDPA) | torch_flash_ma | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.6527232170104981 | 0.6503040194511414 | 0.6524800062179565 | 0.6545600295066833 | 5 | 95.06 | FLASH | torch-sdpa |
Flash (PyTorch SDPA) | torch_flash_ma | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.682803213596344 | 0.6805760264396667 | 0.6828799843788147 | 0.6832640171051025 | 5 | 99.88 | FLASH | torch-sdpa |
Flash (PyTorch SDPA) | torch_flash_ma | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.7075456142425537 | 0.7057600021362305 | 0.7063360214233398 | 0.7070720195770264 | 5 | 103.81 | FLASH | torch-sdpa |
Flash (PyTorch SDPA) | torch_flash_ma | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.7379711985588073 | 0.7368639707565308 | 0.7372480034828186 | 0.7391039729118347 | 5 | 109.12 | FLASH | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.5874239921569824 | 0.5861759781837463 | 0.5873280167579651 | 0.5877439975738525 | 5 | 83.38 | EFFICIENT | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.6502719998359681 | 0.6490240097045898 | 0.649183988571167 | 0.6517760157585144 | 5 | 90.62 | EFFICIENT | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.7812095880508423 | 0.7761600017547607 | 0.7803199887275696 | 0.7852799892425537 | 5 | 95.94 | EFFICIENT | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.7948480010032654 | 0.7911999821662903 | 0.7935360074043274 | 0.7948480248451233 | 5 | 100.0 | EFFICIENT | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.8463295936584473 | 0.8449919819831848 | 0.8459839820861816 | 0.8461120128631592 | 5 | 103.81 | EFFICIENT | torch-sdpa |
MemEff (PyTorch SDPA) | torch_mem_eff | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.9538687944412232 | 0.9492800235748291 | 0.9518399834632874 | 0.9581760168075562 | 5 | 109.12 | EFFICIENT | torch-sdpa |
xFormers | xformers_meff | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.4515071928501129 | 0.44364801049232483 | 0.4524799883365631 | 0.4557119905948639 | 5 | 83.38 | memory_efficient | xformers |
xFormers | xformers_meff | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.46787199974060056 | 0.46489599347114563 | 0.4684160053730011 | 0.46908798813819885 | 5 | 90.62 | memory_efficient | xformers |
xFormers | xformers_meff | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.6001471996307373 | 0.596992015838623 | 0.5984640121459961 | 0.6016640067100525 | 5 | 95.06 | memory_efficient | xformers |
xFormers | xformers_meff | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.6023231983184815 | 0.5997440218925476 | 0.6031039953231812 | 0.6032639741897583 | 5 | 99.88 | memory_efficient | xformers |
xFormers | xformers_meff | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.6411136031150818 | 0.6381760239601135 | 0.6414719820022583 | 0.6421440243721008 | 5 | 103.81 | memory_efficient | xformers |
xFormers | xformers_meff | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.6594688057899475 | 0.6441280245780945 | 0.6496639847755432 | 0.6527680158615112 | 5 | 109.12 | memory_efficient | xformers |
Compiled (default) | torch_flash_compiled_default | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.5181439876556396 | 0.5141760110855103 | 0.5175679922103882 | 0.5197759866714478 | 5 | 83.38 | FLASH | torch-sdpa |
Compiled (default) | torch_flash_compiled_default | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.5579584002494812 | 0.5549119710922241 | 0.5582720041275024 | 0.5598080158233643 | 5 | 90.62 | FLASH | torch-sdpa |
Compiled (default) | torch_flash_compiled_default | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.6872959971427918 | 0.6853119730949402 | 0.687391996383667 | 0.6883519887924194 | 5 | 95.25 | FLASH | torch-sdpa |
Compiled (default) | torch_flash_compiled_default | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.716153597831726 | 0.7128639817237854 | 0.7160959839820862 | 0.7167680263519287 | 5 | 99.88 | FLASH | torch-sdpa |
Compiled (default) | torch_flash_compiled_default | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.7418303966522217 | 0.7386879920959473 | 0.7400959730148315 | 0.7415040135383606 | 5 | 103.81 | FLASH | torch-sdpa |
Compiled (default) | torch_flash_compiled_default | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.7745471954345703 | 0.7708160281181335 | 0.7740799784660339 | 0.7753919959068298 | 5 | 109.12 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.6468096017837525 | 0.6144000291824341 | 0.6245759725570679 | 0.6483200192451477 | 5 | 67.5 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.7060160160064697 | 0.6689280271530151 | 0.6851199865341187 | 0.7184960246086121 | 5 | 75.0 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.8332608103752136 | 0.7953600287437439 | 0.8155840039253235 | 0.8403519988059998 | 5 | 80.38 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.8719295978546142 | 0.8470720052719116 | 0.849727988243103 | 0.8745279908180237 | 5 | 82.5 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.9034304022789001 | 0.8677120208740234 | 0.8835520148277283 | 0.9034240245819092 | 5 | 86.25 | FLASH | torch-sdpa |
Compiled (max-autotune) | torch_flash_compiled_max_autotune | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.9387519836425782 | 0.9154239892959595 | 0.9213759899139404 | 0.9359679818153381 | 5 | 90.0 | FLASH | torch-sdpa |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.3455295979976654 | 0.34355199337005615 | 0.34563198685646057 | 0.34643200039863586 | 5 | 83.38 | flash-attn | hf-kernels |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.3756160080432892 | 0.37411201000213623 | 0.3752000033855438 | 0.3770880103111267 | 5 | 90.62 | flash-attn | hf-kernels |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.4953216016292572 | 0.49324798583984375 | 0.49433600902557373 | 0.49663999676704407 | 5 | 95.06 | flash-attn | hf-kernels |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.5157055854797363 | 0.5142719745635986 | 0.516319990158081 | 0.516543984413147 | 5 | 99.88 | flash-attn | hf-kernels |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.5356672048568726 | 0.5346879959106445 | 0.5358080267906189 | 0.5361599922180176 | 5 | 103.81 | flash-attn | hf-kernels |
HF Kernels Flash Attn | hf_kernels_flash_attn | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.5587136030197144 | 0.5557760000228882 | 0.5574079751968384 | 0.5581120252609253 | 5 | 109.12 | flash-attn | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L128 | 1 | 1152 | 24 | 128 | bfloat16 | 0.3619711995124817 | 0.3603839874267578 | 0.361952006816864 | 0.3624640107154846 | 5 | 83.38 | flash-attn3 | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L256 | 1 | 1280 | 24 | 128 | bfloat16 | 0.3912447988986969 | 0.3892799913883209 | 0.3909760117530823 | 0.3922559916973114 | 5 | 90.62 | flash-attn3 | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L320 | 1 | 1344 | 24 | 128 | bfloat16 | 0.5258048176765442 | 0.5240640044212341 | 0.5248960256576538 | 0.5248960256576538 | 5 | 95.06 | flash-attn3 | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L384 | 1 | 1408 | 24 | 128 | bfloat16 | 0.5276032090187073 | 0.5265600085258484 | 0.5277760028839111 | 0.5282559990882874 | 5 | 99.88 | flash-attn3 | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L448 | 1 | 1472 | 24 | 128 | bfloat16 | 0.5656383991241455 | 0.5639039874076843 | 0.5657920241355896 | 0.5668479800224304 | 5 | 103.81 | flash-attn3 | hf-kernels |
HF Kernels Flash Attn3 | hf_kernels_flash_attn3 | flux_L512 | 1 | 1536 | 24 | 128 | bfloat16 | 0.5789952039718628 | 0.5689600110054016 | 0.5698239803314209 | 0.5713919997215271 | 5 | 109.12 | flash-attn3 | hf-kernels |