Compare commits

..

4 Commits

Author SHA1 Message Date
AUTOMATIC1111 f1d7c07a5a repair img2img 2023-08-07 12:21:05 +03:00
AUTOMATIC1111 686598387f send noisy latent into refiner without adding noise 2023-08-07 12:10:16 +03:00
AUTOMATIC1111 3f82820612 apply unet overrides after switching model 2023-08-07 08:16:30 +03:00
AUTOMATIC1111 6c7b6ecb81 alternative refiner implementation 2023-08-06 22:08:52 +03:00
113 changed files with 3373 additions and 5512 deletions
+1 -3
View File
@@ -90,8 +90,6 @@ module.exports = {
// localStorage.js
localSet: "readonly",
localGet: "readonly",
localRemove: "readonly",
// resizeHandle.js
setupResizeHandle: "writable"
localRemove: "readonly"
}
};
+71 -7
View File
@@ -26,7 +26,7 @@ body:
id: steps
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step instructions on how to reproduce the bug
description: Please provide us with precise step by step information on how to reproduce the bug
value: |
1. Go to ....
2. Press ....
@@ -37,14 +37,64 @@ body:
id: what-should
attributes:
label: What should have happened?
description: Tell us what you think the normal behavior should be
description: Tell what you think the normal behavior should be
validations:
required: true
- type: textarea
id: sysinfo
- type: input
id: commit
attributes:
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
label: Version or Commit where the problem happens
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations:
required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown
id: platforms
attributes:
label: What platforms do you use to access the UI ?
multiple: true
options:
- Windows
- Linux
- MacOS
- iOS
- Android
- Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown
id: cross_attention_opt
attributes:
label: Cross attention optimization
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
multiple: false
options:
- Automatic
- xformers
- sdp-no-mem
- sdp
- Doggettx
- V1
- InvokeAI
- "None "
validations:
required: true
- type: dropdown
@@ -58,7 +108,21 @@ body:
- Brave
- Apple Safari
- Microsoft Edge
- Other
- type: textarea
id: cmdargs
attributes:
label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell
validations:
required: true
- type: textarea
id: extensions
attributes:
label: List of extensions
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
validations:
required: true
- type: textarea
id: logs
attributes:
-155
View File
@@ -1,158 +1,3 @@
## 1.6.0
### Features:
* refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)
* add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards
* add style editor dialog
* hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))
* option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
* makes all of them work with img2img
* makes prompt composition posssible (AND)
* makes them available for SDXL
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
* textual inversion inference support for SDXL
* extra networks UI: show metadata for SD checkpoints
* checkpoint merger: add metadata support
* prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))
* VAE: allow selecting own VAE for each checkpoint (in user metadata editor)
* VAE: add selected VAE to infotext
* options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))
* add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723))
* change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it
* show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707))
* add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models
* prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457))
### Minor:
* img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))
* postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))
* XYZ: in the axis labels, remove pathnames from model filenames
* XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))
* XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))
* add gradio version warning
* sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))
* use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))
* move some settings to their own section: img2img, VAE
* add checkbox to show/hide dirs for extra networks
* Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))
* gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))
* sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))
* update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))
* option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))
* enable cond cache by default
* git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))
* allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))
* automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))
* put commonly used samplers on top, make DPM++ 2M Karras the default choice
* zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727))
* option to cache Lora networks in memory
* rework hires fix UI to use accordion
* face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back
* change quicksettings items to have variable width
* Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))
* Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
* support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))
* add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))
* support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584))
* make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634))
* configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648))
* make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645))
* more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639))
* make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635))
* make progress bar work independently from live preview display which results in it being updated a lot more often
* forbid Full live preview method for medvram and add a setting to undo the forbidding
* make it possible to localize tooltips and placeholders
* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))
* Restore faces and Tiling generation parameters have been moved to settings out of main UI
* if you want to put them back into main UI, use `Options in main UI` setting on the UI page.
### Extensions and API:
* gradio 3.41.2
* also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd
* support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')
* properly clear the total console progressbar when using txt2img and img2img from API
* add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))
* shared.py and webui.py split into many files
* add --loglevel commandline argument for logging
* add a custom UI element that combines accordion and checkbox
* avoid importing gradio in tests because it spams warnings
* put infotext label for setting into OptionInfo definition rather than in a separate list
* make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))
* option to make scripts UI without gr.Group
* add a way for scripts to register a callback for before/after just a single component's creation
* use dataclass for StableDiffusionProcessing
* store patches for Lora in a specialized module instead of inside torch
* support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698))
* add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616))
* dump current stack traces when exiting with SIGINT
* add type annotations for extra fields of shared.sd_model
### Bug Fixes:
* Don't crash if out of local storage quota for javascriot localStorage
* XYZ plot do not fail if an exception occurs
* fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))
* localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))
* fix sdxl model invalid configuration after the hijack
* correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))
* open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))
* prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))
* add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))
* fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))
* fix options in main UI misbehaving when there's just one element
* make it possible to use a sampler from infotext even if it's hidden in the dropdown
* fix styles missing from the prompt in infotext when making a grid of batch of multiplie images
* prevent bogus progress output in console when calculating hires fix dimensions
* fix --use-textbox-seed
* fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))
* properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))
* MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))
* add second_order to samplers that mistakenly didn't have it
* when refreshing cards in extra networks UI, do not discard user's custom resolution
* fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))
* fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588))
* fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586))
* fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596))
* auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603))
* fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607))
* fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638))
* fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633))
* attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630))
* implement missing undo hijack for SDXL
* fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684))
* fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689))
* fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685))
* fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667))
* create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717))
* prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version
* set devices.dtype_unet correctly
* run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
* prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745))
* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
* fix defaults settings page breaking when any of main UI tabs are hidden
* fix incorrect save/display of new values in Defaults page in settings
* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
* fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
* fix a bug allowing users to bypass gradio and API authentication (reported by vysecurity)
* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))
* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))
* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))
* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))
* get progressbar to display correctly in extensions tab
## 1.5.2
### Bug Fixes:
* fix memory leak when generation fails
* update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk
## 1.5.1
### Minor:
-7
View File
@@ -1,7 +0,0 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- given-names: AUTOMATIC1111
title: "Stable Diffusion Web UI"
date-released: 2022-08-22
url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
+2 -5
View File
@@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip
- Hypernetworks
- Loras (same as Hypernetworks but more pretty)
- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
@@ -93,10 +93,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Reorder elements in the UI from settings screen
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
Alternatively, use online services (like Google Colab):
@@ -6,14 +6,9 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
def activate(self, p, params_list):
additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -61,7 +56,4 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()
pass
-31
View File
@@ -1,31 +0,0 @@
import torch
import networks
from modules import patches
class LoraPatches:
def __init__(self):
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
def undo(self):
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+2 -5
View File
@@ -133,7 +133,7 @@ class NetworkModule:
return 1.0
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
def finalize_updown(self, updown, orig_weight, output_shape):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -145,10 +145,7 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown * self.calc_scale() * self.multiplier(), ex_bias
return updown * self.calc_scale() * self.multiplier()
def calc_updown(self, target):
raise NotImplementedError()
+1 -6
View File
@@ -14,14 +14,9 @@ class NetworkModuleFull(network.NetworkModule):
super().__init__(net, weights)
self.weight = weights.w.get("diff")
self.ex_bias = weights.w.get("diff_b")
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
return self.finalize_updown(updown, orig_weight, output_shape)
-28
View File
@@ -1,28 +0,0 @@
import network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
+40 -143
View File
@@ -1,15 +1,12 @@
import logging
import os
import re
import lora_patches
import network
import network_lora
import network_hada
import network_ia3
import network_lokr
import network_full
import network_norm
import torch
from typing import Union
@@ -22,7 +19,6 @@ module_types = [
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
]
@@ -35,8 +31,6 @@ suffix_conversion = {
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -196,19 +190,11 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
return net
def purge_networks_from_memory():
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
name = next(iter(networks_in_memory))
networks_in_memory.pop(name, None)
devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -226,19 +212,15 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
failed_to_load_networks = []
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
for i, name in enumerate(names):
net = already_loaded.get(name, None)
if network_on_disk is not None:
if net is None:
net = networks_in_memory.get(name)
network_on_disk = networks_on_disk[i]
if network_on_disk is not None:
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try:
net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None)
networks_in_memory[name] = net
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
@@ -249,7 +231,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None:
failed_to_load_networks.append(name)
logging.info(f"Couldn't find network with name {name}")
print(f"Couldn't find network with name {name}")
continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -258,38 +240,23 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.append(net)
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
if weights_backup is None:
return
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias.copy_(bias_backup)
else:
self.bias.copy_(bias_backup)
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
self.weight.copy_(weights_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -304,10 +271,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")
if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
@@ -315,41 +279,21 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None
self.network_bias_backup = bias_backup
if current_names != wanted_names:
network_restore_weights_from_backup(self)
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
with torch.no_grad():
updown = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
else:
self.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
self.weight += updown
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
@@ -357,33 +301,21 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad():
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
with torch.no_grad():
updown_q = module_q.calc_updown(self.in_proj_weight)
updown_k = module_k.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out = module_out.calc_updown(self.out_proj.weight)
self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out
if ex_bias is not None:
if self.out_proj.bias is None:
self.out_proj.bias = torch.nn.Parameter(ex_bias)
else:
self.out_proj.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out
continue
if module is None:
continue
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
print(f'failed to calculate network weights for layer {network_layer_name}')
self.network_current_names = wanted_names
@@ -410,7 +342,7 @@ def network_forward(module, input, original_forward):
if module is None:
continue
y = module.forward(input, y)
y = module.forward(y, input)
return y
@@ -422,74 +354,44 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, originals.Linear_forward)
return network_forward(self, input, torch.nn.Linear_forward_before_network)
network_apply_weights(self)
return originals.Linear_forward(self, input)
return torch.nn.Linear_forward_before_network(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return originals.Linear_load_state_dict(self, *args, **kwargs)
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, originals.Conv2d_forward)
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
network_apply_weights(self)
return originals.Conv2d_forward(self, input)
return torch.nn.Conv2d_forward_before_network(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, originals.GroupNorm_forward)
network_apply_weights(self)
return originals.GroupNorm_forward(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, originals.LayerNorm_forward)
network_apply_weights(self)
return originals.LayerNorm_forward(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
return originals.MultiheadAttention_forward(self, *args, **kwargs)
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
def list_available_networks():
@@ -557,14 +459,9 @@ def infotext_pasted(infotext, params):
params["Prompt"] += "\n" + "".join(added)
originals: lora_patches.LoraPatches = None
extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
+34 -10
View File
@@ -1,30 +1,57 @@
import re
import torch
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
networks.originals.undo()
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(networks.extra_network_lora)
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
extra_network = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network)
extra_networks.register_extra_network_alias(extra_network, "lyco")
networks.originals = lora_patches.LoraPatches()
if not hasattr(torch.nn, 'Linear_forward_before_network'):
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
@@ -38,7 +65,6 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
}))
@@ -95,5 +121,3 @@ def infotext_pasted(infotext, d):
script_callbacks.on_infotext_pasted(infotext_pasted)
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
@@ -70,7 +70,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
metadata = item.get("metadata") or {}
keys = {
'ss_output_name': "Output name:",
'ss_sd_model_name': "Model:",
'ss_clip_skip': "Clip skip:",
'ss_network_module': "Kohya module:",
@@ -25,10 +25,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
item = {
"name": name,
"filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
@@ -12,22 +12,8 @@ onUiLoaded(async() => {
"Sketch": elementIDs.sketch
};
// Helper functions
// Get active tab
/**
* Waits for an element to be present in the DOM.
*/
const waitForElement = (id) => new Promise(resolve => {
const checkForElement = () => {
const element = document.querySelector(id);
if (element) return resolve(element);
setTimeout(checkForElement, 100);
};
checkForElement();
});
function getActiveTab(elements, all = false) {
const tabs = elements.img2imgTabs.querySelectorAll("button");
@@ -48,7 +34,7 @@ onUiLoaded(async() => {
// Wait until opts loaded
async function waitForOpts() {
for (; ;) {
for (;;) {
if (window.opts && Object.keys(window.opts).length) {
return window.opts;
}
@@ -56,11 +42,6 @@ onUiLoaded(async() => {
}
}
// Detect whether the element has a horizontal scroll bar
function hasHorizontalScrollbar(element) {
return element.scrollWidth > element.clientWidth;
}
// Function for defining the "Ctrl", "Shift" and "Alt" keys
function isModifierKey(event, key) {
switch (key) {
@@ -220,8 +201,7 @@ onUiLoaded(async() => {
canvas_hotkey_overlap: "KeyO",
canvas_disabled_functions: [],
canvas_show_tooltip: true,
canvas_auto_expand: true,
canvas_blur_prompt: false,
canvas_blur_prompt: false
};
const functionMap = {
@@ -269,7 +249,7 @@ onUiLoaded(async() => {
input?.addEventListener("input", () => restoreImgRedMask(elements));
}
function applyZoomAndPan(elemId, isExtension = true) {
function applyZoomAndPan(elemId) {
const targetElement = gradioApp().querySelector(elemId);
if (!targetElement) {
@@ -381,12 +361,6 @@ onUiLoaded(async() => {
panY: 0
};
if (isExtension) {
targetElement.style.overflow = "hidden";
}
targetElement.isZoomed = false;
fixCanvas();
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
@@ -397,27 +371,8 @@ onUiLoaded(async() => {
toggleOverlap("off");
fullScreenMode = false;
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
if (closeBtn) {
closeBtn.addEventListener("click", resetZoom);
}
if (canvas && isExtension) {
const parentElement = targetElement.closest('[id^="component-"]');
if (
canvas &&
parseFloat(canvas.style.width) > parentElement.offsetWidth &&
parseFloat(targetElement.style.width) > parentElement.offsetWidth
) {
fitToElement();
return;
}
}
if (
canvas &&
!isExtension &&
parseFloat(canvas.style.width) > 865 &&
parseFloat(targetElement.style.width) > 865
) {
@@ -426,6 +381,9 @@ onUiLoaded(async() => {
}
targetElement.style.width = "";
if (canvas) {
targetElement.style.height = canvas.style.height;
}
}
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
@@ -481,7 +439,7 @@ onUiLoaded(async() => {
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
function updateZoom(newZoomLevel, mouseX, mouseY) {
newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15));
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
elemData[elemId].panX +=
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
@@ -492,10 +450,6 @@ onUiLoaded(async() => {
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
toggleOverlap("on");
if (isExtension) {
targetElement.style.overflow = "visible";
}
return newZoomLevel;
}
@@ -518,12 +472,10 @@ onUiLoaded(async() => {
fullScreenMode = false;
elemData[elemId].zoomLevel = updateZoom(
elemData[elemId].zoomLevel +
(operation === "+" ? delta : -delta),
(operation === "+" ? delta : -delta),
zoomPosX - targetElement.getBoundingClientRect().left,
zoomPosY - targetElement.getBoundingClientRect().top
);
targetElement.isZoomed = true;
}
}
@@ -537,19 +489,10 @@ onUiLoaded(async() => {
//Reset Zoom
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
let parentElement;
if (isExtension) {
parentElement = targetElement.closest('[id^="component-"]');
} else {
parentElement = targetElement.parentElement;
}
// Get element and screen dimensions
const elementWidth = targetElement.offsetWidth;
const elementHeight = targetElement.offsetHeight;
const parentElement = targetElement.parentElement;
const screenWidth = parentElement.clientWidth;
const screenHeight = parentElement.clientHeight;
@@ -602,12 +545,8 @@ onUiLoaded(async() => {
if (!canvas) return;
if (canvas.offsetWidth > 862 || isExtension) {
targetElement.style.width = (canvas.offsetWidth + 2) + "px";
}
if (isExtension) {
targetElement.style.overflow = "visible";
if (canvas.offsetWidth > 862) {
targetElement.style.width = canvas.offsetWidth + "px";
}
if (fullScreenMode) {
@@ -709,48 +648,8 @@ onUiLoaded(async() => {
mouseY = e.offsetY;
}
// Simulation of the function to put a long image into the screen.
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
// We hide the image and show it to the user when it is ready.
targetElement.isExpanded = false;
function autoExpand() {
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
if (canvas) {
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
targetElement.style.visibility = "hidden";
setTimeout(() => {
fitToScreen();
resetZoom();
targetElement.style.visibility = "visible";
targetElement.isExpanded = true;
}, 10);
}
}
}
targetElement.addEventListener("mousemove", getMousePosition);
//observers
// Creating an observer with a callback function to handle DOM changes
const observer = new MutationObserver((mutationsList, observer) => {
for (let mutation of mutationsList) {
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
mutation.target.tagName.toLowerCase() === 'canvas') {
targetElement.isExpanded = false;
setTimeout(resetZoom, 10);
}
}
});
// Apply auto expand if enabled
if (hotkeysConfig.canvas_auto_expand) {
targetElement.addEventListener("mousemove", autoExpand);
// Set up an observer to track attribute changes
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
}
// Handle events only inside the targetElement
let isKeyDownHandlerAttached = false;
@@ -855,11 +754,6 @@ onUiLoaded(async() => {
if (isMoving && elemId === activeElement) {
updatePanPosition(e.movementX, e.movementY);
targetElement.style.pointerEvents = "none";
if (isExtension) {
targetElement.style.overflow = "visible";
}
} else {
targetElement.style.pointerEvents = "auto";
}
@@ -870,93 +764,13 @@ onUiLoaded(async() => {
isMoving = false;
};
// Checks for extension
function checkForOutBox() {
const parentElement = targetElement.closest('[id^="component-"]');
if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) {
resetZoom();
targetElement.isExpanded = true;
}
if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) {
resetZoom();
}
if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) {
resetZoom();
}
}
if (isExtension) {
targetElement.addEventListener("mousemove", checkForOutBox);
}
window.addEventListener('resize', (e) => {
resetZoom();
if (isExtension) {
targetElement.isExpanded = false;
targetElement.isZoomed = false;
}
});
gradioApp().addEventListener("mousemove", handleMoveByKey);
}
applyZoomAndPan(elementIDs.sketch, false);
applyZoomAndPan(elementIDs.inpaint, false);
applyZoomAndPan(elementIDs.inpaintSketch, false);
applyZoomAndPan(elementIDs.sketch);
applyZoomAndPan(elementIDs.inpaint);
applyZoomAndPan(elementIDs.inpaintSketch);
// Make the function global so that other extensions can take advantage of this solution
const applyZoomAndPanIntegration = async(id, elementIDs) => {
const mainEl = document.querySelector(id);
if (id.toLocaleLowerCase() === "none") {
for (const elementID of elementIDs) {
const el = await waitForElement(elementID);
if (!el) break;
applyZoomAndPan(elementID);
}
return;
}
if (!mainEl) return;
mainEl.addEventListener("click", async() => {
for (const elementID of elementIDs) {
const el = await waitForElement(elementID);
if (!el) break;
applyZoomAndPan(elementID);
}
}, {once: true});
};
window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
/*
The function `applyZoomAndPanIntegration` takes two arguments:
1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
Example usage:
applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
*/
// More examples
// Add integration with ControlNet txt2img One TAB
// applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
// Add integration with ControlNet txt2img Tabs
// applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
// Add integration with Inpaint Anything
// applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
window.applyZoomAndPan = applyZoomAndPan;
});
@@ -9,7 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
@@ -61,6 +61,3 @@
to {opacity: 1;}
}
.styler {
overflow:inherit !important;
}
@@ -1,7 +1,5 @@
import math
import gradio as gr
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
from modules import scripts, shared, ui_components, ui_settings
from modules.ui_components import FormColumn
@@ -21,38 +19,18 @@ class ExtraOptionsSection(scripts.Script):
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
for row in range(row_count):
with gr.Row():
for col in range(shared.opts.extra_options_cols):
index = row * shared.opts.extra_options_cols + col
if index >= len(extra_options):
break
setting_name = extra_options[index]
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
self.comps.append(comp)
self.setting_names.append(setting_name)
setting_infotext_name = mapping.get(setting_name)
if setting_infotext_name is not None:
self.infotext_fields.append((comp, setting_infotext_name))
self.comps.append(comp)
self.setting_names.append(setting_name)
def get_settings_values():
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
return res[0] if len(res) == 1 else res
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
@@ -65,10 +43,6 @@ class ExtraOptionsSection(scripts.Script):
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
}))
@@ -20,13 +20,7 @@ function reportWindowSize() {
var button = gradioApp().getElementById(tab + '_generate_box');
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
target.insertBefore(button, target.firstElementChild);
gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
}
}
window.addEventListener("resize", reportWindowSize);
onUiLoaded(function() {
reportWindowSize();
});
+1 -1
View File
@@ -33,7 +33,7 @@ function extensions_check() {
var id = randomId();
requestProgress(id, gradioApp().getElementById('extensions_installed_html'), null, function() {
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
});
+1 -10
View File
@@ -249,15 +249,6 @@ function popup(contents) {
globalPopup.style.display = "flex";
}
var storedPopupIds = {};
function popupId(id) {
if (!storedPopupIds[id]) {
storedPopupIds[id] = gradioApp().getElementById(id);
}
popup(storedPopupIds[id]);
}
function extraNetworksShowMetadata(text) {
var elem = document.createElement('pre');
elem.classList.add('popup-metadata');
@@ -341,7 +332,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
newDiv.innerHTML = data.html;
var newCard = newDiv.firstElementChild;
newCard.style.display = '';
newCard.style = '';
card.parentElement.insertBefore(newCard, card);
card.parentElement.removeChild(card);
}
-5
View File
@@ -136,11 +136,6 @@ function setupImageForLightbox(e) {
var event = isFirefox ? 'mousedown' : 'click';
e.addEventListener(event, function(evt) {
if (evt.button == 1) {
open(evt.target.src);
evt.preventDefault();
return;
}
if (!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
-37
View File
@@ -1,37 +0,0 @@
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
var elem = mutationRecord.target;
var open = elem.classList.contains('open');
var accordion = elem.parentNode;
accordion.classList.toggle('input-accordion-open', open);
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
checkbox.checked = open;
updateInput(checkbox);
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
extra.style.display = open ? "" : "none";
}
});
});
function inputAccordionChecked(id, checked) {
var label = gradioApp().querySelector('#' + id + " .label-wrap");
if (label.classList.contains('open') != checked) {
label.click();
}
}
onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
var labelWrap = accordion.querySelector('.label-wrap');
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
}
});
+2 -31
View File
@@ -107,41 +107,12 @@ function processNode(node) {
});
}
function localizeWholePage() {
processNode(gradioApp());
function elem(comp) {
var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id;
return gradioApp().getElementById(elem_id);
}
for (var comp of window.gradio_config.components) {
if (comp.props.webui_tooltip) {
let e = elem(comp);
let tl = e ? getTranslation(e.title) : undefined;
if (tl !== undefined) {
e.title = tl;
}
}
if (comp.props.placeholder) {
let e = elem(comp);
let textbox = e ? e.querySelector('[placeholder]') : null;
let tl = textbox ? getTranslation(textbox.placeholder) : undefined;
if (tl !== undefined) {
textbox.placeholder = tl;
}
}
}
}
function dumpTranslations() {
if (!hasLocalization()) {
// If we don't have any localization,
// we will not have traversed the app to find
// original_lines, so do that now.
localizeWholePage();
processNode(gradioApp());
}
var dumped = {};
if (localization.rtl) {
@@ -183,7 +154,7 @@ document.addEventListener("DOMContentLoaded", function() {
});
});
localizeWholePage();
processNode(gradioApp());
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
+1 -1
View File
@@ -15,7 +15,7 @@ onAfterUiUpdate(function() {
}
}
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"] div[id$="_results"] .thumbnail-item > img');
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
if (galleryPreviews == null) return;
+29 -38
View File
@@ -69,6 +69,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var dateStart = new Date();
var wasEverActive = false;
var parentProgressbar = progressbarContainer.parentNode;
var parentGallery = gallery ? gallery.parentNode : null;
var divProgress = document.createElement('div');
divProgress.className = 'progressDiv';
@@ -79,26 +80,32 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.appendChild(divInner);
parentProgressbar.insertBefore(divProgress, progressbarContainer);
var livePreview = null;
if (parentGallery) {
var livePreview = document.createElement('div');
livePreview.className = 'livePreview';
parentGallery.insertBefore(livePreview, gallery);
}
var removeProgressBar = function() {
if (!divProgress) return;
setTitle("");
parentProgressbar.removeChild(divProgress);
if (gallery && livePreview) gallery.removeChild(livePreview);
if (parentGallery) parentGallery.removeChild(livePreview);
atEnd();
divProgress = null;
};
var funProgress = function(id_task) {
request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) {
var fun = function(id_task, id_live_preview) {
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
if (res.completed) {
removeProgressBar();
return;
}
var rect = progressbarContainer.getBoundingClientRect();
if (rect.width) {
divProgress.style.width = rect.width + "px";
}
let progressText = "";
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
@@ -112,6 +119,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
progressText += " ETA: " + formatTime(res.eta);
}
setTitle(progressText);
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
@@ -134,33 +142,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return;
}
if (onProgress) {
onProgress(res);
}
setTimeout(() => {
funProgress(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500);
}, function() {
removeProgressBar();
});
};
var funLivePreview = function(id_task, id_live_preview) {
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
if (!divProgress) {
return;
}
if (res.live_preview && gallery) {
rect = gallery.getBoundingClientRect();
if (rect.width) {
livePreview.style.width = rect.width + "px";
livePreview.style.height = rect.height + "px";
}
var img = new Image();
img.onload = function() {
if (!livePreview) {
livePreview = document.createElement('div');
livePreview.className = 'livePreview';
gallery.insertBefore(livePreview, gallery.firstElementChild);
}
livePreview.appendChild(img);
if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild);
@@ -169,18 +160,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
img.src = res.live_preview;
}
if (onProgress) {
onProgress(res);
}
setTimeout(() => {
funLivePreview(id_task, res.id_live_preview);
fun(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500);
}, function() {
removeProgressBar();
});
};
funProgress(id_task, 0);
if (gallery) {
funLivePreview(id_task, 0);
}
fun(id_task, 0);
}
-141
View File
@@ -1,141 +0,0 @@
(function() {
const GRADIO_MIN_WIDTH = 320;
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
const PAD = 16;
const DEBOUNCE_TIME = 100;
const R = {
tracking: false,
parent: null,
parentWidth: null,
leftCol: null,
leftColStartWidth: null,
screenX: null,
};
let resizeTimer;
let parents = [];
function setLeftColGridTemplate(el, width) {
el.style.gridTemplateColumns = `${width}px 16px 1fr`;
}
function displayResizeHandle(parent) {
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
parent.style.display = 'flex';
if (R.handle != null) {
R.handle.style.opacity = '0';
}
return false;
} else {
parent.style.display = 'grid';
if (R.handle != null) {
R.handle.style.opacity = '100';
}
return true;
}
}
function afterResize(parent) {
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
const oldParentWidth = R.parentWidth;
const newParentWidth = parent.offsetWidth;
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
const ratio = newParentWidth / oldParentWidth;
const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
setLeftColGridTemplate(parent, newWidthL);
R.parentWidth = newParentWidth;
}
}
function setup(parent) {
const leftCol = parent.firstElementChild;
const rightCol = parent.lastElementChild;
parents.push(parent);
parent.style.display = 'grid';
parent.style.gap = '0';
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
const resizeHandle = document.createElement('div');
resizeHandle.classList.add('resize-handle');
parent.insertBefore(resizeHandle, rightCol);
resizeHandle.addEventListener('mousedown', (evt) => {
if (evt.button !== 0) return;
evt.preventDefault();
evt.stopPropagation();
document.body.classList.add('resizing');
R.tracking = true;
R.parent = parent;
R.parentWidth = parent.offsetWidth;
R.handle = resizeHandle;
R.leftCol = leftCol;
R.leftColStartWidth = leftCol.offsetWidth;
R.screenX = evt.screenX;
});
resizeHandle.addEventListener('dblclick', (evt) => {
evt.preventDefault();
evt.stopPropagation();
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
});
afterResize(parent);
}
window.addEventListener('mousemove', (evt) => {
if (evt.button !== 0) return;
if (R.tracking) {
evt.preventDefault();
evt.stopPropagation();
const delta = R.screenX - evt.screenX;
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
setLeftColGridTemplate(R.parent, leftColWidth);
}
});
window.addEventListener('mouseup', (evt) => {
if (evt.button !== 0) return;
if (R.tracking) {
evt.preventDefault();
evt.stopPropagation();
R.tracking = false;
document.body.classList.remove('resizing');
}
});
window.addEventListener('resize', () => {
clearTimeout(resizeTimer);
resizeTimer = setTimeout(function() {
for (const parent of parents) {
afterResize(parent);
}
}, DEBOUNCE_TIME);
});
setupResizeHandle = setup;
})();
onUiLoaded(function() {
for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
if (!elem.querySelector('.resize-handle')) {
setupResizeHandle(elem);
}
}
});
+19 -2
View File
@@ -19,11 +19,28 @@ function all_gallery_buttons() {
}
function selected_gallery_button() {
return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
var visibleCurrentButton = null;
allCurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
});
return visibleCurrentButton;
}
function selected_gallery_index() {
return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
var buttons = all_gallery_buttons();
var button = selected_gallery_button();
var result = -1;
buttons.forEach(function(v, i) {
if (v == button) {
result = i;
}
});
return result;
}
function extract_image_from_gallery(gallery) {
-7
View File
@@ -25,13 +25,6 @@ start = launch_utils.start
def main():
if args.dump_sysinfo:
filename = launch_utils.dump_sysinfo()
print(f"Sysinfo saved as {filename}. Exiting...")
exit(0)
launch_utils.startup_timer.record("initial startup")
with launch_utils.startup_timer.subcategory("prepare environment"):
+5 -44
View File
@@ -4,8 +4,6 @@ import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import gradio as gr
from threading import Lock
from io import BytesIO
@@ -25,7 +23,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -57,41 +56,7 @@ def setUpscalers(req: dict):
return reqDict
def verify_url(url):
"""Returns True if the url refers to a global resource."""
import socket
from urllib.parse import urlparse
try:
parsed_url = urlparse(url)
domain_name = parsed_url.netloc
host = socket.gethostbyname_ex(domain_name)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if not ip_addr.is_global:
return False
except Exception:
return False
return True
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
if not opts.api_enable_requests:
raise HTTPException(status_code=500, detail="Requests not allowed")
if opts.api_forbid_local_requests and not verify_url(encoding):
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
@@ -365,7 +330,6 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
@@ -426,7 +390,6 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
@@ -570,7 +533,7 @@ class Api:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
shared.opts.set(k, v, is_api=True)
shared.opts.set(k, v)
shared.opts.save(shared.config_filename)
return
@@ -602,12 +565,10 @@ class Api:
]
def get_sd_models(self):
import modules.sd_models as sd_models
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_sd_vaes(self):
import modules.sd_vae as sd_vae
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+4 -5
View File
@@ -50,12 +50,10 @@ class PydanticModelGenerator:
additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
if field_type == 'Image':
# images are sent as base64 strings via API
field_type = 'str'
return Optional[field_type]
def merge_class_params(class_):
@@ -65,6 +63,7 @@ class PydanticModelGenerator:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
@@ -73,7 +72,7 @@ class PydanticModelGenerator:
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
field_value=None if isinstance(v.default, property) else v.default
field_value=v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
+2 -6
View File
@@ -1,12 +1,11 @@
import json
import os
import os.path
import threading
import time
from modules.paths import data_path, script_path
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
cache_filename = os.path.join(data_path, "cache.json")
cache_data = None
cache_lock = threading.Lock()
@@ -30,12 +29,9 @@ def dump_cache():
time.sleep(1)
with cache_lock:
cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file:
with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
os.replace(cache_filename_tmp, cache_filename)
dump_cache_after = None
dump_cache_thread = None
+3 -2
View File
@@ -1,10 +1,11 @@
from functools import wraps
import html
import threading
import time
from modules import shared, progress, errors, devices, fifo_lock
from modules import shared, progress, errors, devices
queue_lock = fifo_lock.FIFOLock()
queue_lock = threading.Lock()
def wrap_queued_call(func):
+2 -5
View File
@@ -16,8 +16,6 @@ parser.add_argument("--test-server", action='store_true', help="launch.py argume
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
@@ -36,10 +34,9 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
@@ -83,7 +80,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
+7 -9
View File
@@ -8,12 +8,14 @@ import time
import tqdm
from datetime import datetime
from collections import OrderedDict
import git
from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
all_config_states = {}
all_config_states = OrderedDict()
def list_config_states():
@@ -26,14 +28,10 @@ def list_config_states():
for filename in os.listdir(config_states_dir):
if filename.endswith(".json"):
path = os.path.join(config_states_dir, filename)
try:
with open(path, "r", encoding="utf-8") as f:
j = json.load(f)
assert "created_at" in j, '"created_at" does not exist'
j["filepath"] = path
config_states.append(j)
except Exception as e:
print(f'[ERROR]: Config states {path}, {e}')
with open(path, "r", encoding="utf-8") as f:
j = json.load(f)
j["filepath"] = path
config_states.append(j)
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
+87 -2
View File
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
from modules import errors, shared
from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -17,6 +17,8 @@ def has_mps() -> bool:
def get_cuda_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -38,6 +40,8 @@ def get_optimal_device():
def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu:
return cpu
@@ -92,7 +96,87 @@ def cond_cast_float(input):
nv_rng = None
def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
from modules.shared import opts
manual_seed(seed)
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
from modules.shared import opts
if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def autocast(disable=False):
from modules import shared
if disable:
return contextlib.nullcontext()
@@ -111,6 +195,8 @@ class NansException(Exception):
def test_for_nans(x, where):
from modules import shared
if shared.cmd_opts.disable_nan_check:
return
@@ -150,4 +236,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
+1 -1
View File
@@ -95,7 +95,7 @@ def check_versions():
expected_torch_version = "2.0.0"
expected_xformers_version = "0.0.20"
expected_gradio_version = "3.41.2"
expected_gradio_version = "3.39.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
print_error_explanation(f"""
+3 -1
View File
@@ -1,7 +1,7 @@
import os
import threading
from modules import shared, errors, cache, scripts
from modules import shared, errors, cache
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -90,6 +90,8 @@ class Extension:
self.have_info_from_repo = True
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
+17 -43
View File
@@ -1,7 +1,6 @@
import json
import os
import re
import logging
from collections import defaultdict
from modules import errors
@@ -87,55 +86,27 @@ class ExtraNetwork:
raise NotImplementedError
def lookup_extra_networks(extra_network_data):
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
Example input:
{
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
}
Example output:
{
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
}
"""
res = {}
for extra_network_name, extra_network_args in list(extra_network_data.items()):
extra_network = extra_network_registry.get(extra_network_name, None)
alias = extra_network_aliases.get(extra_network_name, None)
if alias is not None and extra_network is None:
extra_network = alias
if extra_network is None:
logging.info(f"Skipping unknown extra network: {extra_network_name}")
continue
res.setdefault(extra_network, []).extend(extra_network_args)
return res
def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
activated = []
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
extra_network = extra_network_aliases.get(extra_network_name, None)
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
if extra_network in activated:
@@ -154,16 +125,19 @@ def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
data = lookup_extra_networks(extra_network_data)
for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
for extra_network in data:
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating extra network {extra_network.name}")
errors.display(e, f"deactivating extra network {extra_network_name}")
for extra_network_name, extra_network in extra_network_registry.items():
if extra_network in data:
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue
try:
-37
View File
@@ -1,37 +0,0 @@
import threading
import collections
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
class FIFOLock(object):
def __init__(self):
self._lock = threading.Lock()
self._inner_lock = threading.Lock()
self._pending_threads = collections.deque()
def acquire(self, blocking=True):
with self._inner_lock:
lock_acquired = self._lock.acquire(False)
if lock_acquired:
return True
elif not blocking:
return False
release_event = threading.Event()
self._pending_threads.append(release_event)
release_event.wait()
return self._lock.acquire()
def release(self):
with self._inner_lock:
if self._pending_threads:
release_event = self._pending_threads.popleft()
release_event.set()
self._lock.release()
__enter__ = acquire
def __exit__(self, t, v, tb):
self.release()
+30 -19
View File
@@ -6,10 +6,10 @@ import re
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing
from modules import shared, ui_tempdir, script_callbacks
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
@@ -32,7 +32,6 @@ class ParamBinding:
def reset():
paste_fields.clear()
registered_param_bindings.clear()
def quote(text):
@@ -199,6 +198,7 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
from modules import processing
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
@@ -317,18 +317,36 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
infotext_to_setting_name_mapping = [
]
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
Example content:
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
('Schedule type', 'k_sched_type'),
('Schedule max sigma', 'sigma_max'),
('Schedule min sigma', 'sigma_min'),
('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
('Sigma churn', 's_churn'),
('Sigma tmin', 's_tmin'),
('Sigma tmax', 's_tmax'),
('Sigma noise', 's_noise'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('Token merging ratio', 'token_merging_ratio'),
('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
('Refiner', 'sd_refiner_checkpoint'),
('Refiner switch at', 'sd_refiner_switch_at'),
]
"""
def create_override_settings_dict(text_pairs):
@@ -349,8 +367,7 @@ def create_override_settings_dict(text_pairs):
params[k] = v.strip()
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
for param_name, setting_name in infotext_to_setting_name_mapping:
value = params.get(param_name, None)
if value is None:
@@ -399,16 +416,10 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
return res
if override_settings_component is not None:
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
vals = {}
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in already_handled_fields:
continue
for param_name, setting_name in infotext_to_setting_name_mapping:
v = params.get(param_name, None)
if v is None:
continue
+7 -20
View File
@@ -1,7 +1,6 @@
import gradio as gr
from modules import scripts, ui_tempdir, patches
from modules import scripts
def add_classes_to_gradio_component(comp):
"""
@@ -41,8 +40,6 @@ def Block_get_config(self):
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
config.pop('example_inputs', None)
return config
@@ -54,20 +51,10 @@ def BlockContext_init(self, *args, **kwargs):
return res
def Blocks_get_config_file(self, *args, **kwargs):
config = original_Blocks_get_config_file(self, *args, **kwargs)
original_IOComponent_init = gr.components.IOComponent.__init__
original_Block_get_config = gr.blocks.Block.get_config
original_BlockContext_init = gr.blocks.BlockContext.__init__
for comp_config in config["components"]:
if "example_inputs" in comp_config:
comp_config["example_inputs"] = {"serialized": []}
return config
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
ui_tempdir.install_ui_tempdir_override()
gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init
+14 -34
View File
@@ -21,6 +21,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
import modules.sd_vae as sd_vae
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -340,6 +342,16 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
def get_vae_filename(self): #get the name of the VAE file.
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
@@ -355,9 +367,7 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
@@ -370,8 +380,7 @@ class FilenameGenerator:
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
'none': lambda self: '', # Overrides the default so you can get just the sequence number
}
default_time_format = '%Y%m%d%H%M%S'
@@ -382,22 +391,6 @@ class FilenameGenerator:
self.image = image
self.zip = zip
def get_vae_filename(self):
"""Get the name of the VAE file."""
import modules.sd_vae as sd_vae
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
def hasprompt(self, *args):
lower = self.prompt.lower()
if self.p is None or self.prompt is None:
@@ -451,14 +444,6 @@ class FilenameGenerator:
return sanitize_filename_part(formatted_time, replace_spaces=False)
def image_hash(self, *args):
length = int(args[0]) if (args and args[0] != "") else None
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
def string_hash(self, text, *args):
length = int(args[0]) if (args and args[0] != "") else 8
return hashlib.sha256(text.encode()).hexdigest()[0:length]
def apply(self, x):
res = ''
@@ -600,11 +585,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
namegen = FilenameGenerator(p, seed, prompt, image)
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
print('Image dimensions too large; saving as PNG')
extension = ".png"
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
+16 -6
View File
@@ -6,7 +6,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr
from modules import images as imgutil
from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
@@ -116,20 +116,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
if mode == 0: # img2img
image = init_img
image = init_img.convert("RGB")
mask = None
elif mode == 1: # img2img sketch
image = sketch
image = sketch.convert("RGB")
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
mask = processing.create_binary_mask(mask)
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
orig = inpaint_color_sketch_orig or inpaint_color_sketch
@@ -138,6 +139,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif mode == 4: # inpaint upload mask
image = init_img_inpaint
mask = init_mask_inpaint
@@ -164,13 +166,21 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
prompt=prompt,
negative_prompt=negative_prompt,
styles=prompt_styles,
sampler_name=sampler_name,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
restore_faces=restore_faces,
tiling=tiling,
init_images=[image],
mask=mask,
mask_blur=mask_blur,
-168
View File
@@ -1,168 +0,0 @@
import importlib
import logging
import sys
import warnings
from threading import Thread
from modules.timer import startup_timer
def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import torch # noqa: F401
startup_timer.record("import torch")
import pytorch_lightning # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
import gradio # noqa: F401
startup_timer.record("import gradio")
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
from modules import shared_init
shared_init.initialize()
startup_timer.record("initialize shared")
from modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")
def check_versions():
from modules.shared_cmd_options import cmd_opts
if not cmd_opts.skip_version_check:
from modules import errors
errors.check_versions()
def initialize():
from modules import initialize_util
initialize_util.fix_torch_version()
initialize_util.fix_asyncio_event_loop_policy()
initialize_util.validate_tls_options()
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
from modules import modelloader
modelloader.cleanup_models()
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
from modules.shared_cmd_options import cmd_opts
from modules import codeformer_model
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
from modules import gfpgan_model
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
from modules.shared_cmd_options import cmd_opts
from modules import sd_samplers
sd_samplers.set_samplers()
startup_timer.record("set samplers")
from modules import extensions
extensions.list_extensions()
startup_timer.record("list extensions")
from modules import initialize_util
initialize_util.restore_config_state_file()
startup_timer.record("restore config state file")
from modules import shared, upscaler, scripts
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
scripts.load_scripts()
return
from modules import sd_models
sd_models.list_models()
startup_timer.record("list SD models")
from modules import localization
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list localizations")
with startup_timer.subcategory("load scripts"):
scripts.load_scripts()
if reload_script_modules:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")
from modules import modelloader
modelloader.load_upscalers()
startup_timer.record("load upscalers")
from modules import sd_vae
sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
from modules import textual_inversion
textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
from modules import sd_unet
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.
After it's available, if it has been loaded before this access by some extension,
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
from modules import devices
devices.first_time_calculation()
Thread(target=load_model).start()
from modules import shared_items
shared_items.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
from modules import ui_extra_networks
ui_extra_networks.initialize()
ui_extra_networks.register_default_pages()
from modules import extra_networks
extra_networks.initialize()
extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")
-202
View File
@@ -1,202 +0,0 @@
import json
import os
import signal
import sys
import re
from modules.timer import startup_timer
def gradio_server_name():
from modules.shared_cmd_options import cmd_opts
if cmd_opts.server_name:
return cmd_opts.server_name
else:
return "0.0.0.0" if cmd_opts.listen else None
def fix_torch_version():
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def restore_config_state_file():
from modules import shared, config_states
config_state_file = shared.opts.restore_config_state_file
if config_state_file == "":
return
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
def validate_tls_options():
from modules.shared_cmd_options import cmd_opts
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
return
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
startup_timer.record("TLS")
def get_gradio_auth_creds():
"""
Convert the gradio_auth and gradio_auth_path commandline arguments into
an iterable of (username, password) tuples.
"""
from modules.shared_cmd_options import cmd_opts
def process_credential_line(s):
s = s.strip()
if not s:
return None
return tuple(s.split(':', 1))
if cmd_opts.gradio_auth:
for cred in cmd_opts.gradio_auth.split(','):
cred = process_credential_line(cred)
if cred:
yield cred
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
for cred in line.strip().split(','):
cred = process_credential_line(cred)
if cred:
yield cred
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
print("\n".join(code))
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
dumpstacks()
os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
startup_timer.record("opts onchange")
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from modules.shared_cmd_options import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.cors_allow_origins:
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
if cmd_opts.cors_allow_origins_regex:
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
+3 -2
View File
@@ -186,8 +186,9 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
lowvram.send_everything_to_cpu()
devices.torch_gc()
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load()
+27 -61
View File
@@ -1,9 +1,7 @@
# this scripts installs necessary requirements and launches main program in webui.py
import logging
import re
import subprocess
import os
import shutil
import sys
import importlib.util
import platform
@@ -13,10 +11,8 @@ from functools import lru_cache
from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
from modules.timer import startup_timer
from modules import logging_config
args, _ = cmd_args.parser.parse_known_args()
logging_config.setup_logging(args.loglevel)
python = sys.executable
git = os.environ.get('GIT', "git")
@@ -143,25 +139,6 @@ def check_run_python(code: str) -> bool:
return result.returncode == 0
def git_fix_workspace(dir, name):
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
return
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
try:
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
except RuntimeError:
if not autofix:
raise
print(f"{errdesc}, attempting autofix...")
git_fix_workspace(dir, name)
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful
@@ -169,24 +146,15 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
try:
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
except RuntimeError:
shutil.rmtree(dir, ignore_errors=True)
raise
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
@@ -228,9 +196,7 @@ def run_extension_installer(extension_dir):
env = os.environ.copy()
env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
if stdout:
print(stdout)
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
errors.report(str(e))
@@ -248,7 +214,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
if disable_all_extensions != 'none':
return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
@@ -260,8 +226,6 @@ def run_extensions_installers(settings_file):
with startup_timer.subcategory("run extensions installers"):
for dirname_extension in list_extensions(settings_file):
logging.debug(f"Installing {dirname_extension}")
path = os.path.join(extensions_dir, dirname_extension)
if os.path.isdir(path):
@@ -313,6 +277,7 @@ def prepare_environment():
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
@@ -323,13 +288,13 @@ def prepare_environment():
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError:
@@ -359,6 +324,11 @@ def prepare_environment():
)
startup_timer.record("torch GPU test")
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
startup_timer.record("install gfpgan")
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
startup_timer.record("install clip")
@@ -368,7 +338,17 @@ def prepare_environment():
startup_timer.record("install open_clip")
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux":
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
startup_timer.record("install xformers")
if not is_installed("ngrok") and args.ngrok:
@@ -396,8 +376,7 @@ def prepare_environment():
run_pip(f"install -r \"{requirements_file}\"", "requirements")
startup_timer.record("install requirements")
if not args.skip_install:
run_extensions_installers(settings_file=args.ui_settings_file)
run_extensions_installers(settings_file=args.ui_settings_file)
if args.update_check:
version_check(commit)
@@ -434,16 +413,3 @@ def start():
webui.api_only()
else:
webui.webui()
def dump_sysinfo():
from modules import sysinfo
import datetime
text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
with open(filename, "w", encoding="utf8") as file:
file.write(text)
return filename
+2 -1
View File
@@ -1,7 +1,7 @@
import json
import os
from modules import errors, scripts
from modules import errors
localizations = {}
@@ -16,6 +16,7 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
from modules import scripts
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path
-16
View File
@@ -1,16 +0,0 @@
import os
import logging
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
+2 -16
View File
@@ -1,5 +1,5 @@
import torch
from modules import devices, shared
from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
@@ -14,20 +14,6 @@ def send_everything_to_cpu():
module_in_gpu = None
def is_needed(sd_model):
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
def apply(sd_model):
enable = is_needed(sd_model)
shared.parallel_processing_allowed = not enable
if enable:
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
else:
sd_model.lowvram = False
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
@@ -144,4 +130,4 @@ def setup_for_low_vram(sd_model, use_medvram):
def is_enabled(sd_model):
return sd_model.lowvram
return getattr(sd_model, 'lowvram', False)
+5 -2
View File
@@ -4,7 +4,6 @@ import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
from modules import shared
log = logging.getLogger(__name__)
@@ -31,7 +30,8 @@ has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
if shared.state.current_latent is not None:
from modules.shared import state
if state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection")
return
from torch.mps import empty_cache
@@ -52,6 +52,9 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
-245
View File
@@ -1,245 +0,0 @@
import json
import sys
import gradio as gr
from modules import errors
from modules.shared_cmd_options import cmd_opts
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
self.refresh = refresh
self.do_not_save = False
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
self.infotext = infotext
self.restrict_api = restrict_api
"""If True, the setting will not be accessible via API"""
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def html(self, html):
self.comment_after += html
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def needs_reload_ui(self):
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
return self
class OptionHTML(OptionInfo):
def __init__(self, text):
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
self.do_not_save = True
def options_section(section_identifier, options_dict):
for v in options_dict.values():
v.section = section_identifier
return options_dict
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
class Options:
typemap = {int: float}
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()}
self.restricted_opts = restricted_opts
def __setattr__(self, key, value):
if key in options_builtin_fields:
return super(Options, self).__setattr__(key, value)
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = self.data_labels.get(key, None)
if info.do_not_save:
return
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if item in options_builtin_fields:
return super(Options, self).__getattribute__(item)
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def set(self, key, value, is_api=False, run_callbacks=True):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
option = self.data_labels[key]
if option.do_not_save:
return False
if is_api and option.restrict_api:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if run_callbacks and option.onchange is not None:
try:
option.onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
return True
type_x = self.typemap.get(type(x), type(x))
type_y = self.typemap.get(type(y), type(y))
return type_x == type_y
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
# 1.4.0 ui_reorder
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
if info is not None and not self.same_type(info.default, v):
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
bad_settings += 1
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value
-64
View File
@@ -1,64 +0,0 @@
from collections import defaultdict
def patch(key, obj, field, replacement):
"""Replaces a function in a module or a class.
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
replacement: the new function
Returns:
the original function
"""
patch_key = (obj, field)
if patch_key in originals[key]:
raise RuntimeError(f"patch for {field} is already applied")
original_func = getattr(obj, field)
originals[key][patch_key] = original_func
setattr(obj, field, replacement)
return original_func
def undo(key, obj, field):
"""Undoes the peplacement by the patch().
If the function is not replaced, raises an exception.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
Returns:
Always None
"""
patch_key = (obj, field)
if patch_key not in originals[key]:
raise RuntimeError(f"there is no patch for {field} to undo")
original_func = originals[key].pop(patch_key)
setattr(obj, field, original_func)
return None
def original(key, obj, field):
"""Returns the original function for the patch created by the patch() function"""
patch_key = (obj, field)
return originals[key].get(patch_key, None)
originals = defaultdict(dict)
+31 -30
View File
@@ -11,32 +11,37 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
shared.state.begin(job="extras")
image_data = []
image_names = []
outputs = []
def get_images(extras_mode, image, image_folder, input_dir):
if extras_mode == 1:
for img in image_folder:
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
yield image, fn
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
if extras_mode == 1:
for img in image_folder:
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
image_data.append(image)
image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
image_list = shared.listfiles(input_dir)
for filename in image_list:
try:
image = Image.open(filename)
except Exception:
continue
yield image, filename
else:
assert image, 'image not selected'
yield image, None
image_list = shared.listfiles(input_dir)
for filename in image_list:
try:
image = Image.open(filename)
except Exception:
continue
image_data.append(image)
image_names.append(filename)
else:
assert image, 'image not selected'
image_data.append(image)
image_names.append(None)
if extras_mode == 2 and output_dir != '':
outpath = output_dir
@@ -45,16 +50,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = ''
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
image_data: Image.Image
for image, name in zip(image_data, image_names):
shared.state.textinfo = name
parameters, existing_pnginfo = images.read_info_from_image(image_data)
parameters, existing_pnginfo = images.read_info_from_image(image)
if parameters:
existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
scripts.scripts_postproc.run(pp, args)
@@ -75,8 +78,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close()
devices.torch_gc()
return outputs, ui_common.plaintext_to_html(infotext), ''
+377 -412
View File
File diff suppressed because it is too large Load Diff
-49
View File
@@ -1,49 +0,0 @@
import gradio as gr
from modules import scripts, sd_models
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
class ScriptRefiner(scripts.ScriptBuiltinUI):
section = "accordions"
create_group = False
def __init__(self):
pass
def title(self):
return "Refiner"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
with gr.Row():
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
def lookup_checkpoint(title):
info = sd_models.get_closet_checkpoint_match(title)
return None if info is None else info.title
self.infotext_fields = [
(enable_refiner, lambda d: 'Refiner' in d),
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
(refiner_switch_at, 'Refiner switch at'),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
p.refiner_checkpoint = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
p.refiner_switch_at = refiner_switch_at
-111
View File
@@ -1,111 +0,0 @@
import json
import gradio as gr
from modules import scripts, ui, errors
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
class ScriptSeed(scripts.ScriptBuiltinUI):
section = "seed"
create_group = False
def __init__(self):
self.seed = None
self.reuse_seed = None
self.reuse_subseed = None
def title(self):
return "Seed"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Row(elem_id=self.elem_id("seed_row")):
if cmd_opts.use_textbox_seed:
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
with gr.Row(elem_id=self.elem_id("subseed_row")):
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
(self.seed, "Seed"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
p.seed = seed
if seed_checkbox and subseed_strength > 0:
p.subseed = subseed
p.subseed_strength = subseed_strength
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
p.seed_resize_from_w = seed_resize_from_w
p.seed_resize_from_h = seed_resize_from_h
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
def copy_seed(gen_info_string: str, index):
res = -1
try:
gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0)
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
all_subseeds = gen_info.get('all_subseeds', [-1])
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr.update()]
reuse_seed.click(
fn=copy_seed,
_js="(x, y) => [x, selected_gallery_index()]",
show_progress=False,
inputs=[generation_info, seed],
outputs=[seed, seed]
)
+22 -27
View File
@@ -48,7 +48,6 @@ def add_task_to_queue(id_job):
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
class ProgressResponse(BaseModel):
@@ -72,12 +71,7 @@ def progressapi(req: ProgressRequest):
completed = req.id_task in finished_tasks
if not active:
textinfo = "Waiting..."
if queued:
sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
queue_index = sorted_queued.index(req.id_task)
textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
progress = 0
@@ -95,30 +89,31 @@ def progressapi(req: ProgressRequest):
predicted_duration = elapsed_since_start / progress if progress > 0 else None
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
live_preview = None
id_live_preview = req.id_live_preview
shared.state.set_current_image()
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
if opts.live_previews_enable and req.live_preview:
shared.state.set_current_image()
if shared.state.id_live_preview != req.id_live_preview:
image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
if opts.live_previews_image_format == "png":
# using optimize for large images takes an enormous amount of time
if max(*image.size) <= 256:
save_kwargs = {"optimize": True}
else:
save_kwargs = {"optimize": False, "compress_level": 1}
if opts.live_previews_image_format == "png":
# using optimize for large images takes an enormous amount of time
if max(*image.size) <= 256:
save_kwargs = {"optimize": True}
else:
save_kwargs = {}
save_kwargs = {"optimize": False, "compress_level": 1}
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
save_kwargs = {}
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
live_preview = None
else:
live_preview = None
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+11 -32
View File
@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
def get_learned_conditioning_prompt_schedules(prompts, steps):
"""
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
@@ -57,39 +57,18 @@ def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=N
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
>>> g("a [b:.5] c")
[[10, 'a b c']]
>>> g("a [b:1.5] c")
[[5, 'a c'], [10, 'a b c']]
"""
if hires_steps is None or use_old_scheduling:
int_offset = 0
flt_offset = 0
steps = base_steps
else:
int_offset = base_steps
flt_offset = 1.0
steps = hires_steps
def collect_steps(steps, tree):
res = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
s = tree.children[-2]
v = float(s)
if use_old_scheduling:
v = v*steps if v<1 else v
else:
if "." in s:
v = (v - flt_offset) * steps
else:
v = (v - int_offset)
tree.children[-2] = min(steps, int(v))
if tree.children[-2] >= 1:
res.append(tree.children[-2])
tree.children[-2] = float(tree.children[-2])
if tree.children[-2] < 1:
tree.children[-2] *= steps
tree.children[-2] = min(steps, int(tree.children[-2]))
res.append(tree.children[-2])
def alternate(self, tree):
res.extend(range(1, steps+1))
@@ -107,7 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=N
yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
if isinstance(x, str):
if type(x) == str:
yield x
else:
for gen in x:
@@ -155,7 +134,7 @@ class SdConditioning(list):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -175,7 +154,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps,
"""
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
@@ -250,7 +229,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -259,7 +238,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
res = []
for indexes in res_indexes:
-1
View File
@@ -55,7 +55,6 @@ class UpscalerRealESRGAN(Upscaler):
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
device=self.device,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
-170
View File
@@ -1,170 +0,0 @@
import torch
from modules import devices, rng_philox, shared
def randn(seed, shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
if shared.opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=devices.device)
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
if shared.opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def create_generator(seed):
if shared.opts.randn_source == "NV":
return rng_philox.Generator(seed)
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
generator = torch.Generator(device).manual_seed(int(seed))
return generator
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
self.shape = tuple(map(int, shape))
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w
self.generators = [create_generator(seed) for seed in seeds]
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
xs = []
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
subnoise = None
if self.subseeds is not None and self.subseed_strength != 0:
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
subnoise = randn(subseed, noise_shape)
if noise_shape != self.shape:
noise = randn(seed, noise_shape)
else:
noise = randn(seed, self.shape, generator=generator)
if subnoise is not None:
noise = slerp(self.subseed_strength, noise, subnoise)
if noise_shape != self.shape:
x = randn(seed, self.shape, generator=generator)
dx = (self.shape[2] - noise_shape[2]) // 2
dy = (self.shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
xs.append(noise)
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
if eta_noise_seed_delta:
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
return torch.stack(xs).to(shared.device)
def next(self):
if self.is_first:
self.is_first = False
return self.first()
xs = []
for generator in self.generators:
x = randn_without_seed(self.shape, generator=generator)
xs.append(x)
return torch.stack(xs).to(shared.device)
devices.randn = randn
devices.randn_local = randn_local
devices.randn_like = randn_like
devices.randn_without_seed = randn_without_seed
devices.manual_seed = manual_seed
-29
View File
@@ -28,18 +28,6 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class ExtraNoiseParams:
def __init__(self, noise, x, xi):
self.noise = noise
"""Random noise generated by the seed"""
self.x = x
"""Latent representation of the image"""
self.xi = xi
"""Noisy latent representation of the image"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
@@ -112,7 +100,6 @@ callback_map = dict(
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_extra_noise=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
@@ -202,14 +189,6 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']:
try:
c.callback(params)
except Exception:
report_exception(c, 'callbacks_extra_noise')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
try:
@@ -388,14 +367,6 @@ def on_image_saved(callback):
add_callback(callback_map['callbacks_image_saved'], callback)
def on_extra_noise(callback):
"""register a function to be called before adding extra noise in img2img or hires fix;
The callback is called with one argument:
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
"""
add_callback(callback_map['callbacks_extra_noise'], callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
+11 -136
View File
@@ -3,7 +3,6 @@ import re
import sys
import inspect
from collections import namedtuple
from dataclasses import dataclass
import gradio as gr
@@ -22,11 +21,6 @@ class PostprocessBatchListArgs:
self.images = images
@dataclass
class OnComponent:
component: gr.blocks.Block
class Script:
name = None
"""script's internal name derived from title"""
@@ -41,13 +35,9 @@ class Script:
is_txt2img = False
is_img2img = False
tabname = None
group = None
"""A gr.Group component that has all script's UI inside it."""
create_group = True
"""If False, for alwayson scripts, a group component will not be created."""
"""A gr.Group component that has all script's UI inside it"""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -62,15 +52,6 @@ class Script:
api_info = None
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
on_before_component_elem_id = None
"""list of callbacks to be called before a component with an elem_id is created"""
on_after_component_elem_id = None
"""list of callbacks to be called after a component with an elem_id is created"""
setup_for_ui_only = False
"""If true, the script setup will only be run in Gradio UI, not in API"""
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -109,16 +90,9 @@ class Script:
pass
def setup(self, p, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass
def before_process(self, p, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
This function is called very early before processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui()
"""
@@ -238,29 +212,6 @@ class Script:
pass
def on_before_component(self, callback, *, elem_id):
"""
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
This function is an alternative to before_component in that it also cllows to run before a component is created, but
it doesn't require to be called for every created component - just for the one you need.
"""
if self.on_before_component_elem_id is None:
self.on_before_component_elem_id = []
self.on_before_component_elem_id.append((elem_id, callback))
def on_after_component(self, callback, *, elem_id):
"""
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
"""
if self.on_after_component_elem_id is None:
self.on_after_component_elem_id = []
self.on_after_component_elem_id.append((elem_id, callback))
def describe(self):
"""unused"""
return ""
@@ -269,7 +220,7 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
tabkind = 'img2img' if self.is_img2img else 'txt2img'
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
@@ -281,19 +232,6 @@ class Script:
"""
pass
class ScriptBuiltinUI(Script):
setup_for_ui_only = True
def elem_id(self, item_id):
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
return f'{tabname}{item_id}'
current_basedir = paths.script_path
@@ -312,7 +250,7 @@ postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension, *, include_extensions=True):
def list_scripts(scriptdirname, extension):
scripts_list = []
basedir = os.path.join(paths.script_path, scriptdirname)
@@ -320,9 +258,8 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
if include_extensions:
for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension)
for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -351,7 +288,7 @@ def load_scripts():
postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
@@ -412,17 +349,10 @@ class ScriptRunner:
self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = []
self.title_map = {}
self.infotext_fields = []
self.paste_field_names = []
self.inputs = [None]
self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
self.on_after_component_elem_id = {}
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -437,7 +367,6 @@ class ScriptRunner:
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
script.tabname = "img2img" if is_img2img else "txt2img"
visibility = script.show(script.is_img2img)
@@ -450,28 +379,6 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
self.apply_on_before_component_callbacks()
def apply_on_before_component_callbacks(self):
for script in self.scripts:
on_before = script.on_before_component_elem_id or []
on_after = script.on_after_component_elem_id or []
for elem_id, callback in on_before:
if elem_id not in self.on_before_component_elem_id:
self.on_before_component_elem_id[elem_id] = []
self.on_before_component_elem_id[elem_id].append((callback, script))
for elem_id, callback in on_after:
if elem_id not in self.on_after_component_elem_id:
self.on_after_component_elem_id[elem_id] = []
self.on_after_component_elem_id[elem_id].append((callback, script))
on_before.clear()
on_after.clear()
def create_script_ui(self, script):
import modules.api.models as api_models
@@ -522,20 +429,15 @@ class ScriptRunner:
if script.alwayson and script.section != section:
continue
if script.create_group:
with gr.Group(visible=script.alwayson) as group:
self.create_script_ui(script)
script.group = group
else:
with gr.Group(visible=script.alwayson) as group:
self.create_script_ui(script)
script.group = group
def prepare_ui(self):
self.inputs = [None]
def setup_ui(self):
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
self.setup_ui_for_section(None)
@@ -582,8 +484,6 @@ class ScriptRunner:
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
self.apply_on_before_component_callbacks()
return self.inputs
def run(self, p, *args):
@@ -677,12 +577,6 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
callback(OnComponent(component=component))
except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
for script in self.scripts:
try:
script.before_component(component, **kwargs)
@@ -690,21 +584,12 @@ class ScriptRunner:
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
try:
callback(OnComponent(component=component))
except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def script(self, title):
return self.title_map.get(title.lower())
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
args_from = script.args_from
@@ -723,6 +608,7 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
def before_hr(self, p):
for script in self.alwayson_scripts:
try:
@@ -731,17 +617,6 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts:
if not is_ui and script.setup_for_ui_only:
continue
try:
script_args = p.script_args[script.args_from:script.args_to]
script.setup(p, *script_args)
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
+10 -53
View File
@@ -155,16 +155,10 @@ class LoadStateDictOnMeta(ReplaceHelper):
```
"""
def __init__(self, state_dict, device, weight_dtype_conversion=None):
def __init__(self, state_dict, device):
super().__init__()
self.state_dict = state_dict
self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')
def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -173,60 +167,23 @@ class LoadStateDictOnMeta(ReplaceHelper):
sd = self.state_dict
device = self.device
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []
for name, param in module._parameters.items():
if param is None:
continue
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
for name, param in params:
if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
for name in module._buffers:
original(self, state_dict, prefix, *args, **kwargs)
for name, _ in params:
key = prefix + name
if key in sd:
del sd[key]
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)
original(module, state_dict, prefix, *args, **kwargs)
for key in used_param_keys:
state_dict.pop(key, None)
def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""
if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
+4 -16
View File
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -34,6 +34,8 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
sd_hijack_inpainting.do_inpainting_hijack()
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
@@ -245,21 +247,7 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model')
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
+95
View File
@@ -0,0 +1,95 @@
import torch
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from ldm.models.diffusion.ddim import noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = {}
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
def do_inpainting_hijack():
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+3 -10
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import math
import psutil
import platform
import torch
from torch import einsum
@@ -95,10 +94,7 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention"
@property
def priority(self):
return 1000 if shared.device.type == 'mps' else 10
priority = 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@@ -124,7 +120,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property
def priority(self):
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
return 1000 if not torch.cuda.is_available() else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
@@ -431,10 +427,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None:
if q.device.type == 'mps':
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
else:
chunk_threshold_bytes = int(get_available_vram() * 0.7)
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else:
+22 -77
View File
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.timer import Timer
import tomesd
@@ -27,24 +27,6 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
def replace_key(d, key, new_key, value):
keys = list(d.keys())
d[new_key] = value
if key not in keys:
return d
index = keys.index(key)
keys[index] = new_key
new_d = {k: d[k] for k in keys}
d.clear()
d.update(new_d)
return d
class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
@@ -86,9 +68,7 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
if self.shorthash:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -100,20 +80,14 @@ class CheckpointInfo:
if self.sha256 is None:
return
shorthash = self.sha256[0:10]
if self.shorthash == self.sha256[0:10]:
return self.shorthash
self.shorthash = shorthash
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
old_title = self.title
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
replace_key(checkpoints_list, old_title, self.title, self)
self.register()
return self.shorthash
@@ -167,9 +141,6 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def get_closet_checkpoint_match(search_string):
if not search_string:
return None
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -363,11 +334,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if shared.cmd_opts.no_half:
model.float()
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -383,9 +350,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
devices.dtype_unet = torch.float16
timer.record("apply half()")
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -405,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
@@ -506,12 +473,8 @@ class SdModelData:
return self.sd_model
def set_sd_model(self, v, already_loaded=False):
def set_sd_model(self, v):
self.sd_model = v
if already_loaded:
sd_vae.base_vae = getattr(v, "base_vae", None)
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -526,6 +489,7 @@ model_data = SdModelData()
def get_empty_cond(sd_model):
from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
@@ -538,7 +502,9 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
if m.lowvram:
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -546,17 +512,12 @@ def send_model_to_cpu(m):
devices.torch_gc()
def model_target_device(m):
if lowvram.is_needed(m):
return devices.cpu
else:
return devices.device
def send_model_to_device(m):
lowvram.apply(m)
from modules import lowvram
if not m.lowvram:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else:
m.to(shared.device)
@@ -614,15 +575,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model")
if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'': torch.float16,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -685,14 +638,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")
model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
model_data.set_sd_model(already_loaded)
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -704,10 +651,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
@@ -715,6 +658,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
def reload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -764,7 +708,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
if not sd_model.lowvram:
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
timer.record("move model to device")
@@ -777,6 +721,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
+2 -1
View File
@@ -2,7 +2,7 @@ import os
import torch
from modules import shared, paths, sd_disable_initialization, devices
from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
@@ -29,6 +29,7 @@ def is_using_v_parameterization_for_sd2(state_dict):
"""
import ldm.modules.diffusionmodules.openaimodel
from modules import devices
device = devices.cpu
-31
View File
@@ -1,31 +0,0 @@
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from modules.sd_models import CheckpointInfo
class WebuiSdModel(LatentDiffusion):
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
lowvram: bool
"""True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
sd_model_hash: str
"""short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
sd_model_checkpoint: str
"""path to the file on disk that model weights were obtained from"""
sd_checkpoint_info: 'CheckpointInfo'
"""structure with additional information about the file with model's weights"""
is_sdxl: bool
"""True if the model's architecture is SDXL"""
is_sd2: bool
"""True if the model's architecture is SD 2.x"""
is_sd1: bool
"""True if the model's architecture is SD 1.x"""
+8 -11
View File
@@ -1,18 +1,17 @@
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps,
*sd_samplers_compvis.samplers_data_compvis,
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
samplers_hidden = {}
def find_sampler_config(name):
@@ -39,11 +38,13 @@ def create_sampler(name, model):
def set_samplers():
global samplers, samplers_for_img2img, samplers_hidden
global samplers, samplers_for_img2img
samplers_hidden = set(shared.opts.hide_samplers)
samplers = all_samplers
samplers_for_img2img = all_samplers
hidden = set(shared.opts.hide_samplers)
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
@@ -52,8 +53,4 @@ def set_samplers():
samplers_map[alias.lower()] = sampler.name
def visible_sampler_names():
return [x.name for x in samplers if x.name not in samplers_hidden]
set_samplers()
-230
View File
@@ -1,230 +0,0 @@
import torch
from modules import prompt_parser, devices, sd_samplers_common
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
def catenate_conds(conds):
if not isinstance(conds[0], dict):
return torch.cat(conds)
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
def subscript_cond(cond, a, b):
if not isinstance(cond, dict):
return cond[a:b]
return {key: vec[a:b] for key, vec in cond.items()}
def pad_cond(tensor, repeats, empty):
if not isinstance(tensor, dict):
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
return tensor
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, sampler):
super().__init__()
self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
self.total_steps = None
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = sampler
self.model_wrap = None
self.p = None
self.mask_before_denoising = False
@property
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def get_pred_x0(self, x_in, x_out, sigma):
return x_out
def update_inner_model(self):
self.model_wrap = None
c, uc = self.p.get_conds()
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
if sd_samplers_common.apply_refiner(self):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else:
image_uncond = image_cond
if isinstance(uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
else:
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
skip_uncond = False
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
skip_uncond = True
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
cond_in = catenate_conds([tensor, uncond])
if shared.opts.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt":
preview = self.sampler.last_latent
elif opts.live_preview_content == "Negative prompt":
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
else:
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
sd_samplers_common.store_latent(preview)
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
denoised = after_cfg_callback_params.x
self.step += 1
return denoised
+8 -216
View File
@@ -1,22 +1,11 @@
import inspect
from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
import k_diffusion.sampling
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
class SamplerData(SamplerDataTuple):
def total_steps(self, steps):
if self.options.get("second_order", False):
steps = steps * 2
return steps
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
def setup_img2img_steps(p, steps=None):
@@ -35,27 +24,22 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD":
def samples_to_images_tensor(sample, approximation=None, model=None):
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
'''latents -> images [-1, 1]'''
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
from modules import lowvram
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
approximation = 1
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = sample * 1.5
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -95,19 +79,9 @@ def images_tensor_to_samples(image, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
model.first_stage_model.to(devices.dtype_vae)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(
model.encode_first_stage(torch.unsqueeze(img, 0))
)[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
@@ -153,185 +127,3 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
def apply_refiner(cfg_denoiser):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
if getattr(cfg_denoiser.p, "enable_hr", False):
is_second_pass = cfg_denoiser.p.is_hr_pass
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
return False
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
return False
if opts.hires_fix_refiner_pass != "second pass":
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
cfg_denoiser.p.setup_conds()
cfg_denoiser.update_inner_model()
return True
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
k-diffusion has random_sampler argument for most samplers, but not for all, so
this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, p):
self.rng = p.rng
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
return self.rng.next()
class Sampler:
def __init__(self, funcname):
self.funcname = funcname
self.func = funcname
self.extra_params = []
self.sampler_noises = None
self.stop_at = None
self.eta = None
self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
self.s_tmin = 0.0
self.s_tmax = float('inf')
self.s_noise = 1.0
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None
self.options = {}
def callback_state(self, d):
step = d['i']
if self.stop_at is not None and step > self.stop_at:
raise InterruptedException
state.sampling_step = step
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except RecursionError:
print(
'Encountered RecursionError during sampling, returning last latent. '
'rho >5 with a polyexponential scheduler may cause this error. '
'You should try to use a smaller rho value instead.'
)
return self.last_latent
except InterruptedException:
return self.last_latent
def number_of_needed_noises(self, p):
return p.steps
def initialize(self, p) -> dict:
self.p = p
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
if self.eta != self.eta_default:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
if len(self.extra_params) > 0:
s_churn = getattr(opts, 's_churn', p.s_churn)
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
return extra_params_kwargs
def create_noise_sampler(self, x, sigmas, p):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
if shared.opts.no_dpmpp_sde_batch_determinism:
return None
from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
+224
View File
@@ -0,0 +1,224 @@
import math
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import numpy as np
import torch
from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
import modules.models.diffusion.uni_pc
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
self.orig_p_sample_ddim = None
if self.is_plms:
self.orig_p_sample_ddim = self.sampler.p_sample_plms
elif self.is_ddim:
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
self.sampler_noises = None
self.step = 0
self.stop_at = None
self.eta = None
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
def launch_sampling(self, steps, func):
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
state.sampling_step = 0
try:
return func()
except sd_samplers_common.InterruptedException:
return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
return res
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
uc_image_conditioning = None
if isinstance(cond, dict):
if self.conditioning_key == "crossattn-adm":
image_conditioning = cond["c_adm"]
uc_image_conditioning = unconditional_conditioning["c_adm"]
else:
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
# filling unconditional_conditioning with repeats of the last vector to match length is
# not 100% correct but should work well enough
if unconditional_conditioning.shape[1] < cond.shape[1]:
last_vector = unconditional_conditioning[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
elif unconditional_conditioning.shape[1] > cond.shape[1]:
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x = img_orig * self.mask + self.nmask * x
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
else:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
return x, ts, cond, unconditional_conditioning
def update_step(self, last_latent):
if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
self.step += 1
state.sampling_step = self.step
shared.total_tqdm.update()
def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
return x, ts, cond, uncond, res
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p):
if self.is_ddim:
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
else:
self.eta = 0.0
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
if self.is_unipc:
keys = [
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
]
for name, key in keys:
v = getattr(shared.opts, key)
if v != shared.opts.get_default(key):
p.extra_generation_params[name] = v
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
self.last_latent = x
self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
else:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
self.last_latent = x
self.step = 0
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
else:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
+339 -68
View File
@@ -1,41 +1,38 @@
from collections import deque
import torch
import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
from modules.shared import opts
from modules.processing import StableDiffusionProcessing
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
@@ -49,12 +46,6 @@ sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_fast': ['s_noise'],
'sample_dpm_2_ancestral': ['s_noise'],
'sample_dpmpp_2s_ancestral': ['s_noise'],
'sample_dpmpp_sde': ['s_noise'],
'sample_dpmpp_2m_sde': ['s_noise'],
'sample_dpmpp_3m_sde': ['s_noise'],
}
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
@@ -66,27 +57,319 @@ k_diffusion_scheduler = {
}
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
def catenate_conds(conds):
if not isinstance(conds[0], dict):
return torch.cat(conds)
return self.model_wrap
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname)
def subscript_cond(cond, a, b):
if not isinstance(cond, dict):
return cond[a:b]
self.extra_params = sampler_extra_params.get(funcname, [])
return {key: vec[a:b] for key, vec in cond.items()}
self.options = options or {}
def pad_cond(tensor, repeats, empty):
if not isinstance(tensor, dict):
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
return tensor
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else:
image_uncond = image_cond
if isinstance(uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
else:
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
skip_uncond = False
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
skip_uncond = True
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
cond_in = catenate_conds([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
denoised = after_cfg_callback_params.x
self.step += 1
return denoised
class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return devices.randn_like(x)
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.stop_at = None
self.noisy_output = None
self.eta = None
self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
self.model_wrap = self.model_wrap_cfg.inner_model
# NOTE: These are also defined in the StableDiffusionProcessing class.
# They should have been here to begin with but we're going to
# leave that class __init__ signature alone.
self.s_churn = 0.0
self.s_tmin = 0.0
self.s_tmax = float('inf')
self.s_noise = 1.0
self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
if opts.live_preview_content == "Combined":
sd_samplers_common.store_latent(latent)
self.last_latent = latent
self.noisy_output = d['x']
if self.stop_at is not None and step > self.stop_at:
raise sd_samplers_common.InterruptedException
state.sampling_step = step
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
state.sampling_steps = self.stop_at if self.stop_at is not None else steps
state.sampling_step = 0
try:
return func()
except RecursionError:
print(
'Encountered RecursionError during sampling, returning last latent. '
'rho >5 with a polyexponential scheduler may cause this error. '
'You should try to use a smaller rho value instead.'
)
return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
def number_of_needed_noises(self, p):
return p.steps
def initialize(self, p: StableDiffusionProcessing):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
if self.eta != 1.0:
p.extra_generation_params["Eta"] = self.eta
extra_params_kwargs['eta'] = self.eta
if len(self.extra_params) > 0:
s_churn = getattr(opts, 's_churn', p.s_churn)
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
if s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
return extra_params_kwargs
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -138,21 +421,24 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas
def create_noise_sampler(self, x, sigmas, p):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
if shared.opts.no_dpmpp_sde_batch_determinism:
return None
from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
if opts.img2img_extra_noise > 0:
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
extra_noise_params = ExtraNoiseParams(noise, x, xi)
extra_noise_callback(extra_noise_params)
noise = extra_noise_params.noise
xi += noise * opts.img2img_extra_noise
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -172,12 +458,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.model_wrap_cfg.init_latent = x
self.last_latent = x
self.sampler_extra_args = {
extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -185,7 +468,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
@@ -197,46 +480,34 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
sigmas = self.get_sigmas(p, steps)
if opts.sgm_noise_multiplier:
p.extra_generation_params["SGM noise multiplier"] = True
x = x * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
x = x * sigmas[0]
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
if 'n' in parameters:
extra_params_kwargs['n'] = steps
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'sigmas' in parameters:
if 'n' in parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.last_latent = x
self.sampler_extra_args = {
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples
-167
View File
@@ -1,167 +0,0 @@
import torch
import inspect
import sys
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
]
samplers_data_timesteps = [
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
for label, funcname, aliases, options in samplers_timesteps
]
class CompVisTimestepsDenoiser(torch.nn.Module):
def __init__(self, model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inner_model = model
def forward(self, input, timesteps, **kwargs):
return self.inner_model.apply_model(input, timesteps, **kwargs)
class CompVisTimestepsVDenoiser(torch.nn.Module):
def __init__(self, model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
return e_t
class CFGDenoiserTimesteps(CFGDenoiser):
def __init__(self, sampler):
super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
def get_pred_x0(self, x_in, x_out, sigma):
ts = sigma.to(dtype=int)
a_t = self.alphas[ts][:, None, None, None]
sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
return pred_x0
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
self.model_wrap = denoiser(shared.sd_model)
return self.model_wrap
class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
super().__init__(funcname)
self.eta_option_field = 'eta_ddim'
self.eta_infotext_field = 'Eta DDIM'
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
discard_next_to_last_sigma = True
p.extra_generation_params["Discard penultimate sigma"] = True
steps += 1 if discard_next_to_last_sigma else 0
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
return timesteps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
timesteps = self.get_timesteps(p, steps)
timesteps_sched = timesteps[:t_enc]
alphas_cumprod = shared.sd_model.alphas_cumprod
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
if opts.img2img_extra_noise > 0:
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
extra_noise_params = ExtraNoiseParams(noise, x, xi)
extra_noise_callback(extra_noise_params)
noise = extra_noise_params.noise
xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
if 'timesteps' in parameters:
extra_params_kwargs['timesteps'] = timesteps_sched
if 'is_img2img' in parameters:
extra_params_kwargs['is_img2img'] = True
self.model_wrap_cfg.init_latent = x
self.last_latent = x
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps = steps or p.steps
timesteps = self.get_timesteps(p, steps)
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
if 'timesteps' in parameters:
extra_params_kwargs['timesteps'] = timesteps
self.last_latent = x
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
return samples
sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
-137
View File
@@ -1,137 +0,0 @@
import torch
import tqdm
import k_diffusion.sampling
import numpy as np
from modules import shared
from modules.models.diffusion.uni_pc import uni_pc
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
ts = timesteps[index].item() * s_in
t_next = timesteps[max(index - 1, 0)].item() * s_in
e_t = model(x, ts, **extra_args)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = model(x_prev, t_next, **extra_args)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
else:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
x = x_prev
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
class UniPCCFG(uni_pc.UniPC):
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
super().__init__(None, *args, **kwargs)
def after_update(x, model_x):
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
self.index += 1
self.cfg_model = cfg_model
self.extra_args = extra_args
self.callback = callback
self.index = 0
self.after_update = after_update
def get_model_input_time(self, t_continuous):
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
def model(self, x, t):
t_input = self.get_model_input_time(t)
res = self.cfg_model(x, t_input, **self.extra_args)
return res
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x
+1 -1
View File
@@ -47,7 +47,7 @@ def apply_unet(option=None):
if current_unet_option is None:
current_unet = None
if not shared.sd_model.lowvram:
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
shared.sd_model.model.diffusion_model.to(devices.device)
return
+25 -80
View File
@@ -1,9 +1,6 @@
import os
import collections
from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
import glob
from copy import deepcopy
@@ -20,22 +17,6 @@ checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_loaded_vae_name():
if loaded_vae_file is None:
return None
return os.path.basename(loaded_vae_file)
def get_loaded_vae_hash():
if loaded_vae_file is None:
return None
sha256 = hashes.sha256(loaded_vae_file, 'vae')
return sha256[0:10] if sha256 else None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
@@ -70,6 +51,7 @@ def get_filename(filepath):
def refresh_vae_list():
global vae_dict
vae_dict.clear()
paths = [
@@ -103,7 +85,7 @@ def refresh_vae_list():
name = get_filename(filepath)
vae_dict[name] = filepath
vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
def find_vae_near_checkpoint(checkpoint_file):
@@ -115,74 +97,37 @@ def find_vae_near_checkpoint(checkpoint_file):
return None
@dataclass
class VaeResolution:
vae: str = None
source: str = None
resolved: bool = True
def resolve_vae(checkpoint_file):
if shared.cmd_opts.vae_path is not None:
return shared.cmd_opts.vae_path, 'from commandline argument'
def tuple(self):
return self.vae, self.source
def is_automatic():
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
def resolve_vae_from_setting() -> VaeResolution:
if shared.opts.sd_vae == "None":
return VaeResolution()
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
if vae_from_options is not None:
return VaeResolution(vae_from_options, 'specified in settings')
if not is_automatic():
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
return VaeResolution(resolved=False)
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
metadata = extra_networks.get_user_metadata(checkpoint_file)
vae_metadata = metadata.get("vae", None)
if vae_metadata is not None and vae_metadata != "Automatic":
if vae_metadata == "None":
return VaeResolution()
return None, None
vae_from_metadata = vae_dict.get(vae_metadata, None)
if vae_from_metadata is not None:
return VaeResolution(vae_from_metadata, "from user metadata")
return vae_from_metadata, "from user metadata"
return VaeResolution(resolved=False)
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
return vae_near_checkpoint, 'found near the checkpoint'
return VaeResolution(resolved=False)
if shared.opts.sd_vae == "None":
return None, None
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
if vae_from_options is not None:
return vae_from_options, 'specified in settings'
def resolve_vae(checkpoint_file) -> VaeResolution:
if shared.cmd_opts.vae_path is not None:
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
if not is_automatic:
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
return resolve_vae_from_setting()
res = resolve_vae_from_user_metadata(checkpoint_file)
if res.resolved:
return res
res = resolve_vae_near_checkpoint(checkpoint_file)
if res.resolved:
return res
res = resolve_vae_from_setting()
return res
return None, None
def load_vae_dict(filename, map_location):
@@ -192,7 +137,7 @@ def load_vae_dict(filename, map_location):
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, base_vae, loaded_vae_file
global vae_dict, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -230,8 +175,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
restore_base_vae(model)
loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file
# don't call this from outside
@@ -249,6 +192,8 @@ unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
@@ -256,14 +201,14 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
checkpoint_file = checkpoint_info.filename
if vae_file == unspecified:
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
vae_file, vae_source = resolve_vae(checkpoint_file)
else:
vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
if sd_model.lowvram:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
@@ -275,7 +220,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not sd_model.lowvram:
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print("VAE weights loaded.")
+924 -38
View File
@@ -1,51 +1,840 @@
import datetime
import json
import os
import re
import sys
import threading
import time
import logging
import gradio as gr
import torch
import tqdm
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
import launch
import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import util
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
log = logging.getLogger(__name__)
demo = None
device = None
parser = cmd_args.parser
weight_load_location = None
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
"outdir_extras_samples",
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
"outdir_init_images"
}
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/base",
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/Lime",
"abidlabs/pakistan",
"Ama434/neutral-barlow",
"dawood/microsoft_windows",
"finlaymacklon/smooth_slate",
"Franklisi/darkmode",
"freddyaboulton/dracula_revamped",
"freddyaboulton/test-blue",
"gstaff/xkcd",
"Insuz/Mocha",
"Insuz/SimpleIndigo",
"JohnSmith9982/small_and_pretty",
"nota-ai/theme",
"nuttea/Softblue",
"ParityError/Anime",
"reilnuud/polite",
"remilia/Ghostly",
"rottenlittlecreature/Moon_Goblin",
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
"ysharma/steampunk"
]
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
loaded_hypernetworks = []
state = None
prompt_styles = None
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks
interrogator = None
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
time_start = None
server_start = None
_server_command_signal = threading.Event()
_server_command: Optional[str] = None
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
log.info("Received restart request")
def skip(self):
self.skipped = True
log.info("Received skip request")
def interrupt(self):
self.interrupted = True
log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
devices.torch_gc()
def set_current_image(self):
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
if self.current_latent is None:
return
import modules.sd_samplers
try:
if opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
except Exception:
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
state = State()
state.server_start = time.time()
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
options_templates = None
opts = None
restricted_opts = None
sd_model: sd_models_types.WebuiSdModel = None
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
self.refresh = refresh
self.do_not_save = False
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def html(self, html):
self.comment_after += html
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def needs_reload_ui(self):
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
return self
class OptionHTML(OptionInfo):
def __init__(self, text):
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
self.do_not_save = True
def options_section(section_identifier, options_dict):
for v in options_dict.values():
v.section = section_identifier
return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
tab_names = []
options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"font": OptionInfo("", "Font for image grids that have text"),
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('vae', "VAE"), {
"sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
"""),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
}))
options_templates.update(options_section(('img2img', "img2img"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
}))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
<li>Ignore: keep prompt and styles dropdown as it is.</li>
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
</ul>"""),
}))
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
}))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
class Options:
data = None
data_labels = options_templates
typemap = {int: float}
def __init__(self):
self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = opts.data_labels.get(key, None)
if info.do_not_save:
return
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
if self.data_labels[key].do_not_save:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
return True
type_x = self.typemap.get(type(x), type(x))
type_y = self.typemap.get(type(y), type(y))
return type_x == type_y
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
# 1.4.0 ui_reorder
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
if info is not None and not self.same_type(info.default, v):
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
bad_settings += 1
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
tab_names = []
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
@@ -64,24 +853,121 @@ progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
total_tqdm = None
mem_mon = None
def reload_gradio_theme(theme_name=None):
global gradio_theme
if not theme_name:
theme_name = opts.gradio_theme
options_section = options.options_section
OptionInfo = options.OptionInfo
OptionHTML = options.OptionHTML
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
natural_sort_key = util.natural_sort_key
listfiles = util.listfiles
html_path = util.html_path
html = util.html
walk_files = util.walk_files
ldm_print = util.ldm_print
if theme_name == "Default":
gradio_theme = gr.themes.Default(**default_theme_args)
else:
try:
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
else:
os.makedirs(theme_cache_dir, exist_ok=True)
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
gradio_theme.dump(theme_cache_path)
except Exception as e:
errors.display(e, "changing gradio theme")
gradio_theme = gr.themes.Default(**default_theme_args)
reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
refresh_checkpoints = shared_items.refresh_checkpoints
list_samplers = shared_items.list_samplers
reload_hypernetworks = shared_items.reload_hypernetworks
class TotalTQDM:
def __init__(self):
self._tqdm = None
def reset(self):
self._tqdm = tqdm.tqdm(
desc="Total progress",
total=state.job_count * state.sampling_steps,
position=1,
file=progress_print_out
)
def update(self):
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
def html_path(filename):
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
items = list(os.walk(path, followlinks=True))
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
for root, _, files in items:
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
if not opts.list_hidden_files and ("/." in root or "\\." in root):
continue
yield os.path.join(root, filename)
def ldm_print(*args, **kwargs):
if opts.hide_ldm_prints:
return
print(*args, **kwargs)
-18
View File
@@ -1,18 +0,0 @@
import os
import launch
from modules import cmd_args, script_loading
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
parser = cmd_args.parser
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access
-67
View File
@@ -1,67 +0,0 @@
import os
import gradio as gr
from modules import errors, shared
from modules.paths_internal import script_path
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/base",
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/Lime",
"abidlabs/pakistan",
"Ama434/neutral-barlow",
"dawood/microsoft_windows",
"finlaymacklon/smooth_slate",
"Franklisi/darkmode",
"freddyaboulton/dracula_revamped",
"freddyaboulton/test-blue",
"gstaff/xkcd",
"Insuz/Mocha",
"Insuz/SimpleIndigo",
"JohnSmith9982/small_and_pretty",
"nota-ai/theme",
"nuttea/Softblue",
"ParityError/Anime",
"reilnuud/polite",
"remilia/Ghostly",
"rottenlittlecreature/Moon_Goblin",
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
"ysharma/steampunk",
"NoCrypt/miku"
]
def reload_gradio_theme(theme_name=None):
if not theme_name:
theme_name = shared.opts.gradio_theme
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
if theme_name == "Default":
shared.gradio_theme = gr.themes.Default(**default_theme_args)
else:
try:
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
else:
os.makedirs(theme_cache_dir, exist_ok=True)
shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
shared.gradio_theme.dump(theme_cache_path)
except Exception as e:
errors.display(e, "changing gradio theme")
shared.gradio_theme = gr.themes.Default(**default_theme_args)
-49
View File
@@ -1,49 +0,0 @@
import os
import torch
from modules import shared
from modules.shared import cmd_opts
def initialize():
"""Initializes fields inside the shared module in a controlled manner.
Should be called early because some other modules you can import mingt need these fields to be already set.
"""
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
from modules import options, shared_options
shared.options_templates = shared_options.options_templates
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
shared.restricted_opts = shared_options.restricted_opts
if os.path.exists(shared.config_filename):
shared.opts.load(shared.config_filename)
from modules import devices
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
from modules import shared_state
shared.state = shared_state.State()
from modules import styles
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
from modules import interrogate
shared.interrogator = interrogate.InterrogateModels("interrogate")
from modules import shared_total_tqdm
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
from modules import memmon, devices
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
shared.mem_mon.start()
+2 -52
View File
@@ -1,6 +1,3 @@
import sys
from modules.shared_cmd_options import cmd_opts
def realesrgan_models_names():
@@ -44,36 +41,13 @@ def refresh_unet_list():
modules.sd_unet.list_unets()
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
from modules import shared
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
ui_reorder_categories_builtin_items = [
"inpaint",
"sampler",
"accordions",
"checkboxes",
"hires_fix",
"dimensions",
"cfg",
"denoising",
"seed",
"batch",
"override_settings",
@@ -87,33 +61,9 @@ def ui_reorder_categories():
sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
if isinstance(script.section, str):
sections[script.section] = 1
yield from sections
yield "scripts"
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sys.modules['modules.shared'].__class__ = Shared
-332
View File
@@ -1,332 +0,0 @@
import gradio as gr
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML
options_templates = {}
hide_dirs = shared.hide_dirs
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
"outdir_extras_samples",
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
"outdir_init_images"
}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"font": OptionInfo("", "Font for image grids that have text"),
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
options_templates.update(options_section(('system', "System"), {
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
options_templates.update(options_section(('API', "API"), {
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('vae', "VAE"), {
"sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
"""),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
}))
options_templates.update(options_section(('img2img', "img2img"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
}))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
<li>Ignore: keep prompt and styles dropdown as it is.</li>
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
</ul>"""),
}))
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
}))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
-159
View File
@@ -1,159 +0,0 @@
import datetime
import logging
import threading
import time
from modules import errors, shared, devices
from typing import Optional
log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
time_start = None
server_start = None
_server_command_signal = threading.Event()
_server_command: Optional[str] = None
def __init__(self):
self.server_start = time.time()
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
log.info("Received restart request")
def skip(self):
self.skipped = True
log.info("Received skip request")
def interrupt(self):
self.interrupted = True
log.info("Received interrupt request")
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
devices.torch_gc()
def set_current_image(self):
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
if self.current_latent is None:
return
import modules.sd_samplers
try:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
except Exception:
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
-37
View File
@@ -1,37 +0,0 @@
import tqdm
from modules import shared
class TotalTQDM:
def __init__(self):
self._tqdm = None
def reset(self):
self._tqdm = tqdm.tqdm(
desc="Total progress",
total=shared.state.job_count * shared.state.sampling_steps,
position=1,
file=shared.progress_print_out
)
def update(self):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
+2 -2
View File
@@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
+8 -18
View File
@@ -10,7 +10,7 @@ import psutil
import re
import launch
from modules import paths_internal, timer, shared, extensions, errors
from modules import paths_internal, timer
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = {
@@ -23,6 +23,7 @@ environment_whitelist = {
"TORCH_COMMAND",
"REQS_FILE",
"XFORMERS_PACKAGE",
"GFPGAN_PACKAGE",
"CLIP_PACKAGE",
"OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO",
@@ -82,7 +83,7 @@ def get_dict():
"Data path": paths_internal.data_path,
"Extensions dir": paths_internal.extensions_dir,
"Checksum": checksum_token,
"Commandline": get_argv(),
"Commandline": sys.argv,
"Torch env info": get_torch_sysinfo(),
"Exceptions": get_exceptions(),
"CPU": {
@@ -114,6 +115,8 @@ def format_exception(e, tb):
def get_exceptions():
try:
from modules import errors
return list(reversed(errors.exception_records))
except Exception as e:
return str(e)
@@ -123,22 +126,6 @@ def get_environment():
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
def get_argv():
res = []
for v in sys.argv:
if shared.cmd_opts.gradio_auth and shared.cmd_opts.gradio_auth == v:
res.append("<hidden>")
continue
if shared.cmd_opts.api_auth and shared.cmd_opts.api_auth == v:
res.append("<hidden>")
continue
res.append(v)
return res
re_newline = re.compile(r"\r*\n")
@@ -155,6 +142,8 @@ def get_torch_sysinfo():
def get_extensions(*, enabled):
try:
from modules import extensions
def to_json(x: extensions.Extension):
return {
"name": x.name,
@@ -171,6 +160,7 @@ def get_extensions(*, enabled):
def get_config():
try:
from modules import shared
return shared.opts.data
except Exception as e:
return str(e)
+12 -4
View File
@@ -1,7 +1,7 @@
from contextlib import closing
import modules.scripts
from modules import processing
from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -19,13 +19,21 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
prompt=prompt,
styles=prompt_styles,
negative_prompt=negative_prompt,
sampler_name=sampler_name,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
hr_scale=hr_scale,
@@ -34,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
+181 -71
View File
@@ -1,4 +1,5 @@
import datetime
import json
import mimetypes
import os
import sys
@@ -12,8 +13,8 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
@@ -28,6 +29,7 @@ import modules.shared as shared
import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component
@@ -39,9 +41,6 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# Likewise, add explicit content-type header for certain missing image types
mimetypes.add_type('image/webp', '.webp')
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
@@ -77,6 +76,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
detect_image_size_symbol = '\U0001F4D0' # 📐
up_down_symbol = '\u2195\ufe0f' # ↕️
plaintext_to_html = ui_common.plaintext_to_html
@@ -89,13 +89,17 @@ def send_gradio_gallery_to_image(x):
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
if not enable:
return ""
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
p.calculate_target_resolution()
return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
with devices.autocast():
p.init([""], [0], [0])
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def resize_from_to_html(width, height, scale_by):
@@ -141,6 +145,41 @@ def interrogate_deepbooru(image):
return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
def change_visibility(show):
return {comp: gr_show(show) for comp in seed_extras}
seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
def connect_clear_prompt(button):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
@@ -151,6 +190,39 @@ def connect_clear_prompt(button):
)
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
def copy_seed(gen_info_string: str, index):
res = -1
try:
gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0)
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
all_subseeds = gen_info.get('all_subseeds', [-1])
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
reuse_seed.click(
fn=copy_seed,
_js="(x, y) => [x, selected_gallery_index()]",
show_progress=False,
inputs=[generation_info, dummy_component],
outputs=[seed, dummy_component]
)
def update_token_counter(text, steps):
try:
text, _ = extra_networks.parse_prompt(text)
@@ -285,14 +357,14 @@ def create_output_panel(tabname, outdir):
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
return steps, sampler_name
return steps, sampler_index
def ordered_ui_categories():
@@ -333,13 +405,13 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
elif category == "dimensions":
with FormRow():
@@ -356,45 +428,44 @@ def create_ui():
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "cfg":
with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
elif category == "seed":
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "accordions":
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
with enable_hr.extra():
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
scripts.scripts_txt2img.setup_ui_for_section(category)
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -410,7 +481,7 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = scripts.scripts_txt2img.setup_ui()
if category not in {"accordions"}:
else:
scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
@@ -434,6 +505,9 @@ def create_ui():
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
@@ -443,10 +517,14 @@ def create_ui():
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
sampler_index,
restore_faces,
tiling,
batch_count,
batch_size,
cfg_scale,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
height,
width,
enable_hr,
@@ -457,7 +535,7 @@ def create_ui():
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_sampler_index,
hr_prompt,
hr_negative_prompt,
override_settings,
@@ -491,25 +569,40 @@ def create_ui():
show_progress=False,
)
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
show_progress = False,
)
txt2img_paste_fields = [
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(seed, "Seed"),
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_name, "Hires sampler"),
(hr_sampler_index, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
@@ -525,9 +618,9 @@ def create_ui():
toprow.prompt,
toprow.negative_prompt,
steps,
sampler_name,
sampler_index,
cfg_scale,
scripts.scripts_txt2img.script('Seed').seed,
seed,
width,
height,
]
@@ -535,6 +628,7 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@@ -549,7 +643,7 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -575,7 +669,7 @@ def create_ui():
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
@@ -583,7 +677,7 @@ def create_ui():
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -598,7 +692,7 @@ def create_ui():
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@@ -647,7 +741,7 @@ def create_ui():
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
elif category == "dimensions":
with FormRow():
@@ -697,21 +791,20 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "denoising":
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "cfg":
with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
with FormGroup():
with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed":
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass
elif category == "accordions":
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
scripts.scripts_img2img.setup_ui_for_section(category)
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -755,12 +848,14 @@ def create_ui():
inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
if category not in {"accordions"}:
else:
scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
@@ -778,15 +873,19 @@ def create_ui():
init_img_inpaint,
init_mask_inpaint,
steps,
sampler_name,
sampler_index,
mask_blur,
mask_alpha,
inpainting_fill,
restore_faces,
tiling,
batch_count,
batch_size,
cfg_scale,
image_cfg_scale,
denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
selected_scale_tab,
height,
width,
@@ -870,12 +969,19 @@ def create_ui():
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
(seed, "Seed"),
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
@@ -887,6 +993,7 @@ def create_ui():
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
))
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
@@ -1279,8 +1386,11 @@ def create_ui():
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render()
if ifid not in ["extensions", "settings"]:
loadsave.add_block(interface, ifid)
for interface, _label, ifid in interfaces:
if ifid in ["extensions", "settings"]:
continue
loadsave.add_block(interface, ifid)
loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+9 -11
View File
@@ -11,7 +11,7 @@ from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text
import modules.images
from modules.ui_components import ToolButton
import modules.generation_parameters_copypaste as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -105,6 +105,8 @@ def save_files(js_data, images, do_make_zip, index):
def create_output_panel(tabname, outdir):
from modules import shared
import modules.generation_parameters_copypaste as parameters_copypaste
def open_folder(f):
if not os.path.exists(f):
@@ -132,22 +134,18 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
generation_info = None
with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras":
save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
save = gr.Button('Save', elem_id=f'save_{tabname}')
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
buttons = {
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
}
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
open_folder_button.click(
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
@@ -261,7 +259,7 @@ def setup_dialog(button_show, dialog, *, button_close=None):
fn=lambda: gr.update(visible=True),
inputs=[],
outputs=[dialog],
).then(fn=None, _js="function(){ popupId('" + dialog.elem_id + "'); }")
).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
if button_close:
button_close.click(fn=None, _js="closePopup")
-71
View File
@@ -20,18 +20,6 @@ class ToolButton(FormComponent, gr.Button):
return "button"
class ResizeHandleRow(gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.elem_classes.append("resize-handle-row")
def get_block_name(self):
return "row"
class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
@@ -84,62 +72,3 @@ class DropdownEditable(FormComponent, gr.Dropdown):
def get_block_name(self):
return "dropdown"
class InputAccordion(gr.Checkbox):
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
"""
global_index = 0
def __init__(self, value, **kwargs):
self.accordion_id = kwargs.get('elem_id')
if self.accordion_id is None:
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1
kwargs_checkbox = {
**kwargs,
"elem_id": f"{self.accordion_id}-checkbox",
"visible": False,
}
super().__init__(value, **kwargs_checkbox)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
kwargs_accordion = {
**kwargs,
"elem_id": self.accordion_id,
"label": kwargs.get('label', 'Accordion'),
"elem_classes": ['input-accordion'],
"open": value,
}
self.accordion = gr.Accordion(**kwargs_accordion)
def extra(self):
"""Allows you to put something into the label of the accordion.
Use it like this:
```
with InputAccordion(False, label="Accordion") as acc:
with acc.extra():
FormHTML(value="hello", min_width=0)
...
```
"""
return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
def __enter__(self):
self.accordion.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.accordion.__exit__(exc_type, exc_val, exc_tb)
def get_block_name(self):
return "checkbox"
+111 -125
View File
@@ -2,7 +2,7 @@ import json
import os
import threading
import time
from datetime import datetime, timezone
from datetime import datetime
import git
@@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f, indent=4)
json.dump(current_config_state, f)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
@@ -177,7 +177,7 @@ def extension_table():
<td>{remote}</td>
<td>{ext.branch}</td>
<td>{version_link}</td>
<td>{datetime.fromtimestamp(ext.commit_date) if ext.commit_date else ""}</td>
<td>{time.asctime(time.gmtime(ext.commit_date))}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -200,129 +200,119 @@ def update_config_states_table(state_name):
created_date = time.asctime(time.gmtime(config_state["created_at"]))
filepath = config_state.get("filepath", "<unknown>")
try:
webui_remote = config_state["webui"]["remote"] or ""
webui_branch = config_state["webui"]["branch"]
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
webui_commit_date = config_state["webui"]["commit_date"]
if webui_commit_date:
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
code = f"""<!-- {time.time()} -->"""
webui_remote = config_state["webui"]["remote"] or ""
webui_branch = config_state["webui"]["branch"]
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
webui_commit_date = config_state["webui"]["commit_date"]
if webui_commit_date:
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
else:
webui_commit_date = "<unknown>"
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
commit_link = make_commit_link(webui_commit_hash, webui_remote)
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
current_webui = config_states.get_webui_config()
style_remote = ""
style_branch = ""
style_commit = ""
if current_webui["remote"] != webui_remote:
style_remote = STYLE_PRIMARY
if current_webui["branch"] != webui_branch:
style_branch = STYLE_PRIMARY
if current_webui["commit_hash"] != webui_commit_hash:
style_commit = STYLE_PRIMARY
code += f"""<h2>Config Backup: {config_name}</h2>
<div><b>Filepath:</b> {filepath}</div>
<div><b>Created at:</b> {created_date}</div>"""
code += f"""<h2>WebUI State</h2>
<table id="config_state_webui">
<thead>
<tr>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
<tr>
<td><label{style_remote}>{remote}</label></td>
<td><label{style_branch}>{webui_branch}</label></td>
<td><label{style_commit}>{commit_link}</label></td>
<td><label{style_commit}>{date_link}</label></td>
</tr>
</tbody>
</table>
"""
code += """<h2>Extension State</h2>
<table id="config_state_extensions">
<thead>
<tr>
<th>Extension</th>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
"""
ext_map = {ext.name: ext for ext in extensions.extensions}
for ext_name, ext_conf in config_state["extensions"].items():
ext_remote = ext_conf["remote"] or ""
ext_branch = ext_conf["branch"] or "<unknown>"
ext_enabled = ext_conf["enabled"]
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
ext_commit_date = ext_conf["commit_date"]
if ext_commit_date:
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
else:
webui_commit_date = "<unknown>"
ext_commit_date = "<unknown>"
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
commit_link = make_commit_link(webui_commit_hash, webui_remote)
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
current_webui = config_states.get_webui_config()
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
commit_link = make_commit_link(ext_commit_hash, ext_remote)
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
if current_webui["remote"] != webui_remote:
style_remote = STYLE_PRIMARY
if current_webui["branch"] != webui_branch:
style_branch = STYLE_PRIMARY
if current_webui["commit_hash"] != webui_commit_hash:
style_commit = STYLE_PRIMARY
if ext_name in ext_map:
current_ext = ext_map[ext_name]
current_ext.read_info_from_repo()
if current_ext.enabled != ext_enabled:
style_enabled = STYLE_PRIMARY
if current_ext.remote != ext_remote:
style_remote = STYLE_PRIMARY
if current_ext.branch != ext_branch:
style_branch = STYLE_PRIMARY
if current_ext.commit_hash != ext_commit_hash:
style_commit = STYLE_PRIMARY
code = f"""<!-- {time.time()} -->
<h2>Config Backup: {config_name}</h2>
<div><b>Filepath:</b> {filepath}</div>
<div><b>Created at:</b> {created_date}</div>
<h2>WebUI State</h2>
<table id="config_state_webui">
<thead>
<tr>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
<tr>
<td>
<label{style_remote}>{remote}</label>
</td>
<td>
<label{style_branch}>{webui_branch}</label>
</td>
<td>
<label{style_commit}>{commit_link}</label>
</td>
<td>
<label{style_commit}>{date_link}</label>
</td>
</tr>
</tbody>
</table>
<h2>Extension State</h2>
<table id="config_state_extensions">
<thead>
<tr>
<th>Extension</th>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
"""
code += f"""
<tr>
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
<td><label{style_remote}>{remote}</label></td>
<td><label{style_branch}>{ext_branch}</label></td>
<td><label{style_commit}>{commit_link}</label></td>
<td><label{style_commit}>{date_link}</label></td>
</tr>
"""
ext_map = {ext.name: ext for ext in extensions.extensions}
for ext_name, ext_conf in config_state["extensions"].items():
ext_remote = ext_conf["remote"] or ""
ext_branch = ext_conf["branch"] or "<unknown>"
ext_enabled = ext_conf["enabled"]
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
ext_commit_date = ext_conf["commit_date"]
if ext_commit_date:
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
else:
ext_commit_date = "<unknown>"
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
commit_link = make_commit_link(ext_commit_hash, ext_remote)
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
if ext_name in ext_map:
current_ext = ext_map[ext_name]
current_ext.read_info_from_repo()
if current_ext.enabled != ext_enabled:
style_enabled = STYLE_PRIMARY
if current_ext.remote != ext_remote:
style_remote = STYLE_PRIMARY
if current_ext.branch != ext_branch:
style_branch = STYLE_PRIMARY
if current_ext.commit_hash != ext_commit_hash:
style_commit = STYLE_PRIMARY
code += f""" <tr>
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
<td><label{style_remote}>{remote}</label></td>
<td><label{style_branch}>{ext_branch}</label></td>
<td><label{style_commit}>{commit_link}</label></td>
<td><label{style_commit}>{date_link}</label></td>
</tr>
"""
code += """ </tbody>
</table>"""
except Exception as e:
print(f"[ERROR]: Config states {filepath}, {e}")
code = f"""<!-- {time.time()} -->
<h2>Config Backup: {config_name}</h2>
<div><b>Filepath:</b> {filepath}</div>
<div><b>Created at:</b> {created_date}</div>
<h2>This file is corrupted</h2>"""
code += """
</tbody>
</table>
"""
return code
@@ -442,7 +432,7 @@ sort_ordering = [
def get_date(info: dict, key):
try:
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone().strftime("%Y-%m-%d")
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
except (ValueError, TypeError):
return ''
@@ -557,12 +547,8 @@ def create_ui():
msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
html = f'<span style="color: var(--primary-400);">{msg}</span>'
with gr.Row():
info = gr.HTML(html)
with gr.Row(elem_classes="progress-container"):
extensions_table = gr.HTML('Loading...', elem_id="extensions_installed_html")
info = gr.HTML(html)
extensions_table = gr.HTML('Loading...')
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
apply.click(
+2 -3
View File
@@ -4,6 +4,7 @@ from pathlib import Path
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
import json
import html
@@ -347,8 +348,6 @@ def pages_in_preferred_order(pages):
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
from modules.ui import switch_values_symbol
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
@@ -374,7 +373,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
+1 -3
View File
@@ -19,7 +19,6 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return {
"name": checkpoint.name_for_extra,
"filename": checkpoint.filename,
"shorthash": checkpoint.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
@@ -30,8 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
names = list(sd_models.checkpoints_list)
for index, name in enumerate(names):
for index, name in enumerate(sd_models.checkpoints_list):
yield self.create_item(name, index)
def allowed_directories_for_previews(self):
@@ -1,6 +1,6 @@
import gradio as gr
from modules import ui_extra_networks_user_metadata, sd_vae, shared
from modules import ui_extra_networks_user_metadata, sd_vae
from modules.ui_common import create_refresh_button
@@ -18,10 +18,6 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
self.write_user_metadata(name, user_metadata)
def update_vae(self, name):
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
sd_vae.reload_vae_weights()
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
@@ -62,5 +58,3 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
]
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
+1 -5
View File
@@ -2,7 +2,6 @@ import os
from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js
from modules.hashes import sha256_from_cache
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@@ -15,16 +14,13 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
return {
"name": name,
"filename": full_path,
"shorthash": shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
"search_term": self.search_terms_from_path(path),
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
@@ -19,10 +19,9 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return {
"name": name,
"filename": embedding.filename,
"shorthash": embedding.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},

Some files were not shown because too many files have changed in this diff Show More