Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a36a30fb93 | |||
| 2ea8726597 | |||
| 5dbd0355b0 | |||
| 64fd916334 |
@@ -20,7 +20,7 @@ jobs:
|
||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||
# of PyTorch and other dependencies.
|
||||
- name: Install Ruff
|
||||
run: pip install ruff==0.1.6
|
||||
run: pip install ruff==0.0.272
|
||||
- name: Run Ruff
|
||||
run: ruff .
|
||||
lint-js:
|
||||
|
||||
@@ -121,9 +121,7 @@ Alternatively, use online services (like Google Colab):
|
||||
# Debian-based:
|
||||
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||
# Red Hat-based:
|
||||
sudo dnf install wget git python3 gperftools-libs libglvnd-glx
|
||||
# openSUSE-based:
|
||||
sudo zypper install wget git python3 libtcmalloc4 libglvnd
|
||||
sudo dnf install wget git python3
|
||||
# Arch-based:
|
||||
sudo pacman -S wget git python3
|
||||
```
|
||||
@@ -149,7 +147,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
|
||||
## Credits
|
||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||
|
||||
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||
@@ -176,6 +174,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||
- LyCORIS - KohakuBlueleaf
|
||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
||||
@@ -19,50 +19,3 @@ def rebuild_cp_decomposition(up, down, mid):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||
|
||||
|
||||
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
'''
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
'''
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import torch
|
||||
import network
|
||||
from lyco_helpers import factorization
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class ModuleTypeOFT(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||
return NetworkModuleOFT(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||
class NetworkModuleOFT(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.lin_module = None
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
# kohya-ss
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.is_kohya = True
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# LyCORIS
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
self.is_kohya = False
|
||||
self.oft_blocks = weights.w["oft_diag"]
|
||||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
|
||||
if is_linear:
|
||||
self.out_dim = self.sd_module.out_features
|
||||
elif is_conv:
|
||||
self.out_dim = self.sd_module.out_channels
|
||||
elif is_other_linear:
|
||||
self.out_dim = self.sd_module.embed_dim
|
||||
|
||||
if self.is_kohya:
|
||||
self.constraint = self.alpha * self.out_dim
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
else:
|
||||
self.constraint = None
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
|
||||
def calc_updown_kb(self, orig_weight, multiplier):
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||
|
||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
||||
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
merged_weight = torch.einsum(
|
||||
'k n m, k n ... -> k m ...',
|
||||
R,
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
output_shape = orig_weight.shape
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
||||
multiplier = self.multiplier()
|
||||
return self.calc_updown_kb(orig_weight, multiplier)
|
||||
|
||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
|
||||
return updown, ex_bias
|
||||
@@ -11,7 +11,6 @@ import network_ia3
|
||||
import network_lokr
|
||||
import network_full
|
||||
import network_norm
|
||||
import network_oft
|
||||
|
||||
import torch
|
||||
from typing import Union
|
||||
@@ -29,7 +28,6 @@ module_types = [
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
network_glora.ModuleTypeGLora(),
|
||||
network_oft.ModuleTypeOFT(),
|
||||
]
|
||||
|
||||
|
||||
@@ -191,17 +189,6 @@ def load_network(name, network_on_disk):
|
||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
|
||||
# kohya_ss OFT module
|
||||
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
|
||||
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
|
||||
# KohakuBlueLeaf OFT module
|
||||
if sd_module is None and "oft_diag" in key:
|
||||
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_network] = key
|
||||
continue
|
||||
|
||||
@@ -17,8 +17,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
lora_on_disk = networks.available_networks.get(name)
|
||||
if lora_on_disk is None:
|
||||
return
|
||||
|
||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||
|
||||
@@ -68,10 +66,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
return item
|
||||
|
||||
def list_items(self):
|
||||
# instantiate a list to protect against concurrent modification
|
||||
names = list(networks.available_networks)
|
||||
for index, name in enumerate(names):
|
||||
for index, name in enumerate(networks.available_networks):
|
||||
item = self.create_item(name, index)
|
||||
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
"""
|
||||
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||
Warn: The patch works well only if the input image has a width and height that are multiples of 128
|
||||
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from functools import wraps, cache
|
||||
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import random
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypertileParams:
|
||||
depth = 0
|
||||
layer_name = ""
|
||||
tile_size: int = 0
|
||||
swap_size: int = 0
|
||||
aspect_ratio: float = 1.0
|
||||
forward = None
|
||||
enabled = False
|
||||
|
||||
|
||||
|
||||
# TODO add SD-XL layers
|
||||
DEPTH_LAYERS = {
|
||||
0: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.1.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.9.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.10.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.11.1.transformer_blocks.0.attn1",
|
||||
# SD 1.5 VAE
|
||||
"decoder.mid_block.attentions.0",
|
||||
"decoder.mid.attn_1",
|
||||
],
|
||||
1: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.6.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.8.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
2: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
3: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"middle_block.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
}
|
||||
# XL layers, thanks for GitHub@gel-crabs for the help
|
||||
DEPTH_LAYERS_XL = {
|
||||
0: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
# SD 1.5 VAE
|
||||
"decoder.mid_block.attentions.0",
|
||||
"decoder.mid.attn_1",
|
||||
],
|
||||
1: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.2.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.2.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.3.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.3.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.4.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.4.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.5.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.5.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.6.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.6.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.7.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.7.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.8.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.8.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.9.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.9.attn1",
|
||||
],
|
||||
2: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"middle_block.1.transformer_blocks.0.attn1",
|
||||
"middle_block.1.transformer_blocks.1.attn1",
|
||||
"middle_block.1.transformer_blocks.2.attn1",
|
||||
"middle_block.1.transformer_blocks.3.attn1",
|
||||
"middle_block.1.transformer_blocks.4.attn1",
|
||||
"middle_block.1.transformer_blocks.5.attn1",
|
||||
"middle_block.1.transformer_blocks.6.attn1",
|
||||
"middle_block.1.transformer_blocks.7.attn1",
|
||||
"middle_block.1.transformer_blocks.8.attn1",
|
||||
"middle_block.1.transformer_blocks.9.attn1",
|
||||
],
|
||||
3 : [] # TODO - separate layers for SD-XL
|
||||
}
|
||||
|
||||
|
||||
RNG_INSTANCE = random.Random()
|
||||
|
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
"""
|
||||
Returns a random divisor of value that
|
||||
x * min_value <= value
|
||||
if max_options is 1, the behavior is deterministic
|
||||
"""
|
||||
min_value = min(min_value, value)
|
||||
|
||||
# All big divisors of value (inclusive)
|
||||
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
|
||||
|
||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
|
||||
|
||||
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
|
||||
|
||||
return ns[idx]
|
||||
|
||||
|
||||
def set_hypertile_seed(seed: int) -> None:
|
||||
RNG_INSTANCE.seed(seed)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def largest_tile_size_available(width: int, height: int) -> int:
|
||||
"""
|
||||
Calculates the largest tile size available for a given width and height
|
||||
Tile size is always a power of 2
|
||||
"""
|
||||
gcd = math.gcd(width, height)
|
||||
largest_tile_size_available = 1
|
||||
while gcd % (largest_tile_size_available * 2) == 0:
|
||||
largest_tile_size_available *= 2
|
||||
return largest_tile_size_available
|
||||
|
||||
|
||||
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||
"""
|
||||
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||
We check all possible divisors of hw and return the closest to the aspect ratio
|
||||
"""
|
||||
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
||||
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
||||
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
||||
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
||||
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||
return closest_pair
|
||||
|
||||
|
||||
@cache
|
||||
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||
"""
|
||||
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||
"""
|
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||
# find h and w such that h*w = hw and h/w = aspect_ratio
|
||||
if h * w != hw:
|
||||
w_candidate = hw / h
|
||||
# check if w is an integer
|
||||
if not w_candidate.is_integer():
|
||||
h_candidate = hw / w
|
||||
# check if h is an integer
|
||||
if not h_candidate.is_integer():
|
||||
return iterative_closest_divisors(hw, aspect_ratio)
|
||||
else:
|
||||
h = int(h_candidate)
|
||||
else:
|
||||
w = int(w_candidate)
|
||||
return h, w
|
||||
|
||||
|
||||
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
|
||||
|
||||
@wraps(params.forward)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not params.enabled:
|
||||
return params.forward(*args, **kwargs)
|
||||
|
||||
latent_tile_size = max(128, params.tile_size) // 8
|
||||
x = args[0]
|
||||
|
||||
# VAE
|
||||
if x.ndim == 4:
|
||||
b, c, h, w = x.shape
|
||||
|
||||
nh = random_divisor(h, latent_tile_size, params.swap_size)
|
||||
nw = random_divisor(w, latent_tile_size, params.swap_size)
|
||||
|
||||
if nh * nw > 1:
|
||||
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||
|
||||
out = params.forward(x, *args[1:], **kwargs)
|
||||
|
||||
if nh * nw > 1:
|
||||
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||
|
||||
# U-Net
|
||||
else:
|
||||
hw: int = x.size(1)
|
||||
h, w = find_hw_candidates(hw, params.aspect_ratio)
|
||||
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||
|
||||
factor = 2 ** params.depth if scale_depth else 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
|
||||
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
|
||||
|
||||
if nh * nw > 1:
|
||||
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||
|
||||
out = params.forward(x, *args[1:], **kwargs)
|
||||
|
||||
if nh * nw > 1:
|
||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
|
||||
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
|
||||
if hypertile_layers is None:
|
||||
if not enable:
|
||||
return
|
||||
|
||||
hypertile_layers = {}
|
||||
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
|
||||
|
||||
for depth in range(4):
|
||||
for layer_name, module in model.named_modules():
|
||||
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||
params = HypertileParams()
|
||||
module.__webui_hypertile_params = params
|
||||
params.forward = module.forward
|
||||
params.depth = depth
|
||||
params.layer_name = layer_name
|
||||
module.forward = self_attn_forward(params)
|
||||
|
||||
hypertile_layers[layer_name] = 1
|
||||
|
||||
model.__webui_hypertile_layers = hypertile_layers
|
||||
|
||||
aspect_ratio = width / height
|
||||
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
|
||||
|
||||
for layer_name, module in model.named_modules():
|
||||
if layer_name in hypertile_layers:
|
||||
params = module.__webui_hypertile_params
|
||||
|
||||
params.tile_size = tile_size
|
||||
params.swap_size = swap_size
|
||||
params.aspect_ratio = aspect_ratio
|
||||
params.enabled = enable and params.depth <= max_depth
|
||||
@@ -1,73 +0,0 @@
|
||||
import hypertile
|
||||
from modules import scripts, script_callbacks, shared
|
||||
|
||||
|
||||
class ScriptHypertile(scripts.Script):
|
||||
name = "Hypertile"
|
||||
|
||||
def title(self):
|
||||
return self.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def process(self, p, *args):
|
||||
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||
|
||||
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||
|
||||
def before_hr(self, p, *args):
|
||||
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
|
||||
|
||||
|
||||
def configure_hypertile(width, height, enable_unet=True):
|
||||
hypertile.hypertile_hook_model(
|
||||
shared.sd_model.first_stage_model,
|
||||
width,
|
||||
height,
|
||||
swap_size=shared.opts.hypertile_swap_size_vae,
|
||||
max_depth=shared.opts.hypertile_max_depth_vae,
|
||||
tile_size_max=shared.opts.hypertile_max_tile_vae,
|
||||
enable=shared.opts.hypertile_enable_vae,
|
||||
)
|
||||
|
||||
hypertile.hypertile_hook_model(
|
||||
shared.sd_model.model,
|
||||
width,
|
||||
height,
|
||||
swap_size=shared.opts.hypertile_swap_size_unet,
|
||||
max_depth=shared.opts.hypertile_max_depth_unet,
|
||||
tile_size_max=shared.opts.hypertile_max_tile_unet,
|
||||
enable=enable_unet,
|
||||
is_sdxl=shared.sd_model.is_sdxl
|
||||
)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
options = {
|
||||
"hypertile_explanation": shared.OptionHTML("""
|
||||
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
|
||||
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
|
||||
benefit.
|
||||
"""),
|
||||
|
||||
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
|
||||
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
|
||||
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||
|
||||
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
|
||||
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
|
||||
}
|
||||
|
||||
for name, opt in options.items():
|
||||
opt.section = ('hypertile', "Hypertile")
|
||||
shared.opts.add_option(name, opt)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
@@ -130,10 +130,6 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp
|
||||
} else {
|
||||
promptContainer.insertBefore(prompt, promptContainer.firstChild);
|
||||
}
|
||||
|
||||
if (elem) {
|
||||
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -26,11 +26,7 @@ onAfterUiUpdate(function() {
|
||||
lastHeadImg = headImg;
|
||||
|
||||
// play notification sound if available
|
||||
const notificationAudio = gradioApp().querySelector('#audio_notification audio');
|
||||
if (notificationAudio) {
|
||||
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
|
||||
notificationAudio.play();
|
||||
}
|
||||
gradioApp().querySelector('#audio_notification audio')?.play();
|
||||
|
||||
if (document.hasFocus()) return;
|
||||
|
||||
|
||||
@@ -44,28 +44,3 @@ onUiLoaded(function() {
|
||||
|
||||
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
|
||||
});
|
||||
|
||||
|
||||
onOptionsChanged(function() {
|
||||
if (gradioApp().querySelector('#settings .settings-category')) return;
|
||||
|
||||
var sectionMap = {};
|
||||
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
|
||||
sectionMap[x.textContent.trim()] = x;
|
||||
});
|
||||
|
||||
opts._categories.forEach(function(x) {
|
||||
var section = x[0];
|
||||
var category = x[1];
|
||||
|
||||
var span = document.createElement('SPAN');
|
||||
span.textContent = category;
|
||||
span.className = 'settings-category';
|
||||
|
||||
var sectionElem = sectionMap[section];
|
||||
if (!sectionElem) return;
|
||||
|
||||
sectionElem.parentElement.insertBefore(span, sectionElem);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ class PydanticModelGenerator:
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.model_config['populate_by_name'] = True
|
||||
DynamicModel.model_config['frozen'] = True
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
|
||||
+1
-1
@@ -32,7 +32,7 @@ def dump_cache():
|
||||
with cache_lock:
|
||||
cache_filename_tmp = cache_filename + "-"
|
||||
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4, ensure_ascii=False)
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
os.replace(cache_filename_tmp, cache_filename)
|
||||
|
||||
|
||||
+2
-16
@@ -6,21 +6,6 @@ import traceback
|
||||
exception_records = []
|
||||
|
||||
|
||||
def format_traceback(tb):
|
||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||
|
||||
|
||||
def format_exception(e, tb):
|
||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
return list(reversed(exception_records))
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def record_exception():
|
||||
_, e, tb = sys.exc_info()
|
||||
if e is None:
|
||||
@@ -29,7 +14,8 @@ def record_exception():
|
||||
if exception_records and exception_records[-1] == e:
|
||||
return
|
||||
|
||||
exception_records.append(format_exception(e, tb))
|
||||
from modules import sysinfo
|
||||
exception_records.append(sysinfo.format_exception(e, tb))
|
||||
|
||||
if len(exception_records) > 5:
|
||||
exception_records.pop(0)
|
||||
|
||||
+12
-84
@@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import configparser
|
||||
import os
|
||||
import threading
|
||||
import re
|
||||
|
||||
from modules import shared, errors, cache, scripts
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
|
||||
extensions = []
|
||||
|
||||
os.makedirs(extensions_dir, exist_ok=True)
|
||||
|
||||
@@ -22,55 +19,11 @@ def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class ExtensionMetadata:
|
||||
filename = "metadata.ini"
|
||||
config: configparser.ConfigParser
|
||||
canonical_name: str
|
||||
requires: list
|
||||
|
||||
def __init__(self, path, canonical_name):
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
filepath = os.path.join(path, self.filename)
|
||||
if os.path.isfile(filepath):
|
||||
try:
|
||||
self.config.read(filepath)
|
||||
except Exception:
|
||||
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||
|
||||
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||
self.canonical_name = canonical_name.lower().strip()
|
||||
|
||||
self.requires = self.get_script_requirements("Requires", "Extension")
|
||||
|
||||
def get_script_requirements(self, field, section, extra_section=None):
|
||||
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
||||
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
|
||||
reads more requirements from [extra_section] if specified."""
|
||||
|
||||
x = self.config.get(section, field, fallback='')
|
||||
|
||||
if extra_section:
|
||||
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
||||
|
||||
return self.parse_list(x.lower())
|
||||
|
||||
def parse_list(self, text):
|
||||
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
||||
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# both "," and " " are accepted as separator
|
||||
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||
|
||||
|
||||
class Extension:
|
||||
lock = threading.Lock()
|
||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||
metadata: ExtensionMetadata
|
||||
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
@@ -83,8 +36,6 @@ class Extension:
|
||||
self.branch = None
|
||||
self.remote = None
|
||||
self.have_info_from_repo = False
|
||||
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
||||
self.canonical_name = metadata.canonical_name
|
||||
|
||||
def to_dict(self):
|
||||
return {x: getattr(self, x) for x in self.cached_fields}
|
||||
@@ -105,7 +56,6 @@ class Extension:
|
||||
self.do_read_info_from_repo()
|
||||
|
||||
return self.to_dict()
|
||||
|
||||
try:
|
||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||
self.from_dict(d)
|
||||
@@ -186,6 +136,9 @@ class Extension:
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
if shared.cmd_opts.disable_all_extensions:
|
||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "all":
|
||||
@@ -195,43 +148,18 @@ def list_extensions():
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||
|
||||
loaded_extensions = {}
|
||||
|
||||
# scan through extensions directory and load metadata
|
||||
for dirname in [extensions_builtin_dir, extensions_dir]:
|
||||
extension_paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
continue
|
||||
return
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
canonical_name = extension_dirname
|
||||
metadata = ExtensionMetadata(path, canonical_name)
|
||||
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
# check for duplicated canonical names
|
||||
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
|
||||
if already_loaded_extension is not None:
|
||||
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
|
||||
continue
|
||||
|
||||
is_builtin = dirname == extensions_builtin_dir
|
||||
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
||||
extensions.append(extension)
|
||||
loaded_extensions[canonical_name] = extension
|
||||
|
||||
# check for requirements
|
||||
for extension in extensions:
|
||||
for req in extension.metadata.requires:
|
||||
required_extension = loaded_extensions.get(req)
|
||||
if required_extension is None:
|
||||
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||
continue
|
||||
|
||||
if not extension.enabled:
|
||||
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||
continue
|
||||
|
||||
|
||||
extensions: list[Extension] = []
|
||||
for dirname, path, is_builtin in extension_paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from inspect import signature
|
||||
from functools import wraps
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui_tempdir, patches
|
||||
@@ -66,77 +64,10 @@ def Blocks_get_config_file(self, *args, **kwargs):
|
||||
return config
|
||||
|
||||
|
||||
def gradio_component_compatibility_layer(component_function):
|
||||
@wraps(component_function)
|
||||
def patched_function(*args, **kwargs):
|
||||
original_signature = signature(component_function).parameters
|
||||
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||
result = component_function(*args, **valid_kwargs)
|
||||
return result
|
||||
|
||||
return patched_function
|
||||
|
||||
|
||||
sub_events = ['then', 'success']
|
||||
|
||||
|
||||
def gradio_component_events_compatibility_layer(component_function):
|
||||
@wraps(component_function)
|
||||
def patched_function(*args, **kwargs):
|
||||
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
|
||||
original_signature = signature(component_function).parameters
|
||||
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||
|
||||
result = component_function(*args, **valid_kwargs)
|
||||
|
||||
for sub_event in sub_events:
|
||||
component_event_then_function = getattr(result, sub_event, None)
|
||||
if component_event_then_function:
|
||||
patched_component_event_then_function = gradio_component_sub_events_compatibility_layer(component_event_then_function)
|
||||
setattr(result, sub_event, patched_component_event_then_function)
|
||||
# original_component_event_then_function = patches.patch(f'{__name__}.', obj=result, field='then', replacement=patched_component_event_then_function)
|
||||
|
||||
return result
|
||||
|
||||
return patched_function
|
||||
|
||||
|
||||
def gradio_component_sub_events_compatibility_layer(component_function):
|
||||
@wraps(component_function)
|
||||
def patched_function(*args, **kwargs):
|
||||
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
|
||||
original_signature = signature(component_function).parameters
|
||||
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
|
||||
result = component_function(*args, **valid_kwargs)
|
||||
return result
|
||||
|
||||
return patched_function
|
||||
|
||||
|
||||
for component_name in set(gr.components.__all__ + gr.layouts.__all__):
|
||||
try:
|
||||
component = getattr(gr, component_name)
|
||||
component_init = getattr(component, '__init__')
|
||||
patched_component_init = gradio_component_compatibility_layer(component_init)
|
||||
original_IOComponent_init = patches.patch(f'{__name__}.{component_name}', obj=component, field="__init__", replacement=patched_component_init)
|
||||
|
||||
component_events = set(getattr(component, 'EVENTS'))
|
||||
for component_event in component_events:
|
||||
component_event_function = getattr(component, component_event)
|
||||
patched_component_event_function = gradio_component_events_compatibility_layer(component_event_function)
|
||||
original_component_event_function = patches.patch(f'{__name__}.{component_name}.{component_event}', obj=component, field=component_event, replacement=patched_component_event_function)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
gr.Box = gr.Group
|
||||
|
||||
|
||||
original_IOComponent_init = patches.patch(__name__, obj=gr.components.base.Component, field="__init__", replacement=IOComponent_init)
|
||||
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
||||
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||
|
||||
|
||||
ui_tempdir.install_ui_tempdir_override()
|
||||
|
||||
|
||||
+4
-20
@@ -44,8 +44,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
steps = p.steps
|
||||
override_settings = p.override_settings
|
||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||
batch_results = None
|
||||
discard_further_results = False
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
@@ -129,21 +127,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
|
||||
if proc is None:
|
||||
p.override_settings.pop('save_images_replace_action', None)
|
||||
proc = process_images(p)
|
||||
|
||||
if not discard_further_results and proc:
|
||||
if batch_results:
|
||||
batch_results.images.extend(proc.images)
|
||||
batch_results.infotexts.extend(proc.infotexts)
|
||||
else:
|
||||
batch_results = proc
|
||||
|
||||
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
|
||||
discard_further_results = True
|
||||
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||
|
||||
return batch_results
|
||||
process_images(p)
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
@@ -228,10 +212,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
with closing(p):
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||
|
||||
if processed is None:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
|
||||
@@ -441,7 +441,7 @@ def dump_sysinfo():
|
||||
import datetime
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
file.write(text)
|
||||
|
||||
@@ -1,41 +1,16 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
def __init__(self, level=logging.INFO):
|
||||
super().__init__(level)
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
tqdm.write(msg)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
TQDM_IMPORTED = True
|
||||
except ImportError:
|
||||
# tqdm does not exist before first launch
|
||||
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||
TQDM_IMPORTED = False
|
||||
|
||||
def setup_logging(loglevel):
|
||||
if loglevel is None:
|
||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
|
||||
loghandlers = []
|
||||
|
||||
if TQDM_IMPORTED:
|
||||
loghandlers.append(TqdmLoggingHandler())
|
||||
|
||||
if loglevel:
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=loghandlers
|
||||
)
|
||||
|
||||
|
||||
+10
-71
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -9,14 +8,13 @@ from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
self.section = section
|
||||
self.category_id = category_id
|
||||
self.refresh = refresh
|
||||
self.do_not_save = False
|
||||
|
||||
@@ -65,11 +63,7 @@ class OptionHTML(OptionInfo):
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for v in options_dict.values():
|
||||
if len(section_identifier) == 2:
|
||||
v.section = section_identifier
|
||||
elif len(section_identifier) == 3:
|
||||
v.section = section_identifier[0:2]
|
||||
v.category_id = section_identifier[2]
|
||||
v.section = section_identifier
|
||||
|
||||
return options_dict
|
||||
|
||||
@@ -82,7 +76,7 @@ class Options:
|
||||
|
||||
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||
self.data_labels = data_labels
|
||||
self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
|
||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||
self.restricted_opts = restricted_opts
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
@@ -164,7 +158,7 @@ class Options:
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
json.dump(self.data, file, indent=4, ensure_ascii=False)
|
||||
json.dump(self.data, file, indent=4)
|
||||
|
||||
def same_type(self, x, y):
|
||||
if x is None or y is None:
|
||||
@@ -212,59 +206,23 @@ class Options:
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
|
||||
item_categories = {}
|
||||
for item in self.data_labels.values():
|
||||
category = categories.mapping.get(item.category_id)
|
||||
category = "Uncategorized" if category is None else category.label
|
||||
if category not in item_categories:
|
||||
item_categories[category] = item.section[1]
|
||||
|
||||
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
|
||||
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
|
||||
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
if key not in self.data and not info.do_not_save:
|
||||
if key not in self.data:
|
||||
self.data[key] = info.default
|
||||
|
||||
def reorder(self):
|
||||
"""Reorder settings so that:
|
||||
- all items related to section always go together
|
||||
- all sections belonging to a category go together
|
||||
- sections inside a category are ordered alphabetically
|
||||
- categories are ordered by creation order
|
||||
|
||||
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
|
||||
|
||||
This function also changes items' category_id so that all items belonging to a section have the same category_id.
|
||||
"""
|
||||
|
||||
category_ids = {}
|
||||
section_categories = {}
|
||||
"""reorder settings so that all items related to section always go together"""
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_categories:
|
||||
section_categories[item.section] = item.category_id
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
for _, item in settings_items:
|
||||
item.category_id = section_categories.get(item.section)
|
||||
|
||||
for category_id in categories.mapping:
|
||||
if category_id not in category_ids:
|
||||
category_ids[category_id] = len(category_ids)
|
||||
|
||||
def sort_key(x):
|
||||
item: OptionInfo = x[1]
|
||||
category_order = category_ids.get(item.category_id, len(category_ids))
|
||||
section_order = item.section[1]
|
||||
|
||||
return category_order, section_order
|
||||
|
||||
self.data_labels = dict(sorted(settings_items, key=sort_key))
|
||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
@@ -287,22 +245,3 @@ class Options:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionsCategory:
|
||||
id: str
|
||||
label: str
|
||||
|
||||
class OptionsCategories:
|
||||
def __init__(self):
|
||||
self.mapping = {}
|
||||
|
||||
def register_category(self, category_id, label):
|
||||
if category_id in self.mapping:
|
||||
return category_id
|
||||
|
||||
self.mapping[category_id] = OptionsCategory(category_id, label)
|
||||
|
||||
|
||||
categories = OptionsCategories()
|
||||
|
||||
@@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
image_data.close()
|
||||
|
||||
devices.torch_gc()
|
||||
shared.state.end()
|
||||
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
||||
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableDiffusionProcessing:
|
||||
return conditioning
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
return conditioning_image
|
||||
|
||||
@@ -799,6 +799,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
@@ -872,6 +873,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
if opts.sd_vae_decode_method != 'Full':
|
||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
@@ -1145,7 +1147,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
devices.torch_gc()
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
@@ -1155,6 +1156,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||
|
||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||
@@ -1162,6 +1165,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return samples
|
||||
|
||||
self.is_hr_pass = True
|
||||
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
@@ -1250,6 +1254,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return decoded_samples
|
||||
|
||||
def close(self):
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ class ImageRNG:
|
||||
self.is_first = True
|
||||
|
||||
def first(self):
|
||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
|
||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
||||
|
||||
xs = []
|
||||
|
||||
|
||||
+16
-103
@@ -311,113 +311,20 @@ scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
def topological_sort(dependencies):
|
||||
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
||||
Ignores errors relating to missing dependeencies or circular dependencies
|
||||
"""
|
||||
|
||||
visited = {}
|
||||
result = []
|
||||
|
||||
def inner(name):
|
||||
visited[name] = True
|
||||
|
||||
for dep in dependencies.get(name, []):
|
||||
if dep in dependencies and dep not in visited:
|
||||
inner(dep)
|
||||
|
||||
result.append(name)
|
||||
|
||||
for depname in dependencies:
|
||||
if depname not in visited:
|
||||
inner(depname)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptWithDependencies:
|
||||
script_canonical_name: str
|
||||
file: ScriptFile
|
||||
requires: list
|
||||
load_before: list
|
||||
load_after: list
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||
scripts = {}
|
||||
scripts_list = []
|
||||
|
||||
loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
|
||||
loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
|
||||
|
||||
# build script dependency map
|
||||
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(root_script_basedir):
|
||||
for filename in sorted(os.listdir(root_script_basedir)):
|
||||
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
|
||||
continue
|
||||
|
||||
if os.path.splitext(filename)[1].lower() != extension:
|
||||
continue
|
||||
|
||||
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
|
||||
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
|
||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(basedir):
|
||||
for filename in sorted(os.listdir(basedir)):
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
if include_extensions:
|
||||
for ext in extensions.active():
|
||||
extension_scripts_list = ext.list_files(scriptdirname, extension)
|
||||
for extension_script in extension_scripts_list:
|
||||
if not os.path.isfile(extension_script.path):
|
||||
continue
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
|
||||
relative_path = scriptdirname + "/" + extension_script.filename
|
||||
|
||||
script = ScriptWithDependencies(
|
||||
script_canonical_name=script_canonical_name,
|
||||
file=extension_script,
|
||||
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
|
||||
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
|
||||
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
|
||||
)
|
||||
|
||||
scripts[script_canonical_name] = script
|
||||
loaded_extensions_scripts[ext.canonical_name].append(script)
|
||||
|
||||
for script_canonical_name, script in scripts.items():
|
||||
# load before requires inverse dependency
|
||||
# in this case, append the script name into the load_after list of the specified script
|
||||
for load_before in script.load_before:
|
||||
# if this requires an individual script to be loaded before
|
||||
other_script = scripts.get(load_before)
|
||||
if other_script:
|
||||
other_script.load_after.append(script_canonical_name)
|
||||
|
||||
# if this requires an extension
|
||||
other_extension_scripts = loaded_extensions_scripts.get(load_before)
|
||||
if other_extension_scripts:
|
||||
for other_script in other_extension_scripts:
|
||||
other_script.load_after.append(script_canonical_name)
|
||||
|
||||
# if After mentions an extension, remove it and instead add all of its scripts
|
||||
for load_after in list(script.load_after):
|
||||
if load_after not in scripts and load_after in loaded_extensions_scripts:
|
||||
script.load_after.remove(load_after)
|
||||
|
||||
for other_script in loaded_extensions_scripts.get(load_after, []):
|
||||
script.load_after.append(other_script.script_canonical_name)
|
||||
|
||||
dependencies = {}
|
||||
|
||||
for script_canonical_name, script in scripts.items():
|
||||
for required_script in script.requires:
|
||||
if required_script not in scripts and required_script not in loaded_extensions:
|
||||
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
|
||||
|
||||
dependencies[script_canonical_name] = script.load_after
|
||||
|
||||
ordered_scripts = topological_sort(dependencies)
|
||||
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return scripts_list
|
||||
|
||||
@@ -458,9 +365,15 @@ def load_scripts():
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
# here the scripts_list is already ordered
|
||||
# processing_script is not considered though
|
||||
for scriptfile in scripts_list:
|
||||
def orderby(basedir):
|
||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
||||
for key in priority:
|
||||
if basedir.startswith(key):
|
||||
return priority[key]
|
||||
return 9999
|
||||
|
||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
|
||||
|
||||
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
elif approximation == 3:
|
||||
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
elif approximation == 4:
|
||||
with devices.autocast(), torch.no_grad():
|
||||
x_sample = sd_vae_consistency.decoder_model()(
|
||||
sample.detach().to(devices.device, devices.dtype)/0.18215,
|
||||
schedule=[float(i.strip()) for i in shared.opts.sd_vae_consistency_schedule.split(',')],
|
||||
)
|
||||
sd_vae_consistency.unload()
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
|
||||
@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||
while restart_times > 0:
|
||||
restart_times -= 1
|
||||
step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
|
||||
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
||||
|
||||
last_sigma = None
|
||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Consistency Decoder
|
||||
Improved decoding for stable diffusion vaes.
|
||||
|
||||
https://github.com/openai/consistencydecoder
|
||||
"""
|
||||
import os
|
||||
|
||||
from modules import devices, paths_internal, shared
|
||||
from consistencydecoder import ConsistencyDecoder
|
||||
|
||||
|
||||
sd_vae_consistency_models = None
|
||||
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
|
||||
|
||||
|
||||
def decoder_model():
|
||||
global sd_vae_consistency_models
|
||||
if getattr(shared.sd_model, 'is_sdxl', False):
|
||||
raise NotImplementedError("SDXL is not supported for consistency decoder")
|
||||
if sd_vae_consistency_models is not None:
|
||||
sd_vae_consistency_models.ckpt.to(devices.device)
|
||||
return sd_vae_consistency_models
|
||||
|
||||
loaded_model = ConsistencyDecoder(devices.device, model_path)
|
||||
sd_vae_consistency_models = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def unload():
|
||||
global sd_vae_consistency_models
|
||||
if sd_vae_consistency_models is not None:
|
||||
devices.torch_gc()
|
||||
sd_vae_consistency_models.ckpt.to('cpu')
|
||||
+24
-34
@@ -3,7 +3,7 @@ import gradio as gr
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||
from modules.options import options_section, OptionInfo, OptionHTML
|
||||
|
||||
options_templates = {}
|
||||
hide_dirs = shared.hide_dirs
|
||||
@@ -21,14 +21,7 @@ restricted_opts = {
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
categories.register_category("saving", "Saving images")
|
||||
categories.register_category("sd", "Stable Diffusion")
|
||||
categories.register_category("ui", "User Interface")
|
||||
categories.register_category("system", "System")
|
||||
categories.register_category("postprocessing", "Postprocessing")
|
||||
categories.register_category("training", "Training")
|
||||
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
@@ -71,10 +64,9 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||
|
||||
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
|
||||
"notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||
@@ -86,7 +78,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving", "s
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
@@ -94,21 +86,21 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), {
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System", "system"), {
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||
@@ -123,13 +115,13 @@ options_templates.update(options_section(('system', "System", "system"), {
|
||||
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('API', "API", "system"), {
|
||||
options_templates.update(options_section(('API', "API"), {
|
||||
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training", "training"), {
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
@@ -144,7 +136,7 @@ options_templates.update(options_section(('training', "Training", "training"), {
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||
@@ -161,14 +153,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('vae', "VAE", "sd"), {
|
||||
options_templates.update(options_section(('vae', "VAE"), {
|
||||
"sd_vae_explanation": OptionHTML("""
|
||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||
@@ -180,10 +172,11 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
"sd_vae_consistency_schedule": OptionInfo("1.0, 0.5", "consistency schedule").info("sampling schedule for consistency decoder."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||
options_templates.update(options_section(('img2img', "img2img"), {
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||
@@ -196,10 +189,9 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
@@ -210,7 +202,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
@@ -235,7 +227,7 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), {
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||
@@ -243,7 +235,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||
"extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
@@ -252,7 +244,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||
@@ -281,13 +273,11 @@ options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext", "ui"), {
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||
@@ -302,7 +292,7 @@ options_templates.update(options_section(('infotext', "Infotext", "ui"), {
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||
@@ -315,7 +305,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||
@@ -337,7 +327,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
|
||||
+17
-1
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import platform
|
||||
import hashlib
|
||||
@@ -83,7 +84,7 @@ def get_dict():
|
||||
"Checksum": checksum_token,
|
||||
"Commandline": get_argv(),
|
||||
"Torch env info": get_torch_sysinfo(),
|
||||
"Exceptions": errors.get_exceptions(),
|
||||
"Exceptions": get_exceptions(),
|
||||
"CPU": {
|
||||
"model": platform.processor(),
|
||||
"count logical": psutil.cpu_count(logical=True),
|
||||
@@ -103,6 +104,21 @@ def get_dict():
|
||||
return res
|
||||
|
||||
|
||||
def format_traceback(tb):
|
||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||
|
||||
|
||||
def format_exception(e, tb):
|
||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
return list(reversed(errors.exception_records))
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def get_environment():
|
||||
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
||||
|
||||
|
||||
+11
-20
@@ -4,7 +4,6 @@ import os
|
||||
import sys
|
||||
from functools import reduce
|
||||
import warnings
|
||||
from contextlib import ExitStack
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
@@ -32,8 +31,8 @@ from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
|
||||
# warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||
# warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
@@ -271,11 +270,7 @@ def create_ui():
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||
with ExitStack() as stack:
|
||||
if shared.opts.txt2img_settings_accordion:
|
||||
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
|
||||
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
scripts.scripts_txt2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
@@ -494,11 +489,7 @@ def create_ui():
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||
with ExitStack() as stack:
|
||||
if shared.opts.img2img_settings_accordion:
|
||||
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
|
||||
|
||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||
copy_image_buttons = []
|
||||
copy_image_destinations = {}
|
||||
|
||||
@@ -635,6 +626,12 @@ def create_ui():
|
||||
scale_by.release(**on_change_args)
|
||||
button_update_resize_to.click(**on_change_args)
|
||||
|
||||
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||
for component in [init_img, sketch]:
|
||||
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
||||
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
||||
|
||||
@@ -695,12 +692,6 @@ def create_ui():
|
||||
if category not in {"accordions"}:
|
||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||
for component in [init_img, sketch]:
|
||||
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
def select_img2img_tab(tab):
|
||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||
|
||||
@@ -1308,7 +1299,7 @@ def setup_ui_api(app):
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
||||
|
||||
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def save_config_state(name):
|
||||
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||
print(f"Saving backup of webui/extension state to {filename}.")
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(current_config_state, f, indent=4, ensure_ascii=False)
|
||||
json.dump(current_config_state, f, indent=4)
|
||||
config_states.list_config_states()
|
||||
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||
|
||||
@@ -279,7 +279,6 @@ class ExtraNetworksPage:
|
||||
"date_created": int(stat.st_ctime or 0),
|
||||
"date_modified": int(stat.st_mtime or 0),
|
||||
"name": pth.name.lower(),
|
||||
"path": str(pth.parent).lower(),
|
||||
}
|
||||
|
||||
def find_preview(self, path):
|
||||
@@ -370,9 +369,6 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
||||
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
|
||||
pass
|
||||
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
@@ -386,7 +382,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
related_tabs.append(tab)
|
||||
|
||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||
@@ -403,7 +399,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
allow_prompt = "true" if page.allow_prompt else "false"
|
||||
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
||||
|
||||
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||
|
||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||
|
||||
|
||||
@@ -17,9 +17,6 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
return {
|
||||
"name": checkpoint.name_for_extra,
|
||||
@@ -35,12 +32,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
}
|
||||
|
||||
def list_items(self):
|
||||
# instantiate a list to protect against concurrent modification
|
||||
names = list(sd_models.checkpoints_list)
|
||||
for index, name in enumerate(names):
|
||||
item = self.create_item(name, index)
|
||||
if item is not None:
|
||||
yield item
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
||||
@@ -13,10 +13,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
full_path = shared.hypernetworks.get(name)
|
||||
if full_path is None:
|
||||
return
|
||||
|
||||
full_path = shared.hypernetworks[name]
|
||||
path, ext = os.path.splitext(full_path)
|
||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||
shorthash = sha256[0:10] if sha256 else None
|
||||
@@ -34,12 +31,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
}
|
||||
|
||||
def list_items(self):
|
||||
# instantiate a list to protect against concurrent modification
|
||||
names = list(shared.hypernetworks)
|
||||
for index, name in enumerate(names):
|
||||
item = self.create_item(name, index)
|
||||
if item is not None:
|
||||
yield item
|
||||
for index, name in enumerate(shared.hypernetworks):
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.hypernetwork_dir]
|
||||
|
||||
@@ -14,8 +14,6 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||
if embedding is None:
|
||||
return
|
||||
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
return {
|
||||
@@ -31,12 +29,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
}
|
||||
|
||||
def list_items(self):
|
||||
# instantiate a list to protect against concurrent modification
|
||||
names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
|
||||
for index, name in enumerate(names):
|
||||
item = self.create_item(name, index)
|
||||
if item is not None:
|
||||
yield item
|
||||
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
||||
|
||||
@@ -134,7 +134,7 @@ class UserMetadataEditor:
|
||||
basename, ext = os.path.splitext(filename)
|
||||
|
||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
||||
json.dump(metadata, file, indent=4, ensure_ascii=False)
|
||||
json.dump(metadata, file, indent=4)
|
||||
|
||||
def save_user_metadata(self, name, desc, notes):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
|
||||
@@ -141,7 +141,7 @@ class UiLoadsave:
|
||||
|
||||
def write_to_file(self, current_ui_settings):
|
||||
with open(self.filename, "w", encoding="utf8") as file:
|
||||
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
|
||||
json.dump(current_ui_settings, file, indent=4)
|
||||
|
||||
def dump_defaults(self):
|
||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||
|
||||
@@ -68,10 +68,10 @@ class UiPromptStyles:
|
||||
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||
|
||||
with gr.Row():
|
||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
|
||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
||||
|
||||
with gr.Row():
|
||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
|
||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
||||
|
||||
with gr.Row():
|
||||
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||
|
||||
@@ -61,7 +61,7 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
||||
|
||||
def install_ui_tempdir_override():
|
||||
"""override save to file function so that it also writes PNG info"""
|
||||
# gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||
|
||||
|
||||
def on_tmpdir_changed():
|
||||
|
||||
@@ -16,7 +16,6 @@ exclude = [
|
||||
|
||||
ignore = [
|
||||
"E501", # Line too long
|
||||
"E721", # Do not compare types, use `isinstance`
|
||||
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||
|
||||
"I001", # Import block is un-sorted or un-formatted
|
||||
|
||||
+3
-1
@@ -8,7 +8,7 @@ clean-fid
|
||||
einops
|
||||
fastapi>=0.90.1
|
||||
gfpgan
|
||||
gradio==4.7.1
|
||||
gradio==3.41.2
|
||||
inflection
|
||||
jsonmerge
|
||||
kornia
|
||||
@@ -32,3 +32,5 @@ torch
|
||||
torchdiffeq
|
||||
torchsde
|
||||
transformers==4.30.2
|
||||
|
||||
git+https://github.com/openai/consistencydecoder.git
|
||||
|
||||
@@ -5,9 +5,9 @@ basicsr==1.4.2
|
||||
blendmodes==2022
|
||||
clean-fid==0.1.35
|
||||
einops==0.4.1
|
||||
fastapi==0.104.1
|
||||
fastapi==0.94.0
|
||||
gfpgan==1.3.8
|
||||
gradio==4.7.1
|
||||
gradio==3.41.2
|
||||
httpcore==0.15
|
||||
inflection==0.5.1
|
||||
jsonmerge==1.8.0
|
||||
@@ -30,3 +30,4 @@ torchdiffeq==0.2.3
|
||||
torchsde==0.2.6
|
||||
transformers==4.30.2
|
||||
httpx==0.24.1
|
||||
git+https://github.com/openai/consistencydecoder.git
|
||||
|
||||
@@ -133,18 +133,9 @@ document.addEventListener('keydown', function(e) {
|
||||
if (isEnter && isModifierKey) {
|
||||
if (interruptButton.style.display === 'block') {
|
||||
interruptButton.click();
|
||||
const callback = (mutationList) => {
|
||||
for (const mutation of mutationList) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||
if (interruptButton.style.display === 'none') {
|
||||
generateButton.click();
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
const observer = new MutationObserver(callback);
|
||||
observer.observe(interruptButton, {attributes: true});
|
||||
setTimeout(function() {
|
||||
generateButton.click();
|
||||
}, 500);
|
||||
} else {
|
||||
generateButton.click();
|
||||
}
|
||||
|
||||
@@ -462,15 +462,6 @@ div.toprow-compact-tools{
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
#settings > div.tab-nav .settings-category{
|
||||
display: block;
|
||||
margin: 1em 0 0.25em 0;
|
||||
font-weight: bold;
|
||||
text-decoration: underline;
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
#settings_result{
|
||||
height: 1.4em;
|
||||
margin: 0 1.2em;
|
||||
@@ -849,16 +840,8 @@ footer {
|
||||
|
||||
/* extra networks UI */
|
||||
|
||||
.extra-page > div.gap{
|
||||
gap: 0;
|
||||
}
|
||||
|
||||
.extra-page-prompts{
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.extra-page-prompts.extra-page-prompts-active{
|
||||
margin-bottom: 1em;
|
||||
.extra-page .prompt{
|
||||
margin: 0 0 0.5em 0;
|
||||
}
|
||||
|
||||
.extra-network-cards{
|
||||
|
||||
@@ -89,7 +89,7 @@ delimiter="################################################################"
|
||||
|
||||
printf "\n%s\n" "${delimiter}"
|
||||
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
|
||||
printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m"
|
||||
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
|
||||
printf "\n%s\n" "${delimiter}"
|
||||
|
||||
# Do not run as root
|
||||
@@ -223,7 +223,7 @@ fi
|
||||
# Try using TCMalloc on Linux
|
||||
prepare_tcmalloc() {
|
||||
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
||||
TCMALLOC="$(PATH=/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
|
||||
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
|
||||
if [[ ! -z "${TCMALLOC}" ]]; then
|
||||
echo "Using TCMalloc: ${TCMALLOC}"
|
||||
export LD_PRELOAD="${TCMALLOC}"
|
||||
|
||||
Reference in New Issue
Block a user