Compare commits

..

21 Commits

Author SHA1 Message Date
w-e-w e936dbb43b allow add middleware after app has started
this should completely fix "Cannot add middleware after an application has started" which can occur due to a race condition
2024-12-28 14:53:59 +09:00
w-e-w 813c3912fc open API docs in new tab (#16754) 2024-12-26 22:19:43 -05:00
w-e-w 078d04ef23 ruff <path> is deprecated. Use ruff check <path> (#16753) 2024-12-26 20:40:15 -05:00
w-e-w 1a773bf2c8 Merge pull request #16751 from Neokmi/master
Fix  Codeformer and gfpgan extension , Inconsistent overlay layer types when visibility value is less than 1
2024-12-26 06:33:04 +09:00
w-e-w f113474a6e lint 2024-12-26 06:26:47 +09:00
klx 6577e063d1 Update postprocessing_gfpgan.py
Fix  gfpgan extension , Inconsistent overlay layer types when visibility value is less than 1
2024-12-26 02:16:05 +08:00
klx 7953c570d9 Update postprocessing_codeformer.py
Fix  Codeformer extension , Inconsistent overlay layer types when visibility value is less than 1
2024-12-26 02:14:49 +08:00
w-e-w f25c3fc9cb fix sd_vae_explanation (#16748) 2024-12-24 15:43:55 -05:00
w-e-w fc0952abb9 Merge pull request #16745 from Sanchows/removed-unused-import-modules-errors
removed unnecessary import 'modules.errors'
2024-12-24 22:58:43 +09:00
Alexander Sachenko b414c62ce4 removed unnecessary import modules.errors 2024-12-24 15:45:10 +03:00
w-e-w 04903af798 Merge pull request #16604 from Haoming02/ext-updt-parallel
Check for Extension Updates in Parallel
2024-12-18 03:21:48 +09:00
w-e-w e8c3b1f2a0 Merge pull request #16718 from Haoming02/bracket-checker-order
[Bracket Checker] Also check for the order of brackets
2024-12-18 02:37:30 +09:00
Haoming 8bf30e3c42 revert IIFE 2024-12-18 01:02:40 +08:00
Haoming fbc51fa210 skip escaped 2024-12-16 09:47:38 +08:00
Haoming 7025a2c4a5 check-for-order 2024-12-12 16:08:15 +08:00
w-e-w 0120768f63 Merge pull request #16687 from Haoming02/dropdown4format
Use gr.Dropdown for Image Formats
2024-11-28 17:39:12 +09:00
w-e-w b425b97ad6 improve img fromat description 2024-11-28 17:14:03 +09:00
w-e-w 539ea3982d use DropdownEditable
use DropdownEditable so user can input other formats if they require it
make the default png the first on the list
2024-11-28 14:10:44 +09:00
Haoming 65bd61e87c format-dropdown 2024-11-27 10:42:50 +08:00
w-e-w 95686227bd limit number of simultaneous updates
shared.opts.concurrent_git_fetch_limit
2024-10-29 20:16:15 +09:00
Haoming df74c3c638 threading 2024-10-29 14:12:42 +08:00
15 changed files with 131 additions and 153 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ jobs:
- name: Install Ruff
run: pip install ruff==0.3.3
- name: Run Ruff
run: ruff .
run: ruff check .
lint-js:
name: eslint
runs-on: ubuntu-latest
@@ -1,36 +1,69 @@
// Stable Diffusion WebUI - Bracket checker
// By Hingashi no Florin/Bwin4L & @akx
// Stable Diffusion WebUI - Bracket Checker
// By @Bwin4L, @akx, @w-e-w, @Haoming02
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
// If there's a mismatch, the keyword counter turns red, and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElem) {
const pairs = [
['(', ')', 'round brackets'],
['[', ']', 'square brackets'],
['{', '}', 'curly brackets']
];
function checkBrackets(textArea, counterElt) {
const counts = {};
textArea.value.matchAll(/(?<!\\)(?:\\\\)*?([(){}[\]])/g).forEach(bracket => {
counts[bracket[1]] = (counts[bracket[1]] || 0) + 1;
});
const errors = [];
const errors = new Set();
let i = 0;
function checkPair(open, close, kind) {
if (counts[open] !== counts[close]) {
errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
while (i < textArea.value.length) {
let char = textArea.value[i];
let escaped = false;
while (char === '\\' && i + 1 < textArea.value.length) {
escaped = !escaped;
i++;
char = textArea.value[i];
}
if (escaped) {
i++;
continue;
}
for (const [open, close, label] of pairs) {
if (char === open) {
counts[label] = (counts[label] || 0) + 1;
} else if (char === close) {
counts[label] = (counts[label] || 0) - 1;
if (counts[label] < 0) {
errors.add(`Incorrect order of ${label}.`);
}
}
}
i++;
}
for (const [open, close, label] of pairs) {
if (counts[label] == undefined) {
continue;
}
if (counts[label] > 0) {
errors.add(`${open} ... ${close} - Detected ${counts[label]} more opening than closing ${label}.`);
} else if (counts[label] < 0) {
errors.add(`${open} ... ${close} - Detected ${-counts[label]} more closing than opening ${label}.`);
}
}
checkPair('(', ')', 'round brackets');
checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0);
counterElem.title = [...errors].join('\n');
counterElem.classList.toggle('error', errors.size !== 0);
}
function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter);
const textarea = gradioApp().querySelector(`#${id_prompt} > label > textarea`);
const counter = gradioApp().getElementById(id_counter);
if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
onEdit(`${id_prompt}_BracketChecking`, textarea, 400, () => checkBrackets(textarea, counter));
}
}
+1 -1
View File
@@ -1,5 +1,5 @@
<div>
<a href="{api_docs}">API</a>
<a href="{api_docs}" target="_blank">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">GitHub</a>
 • 
+1
View File
@@ -50,6 +50,7 @@ def check_versions():
def initialize():
from modules import initialize_util
initialize_util.allow_add_middleware_after_start()
initialize_util.fix_torch_version()
initialize_util.fix_pytorch_lightning()
initialize_util.fix_asyncio_event_loop_policy()
+37 -3
View File
@@ -5,6 +5,8 @@ import sys
import re
from modules.timer import startup_timer
from modules import patches
from functools import wraps
def gradio_server_name():
@@ -191,11 +193,8 @@ def configure_opts_onchange():
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
@@ -213,3 +212,38 @@ def configure_cors_middleware(app):
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
def allow_add_middleware_after_start():
from starlette.applications import Starlette
def add_middleware_wrapper(func):
"""Patch Starlette.add_middleware to allow for middleware to be added after the app has started
Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).
If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
the first way is to just set middleware_stack to None and then retry
the second manually insert the middleware into the user_middleware list without calling add_middleware
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
res = None
try:
res = func(self, *args, **kwargs)
except RuntimeError as _:
try:
self.middleware_stack = None
res = func(self, *args, **kwargs)
except RuntimeError as e:
print(f'Warning: "{e}", Retrying...')
from starlette.middleware import Middleware
self.user_middleware.insert(0, Middleware(*args, **kwargs))
self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests
return res
return wrapper
patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))
+1 -3
View File
@@ -43,9 +43,7 @@ def check_python_version():
supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors):
import modules.errors
modules.errors.print_error_explanation(f"""
errors.print_error_explanation(f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
+6 -34
View File
@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -457,20 +457,6 @@ class StableDiffusionProcessing:
opts.emphasis,
)
def apply_generation_params_list(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
self.extra_generation_params[key] = current_value + value
else:
self.extra_generation_params[key] = value
def clear_marked_generation_params(self):
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
for key, value in list(self.extra_generation_params.items()):
if getattr(value, 'to_be_clear_before_batch', False):
self.extra_generation_params.pop(key)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -494,10 +480,6 @@ class StableDiffusionProcessing:
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_list(generation_params_states)
return cache[1]
cache = caches[0]
@@ -505,13 +487,6 @@ class StableDiffusionProcessing:
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_list(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
cache[2] = (generation_params_states, cached_params)
cache[0] = cached_params
return cache[1]
@@ -527,8 +502,6 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
self.extra_generation_params.update(model_hijack.extra_generation_params)
def get_conds(self):
return self.c, self.uc
@@ -828,10 +801,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
for key, value in generation_params.items():
try:
if callable(value):
generation_params[key] = value(**locals())
elif isinstance(value, list):
if isinstance(value, list):
generation_params[key] = value[index]
elif callable(value):
generation_params[key] = value(**locals())
except Exception:
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
generation_params[key] = None
@@ -965,7 +938,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted or state.stopping_generation:
break
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
sd_models.reload_model_weights() # model can be changed for example by refiner
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
@@ -993,6 +965,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
p.extra_generation_params.update(model_hijack.extra_generation_params)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
@@ -1539,8 +1513,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
self.extra_generation_params.update(model_hijack.extra_generation_params)
def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
+1 -9
View File
@@ -2,7 +2,7 @@ import torch
from torch.nn.functional import silu
from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
@@ -321,14 +321,6 @@ class StableDiffusionModelHijack:
self.comments = []
self.extra_generation_params = {}
def extract_generation_params_states(self):
"""Extracts GenerationParametersList so that they can be cached and restored later"""
states = {}
for key in list(self.extra_generation_params):
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
states[key] = self.extra_generation_params.pop(key)
return states
def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"
+6 -28
View File
@@ -3,7 +3,7 @@ from collections import namedtuple
import torch
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules.shared import opts
@@ -27,30 +27,6 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class EmphasisMode(util.GenerationParametersList):
def __init__(self, emphasis_mode:str = None):
super().__init__()
self.emphasis_mode = emphasis_mode
def __call__(self, *args, **kwargs):
return self.emphasis_mode
def __add__(self, other):
if isinstance(other, EmphasisMode):
return self if self.emphasis_mode else other
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented
def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented
def __str__(self):
return self.emphasis_mode if self.emphasis_mode else ''
class TextConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -262,10 +238,12 @@ class TextConditionalModel(torch.nn.Module):
hashes.append(f"{name}: {shorthash}")
if hashes:
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)
if self.hijack.extra_generation_params.get("TI hashes"):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if self.return_pooled:
return torch.hstack(zs), zs[0].pooled
+1 -1
View File
@@ -125,7 +125,7 @@ def ui_reorder_categories():
def callbacks_order_settings():
options = {
"sd_vae_explanation": OptionHTML("""
"callbacks_order_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed.
"""),
+3 -2
View File
@@ -33,12 +33,12 @@ categories.register_category("training", "Training")
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
"samples_format": OptionInfo('png', 'File format for images', ui_components.DropdownEditable, {"choices": ("png", "jpg", "jpeg", "webp", "avif")}).info("manual input of <a href='https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html' target='_blank'>other formats</a> is possible, but compatibility is not guaranteed"),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_format": OptionInfo('png', 'File format for grids', ui_components.DropdownEditable, {"choices": ("png", "jpg", "jpeg", "webp", "avif")}).info("manual input of <a href='https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html' target='_blank'>other formats</a> is possible, but compatibility is not guaranteed"),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
@@ -128,6 +128,7 @@ options_templates.update(options_section(('system', "System", "system"), {
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
"concurrent_git_fetch_limit": OptionInfo(16, "Number of simultaneous extension update checks ", gr.Slider, {"step": 1, "minimum": 1, "maximum": 100}).info("reduce extension update check time"),
}))
options_templates.update(options_section(('profiler', "Profiler", "system"), {
+11 -4
View File
@@ -1,5 +1,6 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
import threading
import time
from datetime import datetime, timezone
@@ -106,18 +107,24 @@ def check_updates(id_task, disable_list):
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
shared.state.job_count = len(exts)
for ext in exts:
shared.state.textinfo = ext.name
lock = threading.Lock()
def _check_update(ext):
try:
ext.check_updates()
except FileNotFoundError as e:
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
with lock:
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
with lock:
shared.state.textinfo = ext.name
shared.state.nextjob()
shared.state.nextjob()
with ThreadPoolExecutor(max_workers=max(1, int(shared.opts.concurrent_git_fetch_limit))) as executor:
for ext in exts:
executor.submit(_check_update, ext)
return extension_table(), ""
-46
View File
@@ -288,49 +288,3 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
class GenerationParametersList(list):
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
due to StableDiffusionProcessing.get_conds_with_caching
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
When an extra_generation_params is set in StableDiffusionModelHijack using this object,
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
Depending on the use case the methods can be overwritten.
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
When called by create_infotext it will access to the locals() of the caller,
if return str, the value will be written to infotext, if return None will be ignored.
"""
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
super().__init__(*args, **kwargs)
self._to_be_clear_before_batch = to_be_clear_before_batch
def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))
@property
def to_be_clear_before_batch(self):
return self._to_be_clear_before_batch
def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented
def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented
def __str__(self):
return self.__call__()
+4
View File
@@ -29,6 +29,10 @@ class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing
res = Image.fromarray(restored_img)
if codeformer_visibility < 1.0:
if pp.image.size != res.size:
res = res.resize(pp.image.size)
if pp.image.mode != res.mode:
res = res.convert(pp.image.mode)
res = Image.blend(pp.image, res, codeformer_visibility)
pp.image = res
+4
View File
@@ -26,6 +26,10 @@ class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
res = Image.fromarray(restored_img)
if gfpgan_visibility < 1.0:
if pp.image.size != res.size:
res = res.resize(pp.image.size)
if pp.image.mode != res.mode:
res = res.convert(pp.image.mode)
res = Image.blend(pp.image, res, gfpgan_visibility)
pp.image = res