Compare commits

..

1 Commits

Author SHA1 Message Date
w-e-w e936dbb43b allow add middleware after app has started
this should completely fix "Cannot add middleware after an application has started" which can occur due to a race condition
2024-12-28 14:53:59 +09:00
20 changed files with 73 additions and 169 deletions
-1
View File
@@ -88,7 +88,6 @@ module.exports = {
// imageviewer.js
modalPrevImage: "readonly",
modalNextImage: "readonly",
updateModalImageIfVisible: "readonly",
// localStorage.js
localSet: "readonly",
localGet: "readonly",
+1 -1
View File
@@ -133,7 +133,7 @@ If your system is very new, you need to install python3.11 or python3.10:
# Ubuntu 24.04
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt update
sudo apt install python3.11 python3.11-venv
sudo apt install python3.11
# Manjaro/Arch
sudo pacman -S yay
-2
View File
@@ -54,7 +54,6 @@ function updateOnBackgroundChange() {
updateModalImage();
}
}
const updateModalImageIfVisible = updateOnBackgroundChange;
function modalImageSwitch(offset) {
var galleryButtons = all_gallery_buttons();
@@ -165,7 +164,6 @@ function modalLivePreviewToggle(event) {
const modalToggleLivePreview = gradioApp().getElementById("modal_toggle_live_preview");
opts.js_live_preview_in_modal_lightbox = !opts.js_live_preview_in_modal_lightbox;
modalToggleLivePreview.innerHTML = opts.js_live_preview_in_modal_lightbox ? "🗇" : "🗆";
updateModalImageIfVisible();
event.stopPropagation();
}
+1 -1
View File
@@ -190,7 +190,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
livePreview.className = 'livePreview';
gallery.insertBefore(livePreview, gallery.firstElementChild);
}
updateModalImageIfVisible();
livePreview.appendChild(img);
if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild);
-5
View File
@@ -6,11 +6,6 @@ git = launch_utils.git
index_url = launch_utils.index_url
dir_repos = launch_utils.dir_repos
if args.uv:
from modules.uv_hook import patch
patch()
commit_hash = launch_utils.commit_hash
git_tag = launch_utils.git_tag
-1
View File
@@ -126,4 +126,3 @@ parser.add_argument("--skip-load-model-at-start", action='store_true', help="if
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
parser.add_argument("--uv", action='store_true', help="use the uv package manager")
+2 -30
View File
@@ -1,7 +1,7 @@
import hashlib
import os.path
from modules import shared, errors
from modules import shared
import modules.cache
dump_cache = modules.cache.dump_cache
@@ -32,7 +32,7 @@ def sha256_from_cache(filename, title, use_addnet_hash=False):
cached_sha256 = hashes[title].get("sha256", None)
cached_mtime = hashes[title].get("mtime", 0)
if ondisk_mtime != cached_mtime or cached_sha256 is None:
if ondisk_mtime > cached_mtime or cached_sha256 is None:
return None
return cached_sha256
@@ -82,31 +82,3 @@ def addnet_hash_safetensors(b):
return hash_sha256.hexdigest()
def partial_hash_from_cache(filename, *, ignore_cache: bool = False, digits: int = 8):
"""old hash that only looks at a small part of the file and is prone to collisions
kept for compatibility, don't use this for new things
"""
try:
filename = str(filename)
mtime = os.path.getmtime(filename)
hashes = cache('partial-hash')
cache_entry = hashes.get(filename, {})
cache_mtime = cache_entry.get("mtime", 0)
cache_hash = cache_entry.get("hash", None)
if mtime == cache_mtime and cache_hash and not ignore_cache:
return cache_hash[0:digits]
with open(filename, 'rb') as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
partial_hash = m.hexdigest()
hashes[filename] = {'mtime': mtime, 'hash': partial_hash}
return partial_hash[0:digits]
except FileNotFoundError:
pass
except Exception:
errors.report(f'Error calculating partial hash for {filename}', exc_info=True)
return 'NOFILE'
-1
View File
@@ -409,7 +409,6 @@ class FilenameGenerator:
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'randn_source': lambda self: opts.data["randn_source"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
+1
View File
@@ -50,6 +50,7 @@ def check_versions():
def initialize():
from modules import initialize_util
initialize_util.allow_add_middleware_after_start()
initialize_util.fix_torch_version()
initialize_util.fix_pytorch_lightning()
initialize_util.fix_asyncio_event_loop_policy()
+37 -3
View File
@@ -5,6 +5,8 @@ import sys
import re
from modules.timer import startup_timer
from modules import patches
from functools import wraps
def gradio_server_name():
@@ -191,11 +193,8 @@ def configure_opts_onchange():
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
@@ -213,3 +212,38 @@ def configure_cors_middleware(app):
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
def allow_add_middleware_after_start():
from starlette.applications import Starlette
def add_middleware_wrapper(func):
"""Patch Starlette.add_middleware to allow for middleware to be added after the app has started
Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).
If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
the first way is to just set middleware_stack to None and then retry
the second manually insert the middleware into the user_middleware list without calling add_middleware
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
res = None
try:
res = func(self, *args, **kwargs)
except RuntimeError as _:
try:
self.middleware_stack = None
res = func(self, *args, **kwargs)
except RuntimeError as e:
print(f'Warning: "{e}", Retrying...')
from starlette.middleware import Middleware
self.user_middleware.insert(0, Middleware(*args, **kwargs))
self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests
return res
return wrapper
patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))
+5 -55
View File
@@ -313,43 +313,9 @@ def requirements_met(requirements_file):
return True
def get_cuda_comp_cap():
"""
Returns float of CUDA Compute Capability using nvidia-smi
Returns 0.0 on error
CUDA Compute Capability
ref https://developer.nvidia.com/cuda-gpus
ref https://en.wikipedia.org/wiki/CUDA
Blackwell consumer GPUs should return 12.0 data-center GPUs should return 10.0
"""
try:
return max(map(float, subprocess.check_output(['nvidia-smi', '--query-gpu=compute_cap', '--format=noheader,csv'], text=True).splitlines()))
except Exception as _:
return 0.0
def early_access_blackwell_wheels():
"""For Blackwell GPUs, use Early Access PyTorch Wheels provided by Nvidia"""
print('deprecated early_access_blackwell_wheels')
if all([
os.environ.get('TORCH_INDEX_URL') is None,
sys.version_info.major == 3,
sys.version_info.minor in (10, 11, 12),
platform.system() == "Windows",
get_cuda_comp_cap() >= 10, # Blackwell
]):
base_repo = 'https://huggingface.co/w-e-w/torch-2.6.0-cu128.nv/resolve/main/'
ea_whl = {
10: f'{base_repo}torch-2.6.0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=fef3de7ce8f4642e405576008f384304ad0e44f7b06cc1aa45e0ab4b6e70490d {base_repo}torchvision-0.20.0a0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=50841254f59f1db750e7348b90a8f4cd6befec217ab53cbb03780490b225abef',
11: f'{base_repo}torch-2.6.0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=6665c36e6a7e79e7a2cb42bec190d376be9ca2859732ed29dd5b7b5a612d0d26 {base_repo}torchvision-0.20.0a0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=bbc0ee4938e35fe5a30de3613bfcd2d8ef4eae334cf8d49db860668f0bb47083',
12: f'{base_repo}torch-2.6.0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=a3197f72379d34b08c4a4bcf49ea262544a484e8702b8c46cbcd66356c89def6 {base_repo}torchvision-0.20.0a0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=235e7be71ac4e75b0f8e817bae4796d7bac8a67146d2037ab96394f2bdc63e6c'
}
return f'pip install {ea_whl.get(sys.version_info.minor)}'
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu128")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.7.0 torchvision==0.22.0 --extra-index-url {torch_index_url}")
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
@@ -373,12 +339,12 @@ def prepare_environment():
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.30')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/w-e-w/stablediffusion.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
@@ -422,24 +388,8 @@ def prepare_environment():
)
startup_timer.record("torch GPU test")
# Ensure build dependencies are installed before any package that might need them
def ensure_build_dependencies():
"""Ensure essential build tools are available"""
if not is_installed("wheel"):
run_pip("install wheel", "wheel")
# Check setuptools version compatibility
try:
setuptools_version = run(f'"{python}" -c "import setuptools; print(setuptools.__version__)"', None, None).strip()
if setuptools_version >= "70":
run_pip("install setuptools==69.5.1", "setuptools")
except Exception:
# If setuptools check fails, install compatible version
run_pip("install setuptools==69.5.1", "setuptools")
# Install build dependencies early
ensure_build_dependencies()
if not is_installed("clip"):
run_pip(f"install --no-build-isolation {clip_package}", "clip")
run_pip(f"install {clip_package}", "clip")
startup_timer.record("install clip")
if not is_installed("open_clip"):
+1 -1
View File
@@ -54,7 +54,7 @@ class SdOptimizationXformers(SdOptimization):
priority = 100
def is_available(self):
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (12, 0))
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+16 -2
View File
@@ -13,7 +13,6 @@ from urllib import request
import ldm.modules.midas as midas
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.hashes import partial_hash_from_cache as model_hash # noqa: F401 for backwards compatibility
from modules.timer import Timer
from modules.shared import opts
import tomesd
@@ -88,7 +87,7 @@ class CheckpointInfo:
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = hashes.partial_hash_from_cache(filename)
self.hash = model_hash(filename)
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
@@ -201,6 +200,21 @@ def get_closet_checkpoint_match(search_string):
return None
def model_hash(filename):
"""old hash that only looks at a small part of the file and is prone to collisions"""
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'
def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
+4 -7
View File
@@ -117,15 +117,12 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
alpha = shared.opts.beta_dist_alpha
beta = shared.opts.beta_dist_beta
curve = [stats.beta.ppf(x, alpha, beta) for x in np.linspace(1, 0, n)]
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
timesteps = [end + x * (start - end) for x in curve]
sigmas = [inner_model.t_to_sigma(ts) for ts in timesteps]
timesteps = 1 - np.linspace(0, 1, n)
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
sigmas += [0.0]
return torch.FloatTensor(sigmas).to(device)
+2 -2
View File
@@ -407,8 +407,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling; XYZ plot: Skip Early CFG"),
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 5.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 5.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
-1
View File
@@ -1,4 +1,3 @@
from __future__ import annotations
import os
import re
-50
View File
@@ -1,50 +0,0 @@
import sys
import copy
import shlex
import subprocess
from functools import wraps
BAD_FLAGS = ("--prefer-binary", '-I', '--ignore-installed')
def patch():
if hasattr(subprocess, "__original_run"):
return
print("using uv")
try:
subprocess.run(['uv', '-V'])
except FileNotFoundError:
subprocess.run([sys.executable, '-m', 'pip', 'install', 'uv'])
subprocess.__original_run = subprocess.run
@wraps(subprocess.__original_run)
def patched_run(*args, **kwargs):
_kwargs = copy.copy(kwargs)
if args:
command, *_args = args
else:
command, _args = _kwargs.pop("args", ""), ()
if isinstance(command, str):
command = shlex.split(command)
else:
command = [arg.strip() for arg in command]
if not isinstance(command, list) or "pip" not in command:
return subprocess.__original_run(*args, **kwargs)
cmd = command[command.index("pip") + 1:]
cmd = [arg for arg in cmd if arg not in BAD_FLAGS]
modified_command = ["uv", "pip", *cmd]
cmd_str = shlex.join([*modified_command, *_args])
result = subprocess.__original_run(cmd_str, **_kwargs)
if result.returncode != 0:
return subprocess.__original_run(*args, **kwargs)
return result
subprocess.run = patched_run
+1 -1
View File
@@ -182,7 +182,7 @@ document.addEventListener('keydown', function(e) {
const lightboxModal = document.querySelector('#lightboxModal');
if (!globalPopup || globalPopup.style.display === 'none') {
if (document.activeElement === lightboxModal) return;
if (interruptButton?.style.display === 'block') {
if (interruptButton.style.display === 'block') {
interruptButton.click();
e.preventDefault();
}
+1 -4
View File
@@ -480,10 +480,8 @@ div.toprow-compact-tools{
}
#settings_result{
min-height: 1.4em;
height: 1.4em;
margin: 0 1.2em;
max-height: calc(var(--text-md) * var(--line-sm) * 5);
overflow-y: auto;
}
table.popup-table{
@@ -602,7 +600,6 @@ table.popup-table .link{
background: var(--background-fill-primary);
width: 100%;
height: 100%;
pointer-events: none;
}
.livePreview img{
+1 -1
View File
@@ -127,7 +127,7 @@ then
fi
# Check prerequisites
gpu_info=$(lspci 2>/dev/null | grep -E "VGA|Display|CMP")
gpu_info=$(lspci 2>/dev/null | grep -E "VGA|Display")
case "$gpu_info" in
*"Navi 1"*)
export HSA_OVERRIDE_GFX_VERSION=10.3.0