for 2, to wait and add metadata after calling the chat completion API, you’d need to log to wandb yourself rather than using autolog.
To do this, you import the Trace class and create the object with your chat data after you call it.
from wandb.sdk.data_types.trace_tree import Trace
root_span = Trace(
name="root_span",
kind="llm", # kind can be "llm", "chain", "agent" or "tool"
status_code=status,
status_message=status_message,
metadata={"temperature": temperature,
"token_usage": token_usage,
"model_name": model_name},
start_time_ms=start_time_ms,
end_time_ms=end_time_ms,
inputs={"system_prompt": system_message, "query": query},
outputs={"response": response_text},
)
# log the span to wandb
root_span.log(name="openai_trace")
Here’s a more complete example:
To solve 1:
You can choose to chain multiple chat_gpt calls and add them as child spans of one parent span. This is how we log chains in langchain for instance. This shows up in W&B as one row in the Trace Table, with each of the calls to chat_gpt visible as spans within the trace.
import time
import datetime
from wandb.sdk.data_types.trace_tree import Trace
import openai
import wandb
wandb.init(project="custom-chained-trace-example")
# boilerplate for openai
model_name='gpt-3.5-turbo'
temperature=0.7
system_message="You are a helpful assistant that always parses the user's query and replies in 3 concise bullet points using markdown."
docs = ['This is a document about cats. Cats are furry animals that like to eat mice. Cats are also very independent animals.']
def call_gpt(model_name, temperature, system_message, query):
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": query}
]
response = openai.ChatCompletion.create(model=model_name,
messages=messages,
temperature=temperature)
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
response_text = response["choices"][0]["message"]["content"]
token_usage = response["usage"].to_dict()
return llm_end_time_ms, response_text, token_usage
def chunk(docs, chunk_size=1):
for i in range(0, len(docs), chunk_size):
yield docs[i:i + chunk_size]
# logic to create a trace for each doc
for doc in docs:
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
# Create a root span to represent the entire trace
root_span = Trace(
name="LLMChain",
kind="chain",
start_time_ms=start_time_ms)
for chunk in chunk(doc, chunk_size=100):
doc_query = "Parse this doc: " + chunk
start_time_ms = round(datetime.datetime.now().timestamp() * 1000)
llm_end_time_ms, response_text, token_usage = call_gpt(model_name, temperature, system_message, doc_query)
# Create a span to represent each LLM call
llm_span = Trace(
name="OpenAI",
kind="llm",
status_code="success",
metadata={"temperature":temperature,
"token_usage": token_usage,
"model_name":model_name},
start_time_ms=start_time_ms,
end_time_ms=llm_end_time_ms,
inputs={"system_prompt":system_message, "query": doc_query},
outputs={"response": response_text},
)
root_span.add_child(llm_span)
# update the end time of the Chain span
root_span.add_inputs_and_outputs(
inputs={"query": doc},
outputs={"response": response_text})
# update the Chain span's end time
root_span._span.end_time_ms = llm_end_time_ms
# add metadata to the trace table
accuracy = 0.7
wandb.log({"accuracy": accuracy}, commit=False)
# log all spans to W&B by logging the root span
root_span.log(name="docs_trace")
Docs here: