for 2, to wait and add metadata after calling the chat completion API, you’d need to log to wandb
yourself rather than using autolog
.
To do this, you import the Trace
class and create the object with your chat data after you call it.
from wandb.sdk.data_types.trace_tree import Trace
root_span = Trace(
name="root_span",
kind="llm", # kind can be "llm", "chain", "agent" or "tool"
status_code=status,
status_message=status_message,
metadata={"temperature": temperature,
"token_usage": token_usage,
"model_name": model_name},
start_time_ms=start_time_ms,
end_time_ms=end_time_ms,
inputs={"system_prompt": system_message, "query": query},
outputs={"response": response_text},
)
# log the span to wandb
root_span.log(name="openai_trace")
Here’s a more complete example:
To solve 1:
You can choose to chain multiple chat_gpt calls and add them as child spans of one parent span. This is how we log chains in langchain
for instance. This shows up in W&B as one row in the Trace Table, with each of the calls to chat_gpt visible as spans within the trace.
import time
import datetime
from wandb.sdk.data_types.trace_tree import Trace
import openai
import wandb
wandb.init(project="custom-chained-trace-example")
# boilerplate for openai
model_name='gpt-3.5-turbo'
temperature=0.7
system_message="You are a helpful assistant that always parses the user's query and replies in 3 concise bullet points using markdown."
docs = ['This is a document about cats. Cats are furry animals that like to eat mice. Cats are also very independent animals.']
def call_gpt(model_name, temperature, system_message, query):
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": query}
]
response = openai.ChatCompletion.create(model=model_name,
messages=messages,
temperature=temperature)
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
response_text = response["choices"][0]["message"]["content"]
token_usage = response["usage"].to_dict()
return llm_end_time_ms, response_text, token_usage
def chunk(docs, chunk_size=1):
for i in range(0, len(docs), chunk_size):
yield docs[i:i + chunk_size]
# logic to create a trace for each doc
for doc in docs:
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
# Create a root span to represent the entire trace
root_span = Trace(
name="LLMChain",
kind="chain",
start_time_ms=start_time_ms)
for chunk in chunk(doc, chunk_size=100):
doc_query = "Parse this doc: " + chunk
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
llm_end_time_ms, response_text, token_usage = call_gpt(model_name, temperature, system_message, doc_query)
# Create a span to represent each LLM call
llm_span = Trace(
name="OpenAI",
kind="llm",
status_code="success",
metadata={"temperature":temperature,
"token_usage": token_usage,
"model_name":model_name},
start_time_ms=start_time_ms,
end_time_ms=llm_end_time_ms,
inputs={"system_prompt":system_message, "query": doc_query},
outputs={"response": response_text},
)
root_span.add_child(llm_span)
# update the end time of the Chain span
root_span.add_inputs_and_outputs(
inputs={"query": doc},
outputs={"response": response_text})
# update the Chain span's end time
root_span._span.end_time_ms = llm_end_time_ms
# add metadata to the trace table
accuracy = 0.7
wandb.log({"accuracy": accuracy}, commit=False)
# log all spans to W&B by logging the root span
root_span.log(name="docs_trace")
Docs here: