-
Notifications
You must be signed in to change notification settings - Fork 18
/
chat_with_image.py
executable file
·115 lines (90 loc) · 4.64 KB
/
chat_with_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
try:
import dotenv
dotenv.load_dotenv(override=True)
except:
pass
import os
import requests
import argparse
from datauri import DataURI
from openai import OpenAI
def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> str:
if img_url.startswith('http'):
response = requests.get(img_url)
img_data = response.content
content_type = response.headers['content-type']
return str(DataURI.make(mimetype=content_type, charset='utf-8', base64=True, data=img_data))
elif img_url.startswith('file:'):
img_url = img_url.replace('file://', '').replace('file:', '')
return str(DataURI.from_file(img_url))
return img_url
if __name__ == '__main__':
# Initialize argparse
parser = argparse.ArgumentParser(description='Test vision using OpenAI',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', '--system-prompt', type=str, default=None, help="Set a system prompt.")
parser.add_argument('--openai-model', type=str, default="gpt-4-vision-preview", help="OpenAI model to use.")
parser.add_argument('-S', '--start-with', type=str, default=None, help="Start reply with, ex. 'Sure, ' (doesn't work with all models)")
parser.add_argument('-m', '--max-tokens', type=int, default=None, help="Max tokens to generate.")
parser.add_argument('-t', '--temperature', type=float, default=None, help="Temperature.")
parser.add_argument('-p', '--top_p', type=float, default=None, help="top_p")
parser.add_argument('-u', '--keep-remote-urls', action='store_true', help="Normally, http urls are converted to data: urls for better latency.")
parser.add_argument('-1', '--single', action='store_true', help='Single turn Q&A, output is only the model response.')
parser.add_argument('--no-stream', action='store_true', help='Disable streaming response.')
parser.add_argument('image_url', type=str, help='URL or image file to be tested')
parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image')
args = parser.parse_args()
client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'),
api_key=os.environ.get('OPENAI_API_KEY', 'sk-ip'))
params = {}
if args.max_tokens is not None:
params['max_tokens'] = args.max_tokens
if args.temperature is not None:
params['temperature'] = args.temperature
if args.top_p is not None:
params['top_p'] = args.top_p
params['stream'] = not args.no_stream
image_url = args.image_url
if not image_url.startswith('http'):
image_url = str(DataURI.from_file(image_url))
elif not args.keep_remote_urls:
image_url = url_for_api(image_url)
messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
content = [{ "type": "image_url", "image_url": { "url": image_url } },
{ "type": "text", "text": ' '.join(args.questions) }]
messages.extend([{ "role": "user", "content": content }])
while True:
if args.start_with:
messages.extend([{ "role": "assistant", "content": [{ "type": "text", "text": args.start_with }] }])
response = client.chat.completions.create(model=args.openai_model, messages=messages, **params)
if not args.single:
print(f"Answer: ", end='', flush=True)
assistant_text = ''
if args.no_stream:
assistant_text = response.choices[0].message.content
print(assistant_text)
else:
for chunk in response:
assistant_text += chunk.choices[0].delta.content
print(chunk.choices[0].delta.content, end='', flush=True)
print('')
if args.single:
break
image_url = None
try:
q = input("\nQuestion: ")
if q.startswith('http') or q.startswith('data:') or q.startswith('file:'):
image_url = q
if image_url.startswith('http') and args.keep_remote_urls:
pass
else:
image_url = url_for_api(image_url)
q = input("Question: ")
except EOFError as e:
print('')
break
content = [{"type": "image_url", "image_url": { "url": image_url } }] if image_url else []
content.extend([{ 'type': 'text', 'text': assistant_text }])
messages.extend([{ "role": "assistant", "content": content },
{ "role": "user", "content": [{ 'type': 'text', 'text': q }] }])