Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'_IPEXLinear' object has no attribute 'use_dnnl' #618

Open
Alok-Ranjan23 opened this issue May 6, 2024 · 6 comments
Open

'_IPEXLinear' object has no attribute 'use_dnnl' #618

Alok-Ranjan23 opened this issue May 6, 2024 · 6 comments
Assignees
Labels
Bug Something isn't working CPU CPU specific issues Crash Execution crashes high priority LLM

Comments

@Alok-Ranjan23
Copy link

Alok-Ranjan23 commented May 6, 2024

Describe the bug

I am trying to run llm.optimize API for GPT-j and getting the following Error.
_'IPEXLinear' object has no attribute 'use_dnnl'

Please check the attached code snippet.

`import torch
import time

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
'EleutherAI/gpt-j-6b',
torchscript=True,
return_dict=False,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)

vocab_size = model.config.vocab_size

step_count = 3
num_warmup = 2

with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
enabled=True
):
import intel_extension_for_pytorch as ipex
model = ipex.llm.optimize(model.eval(), dtype=torch.bfloat16, inplace=True,deployment_mode=False)

total_time = 0
batch_size=32
seq_len=384
sample_input = torch.randint(vocab_size, size=[batch_size, seq_len])
with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
enabled=True):
for i in range(int(num_warmup)):
outputs = model(sample_input)
for i in range(int(step_count)):
start_time = time.time()
outputs = model(sample_input)
end_time = time.time()
total_time += (end_time - start_time) * 1e3 # ms
time_per_step = total_time / int(step_count)
if batch_size == 1:
print("%s: Latency = %.2f ms" % ('GPT-J', time_per_step))
else:
print(
"%s: Throughput = %.2f QPS"
% ('GPT-J', (batch_size * 1000) / time_per_step)
)`

The above code file produces the following error:
_AttributeError: 'IPEXLinear' object has no attribute 'use_dnnl'

What I think about this error is that ipex.llm.optimize Convert LinearOp into IPEXLinear and ipexLINEAR does not need use_dnnl attribute.

Why is there no attribute 'use_dnnl' in _IPEXLinear Op? How can we remove this attribute from _IPEXLinear Op?

Versions

Collecting environment information...
PyTorch version: 2.2.0+cu121
PyTorch CXX11 ABI: No
IPEX version: 2.2.0+cpu
IPEX commit: 211813b
Build type: Release

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
IGC version: N/A
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.35
Is XPU available: False
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration:

Intel OpenCL ICD version: N/A
Level Zero version: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
BogoMIPS: 4799.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-23,192-215
NUMA node1 CPU(s): 24-47,216-239
NUMA node2 CPU(s): 48-71,240-263
NUMA node3 CPU(s): 72-95,264-287
NUMA node4 CPU(s): 96-119,288-311
NUMA node5 CPU(s): 120-143,312-335
NUMA node6 CPU(s): 144-167,336-359
NUMA node7 CPU(s): 168-191,360-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.2.0
[pip3] numpy==1.26.4
[pip3] torch==2.2.0
[conda] intel-extension-for-pytorch 2.2.0 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.0 pypi_0 pypi

@feng-intel
Copy link

I can reproduce this issue. I will ask dev team for help and feedback later.
Our 2.3.0 version will be released soon. It can work with removing "torch.inference_mode(), "

@jgong5
Copy link
Contributor

jgong5 commented May 8, 2024

cc @jianan-gu

@jgong5 jgong5 added Bug Something isn't working high priority labels May 8, 2024
@jgong5
Copy link
Contributor

jgong5 commented May 8, 2024

Mark it as high priority due to the crash.

@jingxu10 jingxu10 added CPU CPU specific issues Crash Execution crashes LLM labels May 9, 2024
@ZailiWang ZailiWang assigned feng-intel and unassigned feng-intel May 10, 2024
@feng-intel
Copy link

We have found the root cause. Before fixing, users can use

batch_size=32
seq_len=384
input_ids = torch.randint(vocab_size, size=[batch_size, seq_len])
att_mask = torch.ones_like(input_ids)
sample_input = model.prepare_inputs_for_generation(input_ids, attention_mask=att_mask)

To generate sample input

@Alok-Ranjan23
Copy link
Author

That's fine. But Why is this optimization not working with dummy benchmark?

@feng-intel
Copy link

It will be fixed in release 2.3.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Something isn't working CPU CPU specific issues Crash Execution crashes high priority LLM
Projects
None yet
Development

No branches or pull requests

4 participants