Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def check_qaic_sdk():
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile

# Imports for the diffusers
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEFFWanPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
Expand All @@ -67,6 +71,8 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEFFFluxPipeline",
"QEFFWanPipeline"
]

else:
Expand Down
8 changes: 5 additions & 3 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json # dump_qconfig #TODO: debug and enable
from QEfficient.utils.cache import QEFF_HOME, to_hashable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,7 +179,8 @@ def _export(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
opset_version=17,
# verbose=True,
**export_kwargs,
)
logger.info("Pytorch export successful")
Expand Down Expand Up @@ -213,7 +214,7 @@ def _export(
self.onnx_path = onnx_path
return onnx_path

@dump_qconfig
# @dump_qconfig
def _compile(
self,
onnx_path: Optional[str] = None,
Expand Down Expand Up @@ -352,6 +353,7 @@ def _compile(

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
print(command)
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
Expand Down
110 changes: 110 additions & 0 deletions QEfficient/diffusers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@

<div align="center">


# **Diffusion Models on Qualcomm Cloud AI 100**


<div align="center">

### 🎨 **Experience the Future of AI Image Generation**

* Optimized for Qualcomm Cloud AI 100*

<img src="../../docs/image/girl_laughing.png" alt="Sample Output" width="400">

**Generated with**: `stabilityai/stable-diffusion-3.5-large` • `"A girl laughing"` • 28 steps • 2.0 guidance scale • ⚡



</div>



[![Diffusers](https://img.shields.io/badge/Diffusers-0.31.0-orange.svg)](https://github.com/huggingface/diffusers)
</div>

---

## ✨ Overview

QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.

## 🛠️ Installation

### Prerequisites

Ensure you have Python 3.8+ and the required dependencies:

```bash
# Create Python virtual environment (Recommended Python 3.10)
sudo apt install python3.10-venv
python3.10 -m venv qeff_env
source qeff_env/bin/activate
pip install -U pip
```

### Install QEfficient

```bash
# Install from GitHub (includes diffusers support)
pip install git+https://github.com/quic/efficient-transformers

# Or build from source
git clone https://github.com/quic/efficient-transformers.git
cd efficient-transformers
pip install build wheel
python -m build --wheel --outdir dist
pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
```

### Install Diffusers Dependencies

```bash
# Install diffusers optional dependencies
pip install "QEfficient[diffusers]"
```

---

## 🎯 Supported Models

### Stable Diffusion 3.x Series
- ✅ [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)
- ✅ [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo)
---


## 📚 Examples

Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:

---

## 🤝 Contributing

We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.

### Development Setup

```bash
git clone https://github.com/quic/efficient-transformers.git
cd efficient-transformers
pip install -e ".[diffusers,test]"
```

---

## 🙏 Acknowledgments

- **HuggingFace Diffusers**: For the excellent foundation library
- **Stability AI**: For the amazing Stable Diffusion models
---

## 📞 Support

- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)

---

Empty file.
Empty file.
75 changes: 75 additions & 0 deletions QEfficient/diffusers/models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import torch
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward


class QEffJointTransformerBlock(JointTransformerBlock):
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
# ff_output = self.ff(norm_hidden_states)
ff_output = self.ff(norm_hidden_states, block_size=4096)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output

# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output

return encoder_hidden_states, hidden_states
155 changes: 155 additions & 0 deletions QEfficient/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from typing import Optional

import torch
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0


class QEffAttention(Attention):
def __qeff_init__(self):
processor = QEffJointAttnProcessor2_0()
self.processor = processor
processor.query_block_size = 64

def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()

if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1

attention_scores = torch.baddbmm(
baddbmm_input,
query,
key,
beta=beta,
alpha=self.scale,
)
del baddbmm_input

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = attention_scores.softmax(dim=-1)
del attention_scores

attention_probs = attention_probs.to(dtype)

return attention_probs


class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
def __call__(
self,
attn: QEffAttention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states

batch_size = hidden_states.shape[0]

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

query = query.reshape(-1, query.shape[-2], query.shape[-1])
key = key.reshape(-1, key.shape[-2], key.shape[-1])
value = value.reshape(-1, value.shape[-2], value.shape[-1])

# pre-transpose the key
key = key.transpose(-1, -2)
if query.size(-2) != value.size(-2): # cross-attention, use regular attention
# QKV done in single block
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
else: # self-attention, use blocked attention
# QKV done with block-attention (a la FlashAttentionV2)
query_block_size = self.query_block_size
query_seq_len = query.size(-2)
num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
for qidx in range(num_blocks):
query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :]
attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
hidden_states_block = torch.bmm(attention_probs, value)
if qidx == 0:
hidden_states = hidden_states_block
else:
hidden_states = torch.cat((hidden_states, hidden_states_block), -2)
hidden_states = attn.batch_to_head_dim(hidden_states)

if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
Loading