Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aa7925d36 | |||
| a57adde5ae | |||
| e72a6c411a |
+23
-28
@@ -16,7 +16,7 @@ from skimage import exposure
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
|
||||||
from modules.rng import slerp # noqa: F401
|
from modules.rng import slerp # noqa: F401
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
@@ -187,6 +187,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
cached_uc = [None, None]
|
cached_uc = [None, None]
|
||||||
cached_c = [None, None]
|
cached_c = [None, None]
|
||||||
|
hijack_generation_params_state_list = []
|
||||||
|
|
||||||
comments: dict = None
|
comments: dict = None
|
||||||
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||||
@@ -457,20 +458,6 @@ class StableDiffusionProcessing:
|
|||||||
opts.emphasis,
|
opts.emphasis,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_generation_params_list(self, generation_params_states):
|
|
||||||
"""add and apply generation_params_states to self.extra_generation_params"""
|
|
||||||
for key, value in generation_params_states.items():
|
|
||||||
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
|
|
||||||
self.extra_generation_params[key] = current_value + value
|
|
||||||
else:
|
|
||||||
self.extra_generation_params[key] = value
|
|
||||||
|
|
||||||
def clear_marked_generation_params(self):
|
|
||||||
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
|
|
||||||
for key, value in list(self.extra_generation_params.items()):
|
|
||||||
if getattr(value, 'to_be_clear_before_batch', False):
|
|
||||||
self.extra_generation_params.pop(key)
|
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
@@ -495,9 +482,9 @@ class StableDiffusionProcessing:
|
|||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and cached_params == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
if len(cache) == 3:
|
if len(cache) == 3:
|
||||||
generation_params_states, cached_cached_params = cache[2]
|
generation_params_state, cached_params_2 = cache[2]
|
||||||
if cached_params == cached_cached_params:
|
if cached_params == cached_params_2:
|
||||||
self.apply_generation_params_list(generation_params_states)
|
self.hijack_generation_params_state_list.extend(generation_params_state)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
@@ -505,16 +492,25 @@ class StableDiffusionProcessing:
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
generation_params_states = model_hijack.extract_generation_params_states()
|
generation_params_state = model_hijack.capture_generation_params_state()
|
||||||
self.apply_generation_params_list(generation_params_states)
|
self.hijack_generation_params_state_list.extend(generation_params_state)
|
||||||
if len(cache) == 2:
|
if len(cache) == 2:
|
||||||
cache.append((generation_params_states, cached_params))
|
cache.append((generation_params_state, cached_params))
|
||||||
else:
|
else:
|
||||||
cache[2] = (generation_params_states, cached_params)
|
cache[2] = (generation_params_state, cached_params)
|
||||||
|
|
||||||
cache[0] = cached_params
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
|
def apply_hijack_generation_params(self):
|
||||||
|
self.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
for func in self.hijack_generation_params_state_list:
|
||||||
|
try:
|
||||||
|
func(self.extra_generation_params)
|
||||||
|
except Exception:
|
||||||
|
errors.report('Failed to apply hijack generation params state', exc_info=True)
|
||||||
|
self.hijack_generation_params_state_list.clear()
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
||||||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||||
@@ -527,7 +523,7 @@ class StableDiffusionProcessing:
|
|||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
self.extra_generation_params.update(model_hijack.extra_generation_params)
|
self.apply_hijack_generation_params()
|
||||||
|
|
||||||
def get_conds(self):
|
def get_conds(self):
|
||||||
return self.c, self.uc
|
return self.c, self.uc
|
||||||
@@ -828,10 +824,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
|
|
||||||
for key, value in generation_params.items():
|
for key, value in generation_params.items():
|
||||||
try:
|
try:
|
||||||
if callable(value):
|
if isinstance(value, list):
|
||||||
generation_params[key] = value(**locals())
|
|
||||||
elif isinstance(value, list):
|
|
||||||
generation_params[key] = value[index]
|
generation_params[key] = value[index]
|
||||||
|
elif callable(value):
|
||||||
|
generation_params[key] = value(**locals())
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
|
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
|
||||||
generation_params[key] = None
|
generation_params[key] = None
|
||||||
@@ -965,7 +961,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.interrupted or state.stopping_generation:
|
if state.interrupted or state.stopping_generation:
|
||||||
break
|
break
|
||||||
|
|
||||||
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
|
|
||||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||||
|
|
||||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
@@ -1539,7 +1534,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
||||||
|
|
||||||
self.extra_generation_params.update(model_hijack.extra_generation_params)
|
self.apply_hijack_generation_params()
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
if self.is_hr_pass:
|
if self.is_hr_pass:
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import torch
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
|
from modules.util import GenerationParamsState
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@@ -321,13 +322,12 @@ class StableDiffusionModelHijack:
|
|||||||
self.comments = []
|
self.comments = []
|
||||||
self.extra_generation_params = {}
|
self.extra_generation_params = {}
|
||||||
|
|
||||||
def extract_generation_params_states(self):
|
def capture_generation_params_state(self):
|
||||||
"""Extracts GenerationParametersList so that they can be cached and restored later"""
|
state = []
|
||||||
states = {}
|
|
||||||
for key in list(self.extra_generation_params):
|
for key in list(self.extra_generation_params):
|
||||||
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
|
if isinstance(self.extra_generation_params[key], GenerationParamsState):
|
||||||
states[key] = self.extra_generation_params.pop(key)
|
state.append(self.extra_generation_params.pop(key))
|
||||||
return states
|
return state
|
||||||
|
|
||||||
def get_prompt_lengths(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
if self.clip is None:
|
if self.clip is None:
|
||||||
|
|||||||
+21
-20
@@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
|
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.util import GenerationParamsState
|
||||||
|
|
||||||
|
|
||||||
class PromptChunk:
|
class PromptChunk:
|
||||||
@@ -27,28 +28,29 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
|
|||||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
class EmphasisMode(util.GenerationParametersList):
|
class EmbeddingHashes(GenerationParamsState):
|
||||||
def __init__(self, emphasis_mode:str = None):
|
def __init__(self, hashes: list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emphasis_mode = emphasis_mode
|
self.hashes = hashes
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, extra_generation_params):
|
||||||
return self.emphasis_mode
|
unique_hashes = dict.fromkeys(self.hashes)
|
||||||
|
if existing_ti_hashes := extra_generation_params.get('TI hashes'):
|
||||||
|
unique_hashes.update(dict.fromkeys(existing_ti_hashes.split(', ')))
|
||||||
|
extra_generation_params['TI hashes'] = ', '.join(sorted(unique_hashes, key=util.natural_sort_key))
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
if isinstance(other, EmphasisMode):
|
|
||||||
return self if self.emphasis_mode else other
|
|
||||||
elif isinstance(other, str):
|
|
||||||
return self.__str__() + other
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
class EmphasisMode(GenerationParamsState):
|
||||||
if isinstance(other, str):
|
def __init__(self, texts):
|
||||||
return other + self.__str__()
|
super().__init__()
|
||||||
return NotImplemented
|
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
|
||||||
|
self.emphasis = opts.emphasis
|
||||||
|
else:
|
||||||
|
self.emphasis = None
|
||||||
|
|
||||||
def __str__(self):
|
def __call__(self, extra_generation_params):
|
||||||
return self.emphasis_mode if self.emphasis_mode else ''
|
if self.emphasis:
|
||||||
|
extra_generation_params['Emphasis'] = self.emphasis
|
||||||
|
|
||||||
|
|
||||||
class TextConditionalModel(torch.nn.Module):
|
class TextConditionalModel(torch.nn.Module):
|
||||||
@@ -262,10 +264,9 @@ class TextConditionalModel(torch.nn.Module):
|
|||||||
hashes.append(f"{name}: {shorthash}")
|
hashes.append(f"{name}: {shorthash}")
|
||||||
|
|
||||||
if hashes:
|
if hashes:
|
||||||
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)
|
self.hijack.extra_generation_params["TI hashes"] = EmbeddingHashes(hashes)
|
||||||
|
|
||||||
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
|
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(texts)
|
||||||
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)
|
|
||||||
|
|
||||||
if self.return_pooled:
|
if self.return_pooled:
|
||||||
return torch.hstack(zs), zs[0].pooled
|
return torch.hstack(zs), zs[0].pooled
|
||||||
|
|||||||
+9
-40
@@ -290,47 +290,16 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
|
|||||||
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
||||||
|
|
||||||
|
|
||||||
class GenerationParametersList(list):
|
class GenerationParamsState:
|
||||||
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
|
"""A custom class used in StableDiffusionModelHijack for assigning extra_generation_params
|
||||||
due to StableDiffusionProcessing.get_conds_with_caching
|
generation_params assigned using this class will work properly with StableDiffusionProcessing.get_conds_with_caching()
|
||||||
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
|
if assigned directly the generation_params will not be populated if conda cache is used
|
||||||
|
|
||||||
When an extra_generation_params is set in StableDiffusionModelHijack using this object,
|
Generation_params of this class will be captured (see StableDiffusionModelHijack.capture_generation_params_state) and stored with conda cache, and will be extracted in StableDiffusionProcessing.apply_hijack_generation_params()
|
||||||
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
|
|
||||||
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
|
|
||||||
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
|
|
||||||
|
|
||||||
Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
|
To use this class, create a subclass with a __call__ method that takes extra_generation_params: dict as input
|
||||||
|
|
||||||
Depending on the use case the methods can be overwritten.
|
Example usage: sd_hijack_clip.EmbeddingHashes, sd_hijack_clip.EmphasisMode
|
||||||
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
|
|
||||||
When called by create_infotext it will access to the locals() of the caller,
|
|
||||||
if return str, the value will be written to infotext, if return None will be ignored.
|
|
||||||
"""
|
"""
|
||||||
|
def __call__(self, extra_generation_params: dict):
|
||||||
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
|
raise NotImplementedError
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._to_be_clear_before_batch = to_be_clear_before_batch
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return ', '.join(sorted(set(self), key=natural_sort_key))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def to_be_clear_before_batch(self):
|
|
||||||
return self._to_be_clear_before_batch
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
if isinstance(other, GenerationParametersList):
|
|
||||||
return self.__class__([*self, *other])
|
|
||||||
elif isinstance(other, str):
|
|
||||||
return self.__str__() + other
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
|
||||||
if isinstance(other, str):
|
|
||||||
return other + self.__str__()
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.__call__()
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user