79 lines
2.1 KiB
Rust
79 lines
2.1 KiB
Rust
use crate::commands;
|
|
use color_eyre::Result;
|
|
use futures::StreamExt;
|
|
use genai::{
|
|
Client,
|
|
ModelIden,
|
|
chat::{
|
|
ChatMessage,
|
|
ChatRequest,
|
|
ChatStreamEvent,
|
|
StreamChunk,
|
|
},
|
|
resolver::{
|
|
AuthData,
|
|
AuthResolver,
|
|
},
|
|
};
|
|
use tracing::info;
|
|
|
|
// Represents an LLM completion source.
|
|
// FIXME: Clone is probably temporary.
|
|
#[derive(Clone, Debug)]
|
|
pub struct LLMHandle {
|
|
chat_request: ChatRequest,
|
|
client: Client,
|
|
cmd_root: Option<commands::Root>,
|
|
model: String,
|
|
}
|
|
|
|
impl LLMHandle {
|
|
pub fn new(
|
|
api_key: String,
|
|
_base_url: impl AsRef<str>,
|
|
cmd_root: Option<commands::Root>,
|
|
model: impl Into<String>,
|
|
system_role: String,
|
|
) -> Result<LLMHandle> {
|
|
let auth_resolver = AuthResolver::from_resolver_fn(
|
|
|_model_iden: ModelIden| -> Result<Option<AuthData>, genai::resolver::Error> {
|
|
// let ModelIden { adapter_kind, model_name } = model_iden;
|
|
|
|
Ok(Some(AuthData::from_single(api_key)))
|
|
},
|
|
);
|
|
|
|
let client = Client::builder().with_auth_resolver(auth_resolver).build();
|
|
let chat_request = ChatRequest::default().with_system(system_role);
|
|
|
|
info!("New LLMHandle created.");
|
|
|
|
Ok(LLMHandle {
|
|
client,
|
|
chat_request,
|
|
cmd_root,
|
|
model: model.into(),
|
|
})
|
|
}
|
|
|
|
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
|
let mut req = self.chat_request.clone();
|
|
let client = self.client.clone();
|
|
|
|
req = req.append_message(ChatMessage::user(message.into()));
|
|
let response = client
|
|
.exec_chat_stream(&self.model, req.clone(), None)
|
|
.await?;
|
|
let mut stream = response.stream;
|
|
let mut text = String::new();
|
|
|
|
while let Some(Ok(stream_event)) = stream.next().await {
|
|
if let ChatStreamEvent::Chunk(StreamChunk { content }) = stream_event {
|
|
text.push_str(&content);
|
|
}
|
|
}
|
|
|
|
Ok(text)
|
|
}
|
|
}
|