mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-03 07:26:31 +08:00
Update optimized_attention_for_device function for new functions that
support masked attention.
This commit is contained in:
@@ -333,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
@@ -349,15 +348,15 @@ else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input and model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return attention_sub_quad
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
else:
|
||||
return attention_basic
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
|
Reference in New Issue
Block a user