Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cf2772fab0 | |||
| 0dfffe53ec | |||
| 2be85f8fe0 | |||
| eb52c803b8 | |||
| f8871dedcf | |||
| b7e0d4a7e1 | |||
| 5cb1ce470d | |||
| 888b928f0d | |||
| b55f09c4e1 | |||
| c7cd9b441d | |||
| 6ef0ff39f2 | |||
| 120a84bd2f | |||
| 368d66c9cc | |||
| 81105ee013 | |||
| 24dae9bc4c |
@@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.lin_module = None
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
self.scale = 1.0
|
||||
|
||||
# kohya-ss
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.is_kohya = True
|
||||
@@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.constraint = None
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
|
||||
def calc_updown_kb(self, orig_weight, multiplier):
|
||||
def calc_updown(self, orig_weight):
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
|
||||
if self.is_kohya:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||
|
||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
||||
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
@@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
output_shape = orig_weight.shape
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
||||
multiplier = self.multiplier()
|
||||
return self.calc_updown_kb(orig_weight, multiplier)
|
||||
|
||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
|
||||
return updown, ex_bias
|
||||
|
||||
@@ -159,7 +159,8 @@ def load_network(name, network_on_disk):
|
||||
bundle_embeddings = {}
|
||||
|
||||
for key_network, weight in sd.items():
|
||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||
|
||||
if key_network_without_network_parts == "bundle_emb":
|
||||
emb_name, vec_name = network_part.split(".", 1)
|
||||
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||
|
||||
@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
|
||||
self.setting_names = []
|
||||
self.infotext_fields = []
|
||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||
|
||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
||||
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||
|
||||
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||
|
||||
@@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i
|
||||
"""),
|
||||
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
||||
}))
|
||||
|
||||
|
||||
@@ -17,11 +17,42 @@ class ScriptHypertile(scripts.Script):
|
||||
|
||||
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
def before_hr(self, p, *args):
|
||||
|
||||
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
|
||||
|
||||
# exclusive hypertile seed for the second pass
|
||||
if not shared.opts.hypertile_enable_unet:
|
||||
if enable:
|
||||
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass)
|
||||
|
||||
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
|
||||
|
||||
if enable and not shared.opts.hypertile_enable_unet:
|
||||
p.extra_generation_params["Hypertile U-Net second pass"] = True
|
||||
|
||||
self.add_infotext(p, add_unet_params=True)
|
||||
|
||||
def add_infotext(self, p, add_unet_params=False):
|
||||
def option(name):
|
||||
value = getattr(shared.opts, name)
|
||||
default_value = shared.opts.get_default(name)
|
||||
return None if value == default_value else value
|
||||
|
||||
if shared.opts.hypertile_enable_unet:
|
||||
p.extra_generation_params["Hypertile U-Net"] = True
|
||||
|
||||
if shared.opts.hypertile_enable_unet or add_unet_params:
|
||||
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
|
||||
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
|
||||
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
|
||||
|
||||
if shared.opts.hypertile_enable_vae:
|
||||
p.extra_generation_params["Hypertile VAE"] = True
|
||||
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
|
||||
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
|
||||
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
|
||||
|
||||
|
||||
def configure_hypertile(width, height, enable_unet=True):
|
||||
@@ -57,16 +88,16 @@ def on_ui_settings():
|
||||
benefit.
|
||||
"""),
|
||||
|
||||
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
|
||||
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
|
||||
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
|
||||
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
|
||||
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
|
||||
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
|
||||
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
|
||||
|
||||
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
|
||||
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
|
||||
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
|
||||
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
|
||||
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
|
||||
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
|
||||
}
|
||||
|
||||
for name, opt in options.items():
|
||||
|
||||
@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
|
||||
if (modalImage && modalImage.offsetParent) {
|
||||
let currentButton = selected_gallery_button();
|
||||
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||
if (preview.length > 0) {
|
||||
if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
|
||||
// show preview image if available
|
||||
modalImage.src = preview[preview.length - 1].src;
|
||||
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||
|
||||
@@ -215,9 +215,33 @@ function restoreProgressImg2img() {
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Configure the width and height elements on `tabname` to accept
|
||||
* pasting of resolutions in the form of "width x height".
|
||||
*/
|
||||
function setupResolutionPasting(tabname) {
|
||||
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
|
||||
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
|
||||
for (const el of [width, height]) {
|
||||
el.addEventListener('paste', function(event) {
|
||||
var pasteData = event.clipboardData.getData('text/plain');
|
||||
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
|
||||
if (parsed) {
|
||||
width.value = parsed[1];
|
||||
height.value = parsed[2];
|
||||
updateInput(width);
|
||||
updateInput(height);
|
||||
event.preventDefault();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||
setupResolutionPasting('txt2img');
|
||||
setupResolutionPasting('img2img');
|
||||
});
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin(job="extras")
|
||||
@@ -128,6 +128,10 @@ def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, out
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
||||
|
||||
def run_postprocessing_webui(id_task, *args, **kwargs):
|
||||
return run_postprocessing(*args, **kwargs)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
"""old handler for API"""
|
||||
|
||||
|
||||
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
||||
would be on the meta device.
|
||||
"""
|
||||
|
||||
if state_dict == sd:
|
||||
if state_dict is sd:
|
||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
original(module, state_dict, strict=strict)
|
||||
|
||||
@@ -256,6 +256,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
||||
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
@@ -330,6 +331,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||
"js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||
|
||||
+10
-44
@@ -2,7 +2,6 @@ import csv
|
||||
import fnmatch
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import typing
|
||||
import shutil
|
||||
|
||||
@@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple):
|
||||
path: str = None
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""
|
||||
Iterating through a list of regular expressions and replacement strings, we
|
||||
clean up the prompt and style text to make it easier to match against each
|
||||
other.
|
||||
"""
|
||||
re_list = [
|
||||
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
|
||||
("multiple spaces", re.compile("\s{2,}"), " "),
|
||||
]
|
||||
for _, regex, replace in re_list:
|
||||
text = regex.sub(replace, text)
|
||||
|
||||
return text.strip(", ")
|
||||
|
||||
|
||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||
if "{prompt}" in style_prompt:
|
||||
res = style_prompt.replace("{prompt}", prompt)
|
||||
@@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles):
|
||||
for style in styles:
|
||||
prompt = merge_prompts(style, prompt)
|
||||
|
||||
return clean_text(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||
@@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt):
|
||||
Note that the "cleaned" version of the style text is only used for matching
|
||||
purposes here. It isn't returned; the original style text is not modified.
|
||||
"""
|
||||
stripped_prompt = clean_text(prompt)
|
||||
stripped_style_text = clean_text(style_text)
|
||||
stripped_prompt = prompt
|
||||
stripped_style_text = style_text
|
||||
if "{prompt}" in stripped_style_text:
|
||||
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||
# return True and the "inner" prompt text that isn't part of the style.
|
||||
@@ -115,10 +98,8 @@ class StyleDatabase:
|
||||
self.path = path
|
||||
|
||||
folder, file = os.path.split(self.path)
|
||||
self.default_file = file.split("*")[0] + ".csv"
|
||||
if self.default_file == ".csv":
|
||||
self.default_file = "styles.csv"
|
||||
self.default_path = os.path.join(folder, self.default_file)
|
||||
filename, _, ext = file.partition('*')
|
||||
self.default_path = os.path.join(folder, filename + ext)
|
||||
|
||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||
|
||||
@@ -172,10 +153,8 @@ class StyleDatabase:
|
||||
row["name"], prompt, negative_prompt, path
|
||||
)
|
||||
|
||||
def get_style_paths(self) -> list():
|
||||
"""
|
||||
Returns a list of all distinct paths, including the default path, of
|
||||
files that styles are loaded from."""
|
||||
def get_style_paths(self) -> set:
|
||||
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
||||
# Update any styles without a path to the default path
|
||||
for style in list(self.styles.values()):
|
||||
if not style.path:
|
||||
@@ -189,9 +168,9 @@ class StyleDatabase:
|
||||
style_paths.add(style.path)
|
||||
|
||||
# Remove any paths for styles that are just list dividers
|
||||
style_paths.remove("do_not_save")
|
||||
style_paths.discard("do_not_save")
|
||||
|
||||
return list(style_paths)
|
||||
return style_paths
|
||||
|
||||
def get_style_prompts(self, styles):
|
||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||
@@ -213,20 +192,7 @@ class StyleDatabase:
|
||||
# The path argument is deprecated, but kept for backwards compatibility
|
||||
_ = path
|
||||
|
||||
# Update any styles without a path to the default path
|
||||
for style in list(self.styles.values()):
|
||||
if not style.path:
|
||||
self.styles[style.name] = style._replace(path=self.default_path)
|
||||
|
||||
# Create a list of all distinct paths, including the default path
|
||||
style_paths = set()
|
||||
style_paths.add(self.default_path)
|
||||
for _, style in self.styles.items():
|
||||
if style.path:
|
||||
style_paths.add(style.path)
|
||||
|
||||
# Remove any paths for styles that are just list dividers
|
||||
style_paths.remove("do_not_save")
|
||||
style_paths = self.get_style_paths()
|
||||
|
||||
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def create_ui():
|
||||
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||
|
||||
submit.click(
|
||||
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
||||
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
|
||||
_js="submit_extras",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
|
||||
@@ -48,3 +48,12 @@ if has_xpu:
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.bmm',
|
||||
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||
CondFunc('torch.cat',
|
||||
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
|
||||
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
|
||||
|
||||
Reference in New Issue
Block a user